Add more tests
This commit is contained in:
408
tests/test_main_tool.py
Normal file
408
tests/test_main_tool.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""Tests for the main knowledge_search tool."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from knowledge_search_mcp.__main__ import knowledge_search
|
||||
from knowledge_search_mcp.models import AppContext, SourceNamespace, SearchResult
|
||||
|
||||
|
||||
class TestKnowledgeSearch:
|
||||
"""Tests for knowledge_search tool function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_context(self):
|
||||
"""Create a mock AppContext."""
|
||||
app = MagicMock(spec=AppContext)
|
||||
|
||||
# Mock genai_client
|
||||
app.genai_client = MagicMock()
|
||||
|
||||
# Mock vector_search
|
||||
app.vector_search = MagicMock()
|
||||
app.vector_search.async_run_query = AsyncMock()
|
||||
|
||||
# Mock settings
|
||||
app.settings = MagicMock()
|
||||
app.settings.embedding_model = "models/text-embedding-004"
|
||||
app.settings.deployed_index_id = "test-deployed-index"
|
||||
app.settings.search_limit = 10
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(self, mock_app_context):
|
||||
"""Create a mock MCP Context."""
|
||||
ctx = MagicMock()
|
||||
ctx.request_context = MagicMock()
|
||||
ctx.request_context.lifespan_context = mock_app_context
|
||||
return ctx
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding(self):
|
||||
"""Create a sample embedding vector."""
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_results(self):
|
||||
"""Create sample search results."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": "First document content"},
|
||||
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content"},
|
||||
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content"},
|
||||
]
|
||||
return results
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_successful_search(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test successful search workflow."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("What is financial education?", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
|
||||
mock_generate.assert_called_once()
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
|
||||
deployed_index_id="test-deployed-index",
|
||||
query=sample_embedding,
|
||||
limit=10,
|
||||
source=None,
|
||||
)
|
||||
mock_filter.assert_called_once_with(sample_search_results)
|
||||
mock_format.assert_called_once_with(sample_search_results)
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_embedding_generation_error(self, mock_generate, mock_context):
|
||||
"""Test handling of embedding generation error."""
|
||||
# Setup mock to return error
|
||||
mock_generate.return_value = ([], "Error: API rate limit exceeded. Please try again later.")
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "Error: API rate limit exceeded. Please try again later."
|
||||
mock_generate.assert_called_once()
|
||||
# Vector search should not be called
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called()
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_empty_query_error(self, mock_generate, mock_context):
|
||||
"""Test handling of empty query."""
|
||||
# Setup mock to return error for empty query
|
||||
mock_generate.return_value = ([], "Error: Query cannot be empty")
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "Error: Query cannot be empty"
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_vector_search_error(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of vector search error."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = Exception(
|
||||
"Vector search service unavailable"
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Vector search service unavailable" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_empty_search_results(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding
|
||||
):
|
||||
"""Test handling of empty search results."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = []
|
||||
mock_filter.return_value = []
|
||||
mock_format.return_value = "No relevant documents found for your query."
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("obscure query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "No relevant documents found for your query."
|
||||
mock_format.assert_called_once_with([])
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_filtered_results_empty(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test when filtering removes all results."""
|
||||
# Setup mocks - results exist but get filtered out
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = [] # All filtered out
|
||||
mock_format.return_value = "No relevant documents found for your query."
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "No relevant documents found for your query."
|
||||
mock_filter.assert_called_once_with(sample_search_results)
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_source_filter_parameter(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test that source filter is passed correctly to vector search."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "formatted results"
|
||||
|
||||
# Execute with source filter
|
||||
source_filter = SourceNamespace.EDUCACION_FINANCIERA
|
||||
result = await knowledge_search("test query", mock_context, source=source_filter)
|
||||
|
||||
# Assert
|
||||
assert result == "formatted results"
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
|
||||
deployed_index_id="test-deployed-index",
|
||||
query=sample_embedding,
|
||||
limit=10,
|
||||
source=source_filter,
|
||||
)
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_all_source_filters(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test all available source filter values."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "results"
|
||||
|
||||
# Test each source filter
|
||||
for source in SourceNamespace:
|
||||
result = await knowledge_search("test query", mock_context, source=source)
|
||||
assert result == "results"
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_vector_search_timeout(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of vector search timeout."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = TimeoutError(
|
||||
"Request timed out"
|
||||
)
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Request timed out" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_vector_search_connection_error(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of vector search connection error."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = ConnectionError(
|
||||
"Connection refused"
|
||||
)
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Connection refused" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
async def test_format_results_unexpected_error(
|
||||
self,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test handling of unexpected error in format_search_results."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
|
||||
# Mock format_search_results to raise an error
|
||||
with patch('knowledge_search_mcp.__main__.format_search_results', side_effect=ValueError("Format error")):
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Unexpected error during search:" in result
|
||||
assert "Format error" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_filter_results_unexpected_error(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of unexpected error in filter_search_results."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = [
|
||||
{"id": "doc1", "distance": 0.9, "content": "test"}
|
||||
]
|
||||
|
||||
# Mock filter_search_results to raise an error
|
||||
with patch('knowledge_search_mcp.__main__.filter_search_results', side_effect=TypeError("Filter error")):
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Unexpected error during search:" in result
|
||||
assert "Filter error" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_long_query_truncation_in_logs(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test that long queries are handled correctly."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "results"
|
||||
|
||||
# Execute with very long query
|
||||
long_query = "a" * 500
|
||||
result = await knowledge_search(long_query, mock_context)
|
||||
|
||||
# Assert - should succeed
|
||||
assert result == "results"
|
||||
# Verify generate_query_embedding was called with full query
|
||||
assert mock_generate.call_args[0][2] == long_query
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_multiple_results_returned(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding
|
||||
):
|
||||
"""Test handling of multiple search results."""
|
||||
# Create larger result set
|
||||
large_results: list[SearchResult] = [
|
||||
{"id": f"doc{i}.txt", "distance": 0.9 - (i * 0.05), "content": f"Content {i}"}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = large_results
|
||||
mock_filter.return_value = large_results[:5] # Filter to top 5
|
||||
mock_format.return_value = "formatted 5 results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "formatted 5 results"
|
||||
mock_filter.assert_called_once_with(large_results)
|
||||
mock_format.assert_called_once_with(large_results[:5])
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_settings_values_used_correctly(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test that settings values are used correctly."""
|
||||
# Customize settings
|
||||
mock_context.request_context.lifespan_context.settings.embedding_model = "custom-model"
|
||||
mock_context.request_context.lifespan_context.settings.deployed_index_id = "custom-index"
|
||||
mock_context.request_context.lifespan_context.settings.search_limit = 20
|
||||
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Verify embedding model
|
||||
assert mock_generate.call_args[0][1] == "custom-model"
|
||||
|
||||
# Verify vector search parameters
|
||||
call_kwargs = mock_context.request_context.lifespan_context.vector_search.async_run_query.call_args.kwargs
|
||||
assert call_kwargs["deployed_index_id"] == "custom-index"
|
||||
assert call_kwargs["limit"] == 20
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_graceful_degradation_on_partial_failure(
|
||||
self, mock_generate, mock_context, sample_embedding
|
||||
):
|
||||
"""Test that errors are caught and returned as strings, not raised."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = RuntimeError(
|
||||
"Critical failure"
|
||||
)
|
||||
|
||||
# Should not raise, should return error message
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Critical failure" in result
|
||||
381
tests/test_search_services.py
Normal file
381
tests/test_search_services.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Tests for search service functions."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from knowledge_search_mcp.services.search import (
|
||||
generate_query_embedding,
|
||||
filter_search_results,
|
||||
format_search_results,
|
||||
)
|
||||
from knowledge_search_mcp.models import SearchResult
|
||||
|
||||
|
||||
class TestGenerateQueryEmbedding:
|
||||
"""Tests for generate_query_embedding function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai_client(self):
|
||||
"""Create a mock genai client."""
|
||||
client = MagicMock()
|
||||
client.aio = MagicMock()
|
||||
client.aio.models = MagicMock()
|
||||
return client
|
||||
|
||||
async def test_successful_embedding_generation(self, mock_genai_client):
|
||||
"""Test successful embedding generation."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"What is financial education?"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert error is None
|
||||
assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
mock_genai_client.aio.models.embed_content.assert_called_once()
|
||||
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
|
||||
assert call_kwargs["model"] == "models/text-embedding-004"
|
||||
assert call_kwargs["contents"] == "What is financial education?"
|
||||
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
|
||||
|
||||
async def test_empty_query_string(self, mock_genai_client):
|
||||
"""Test handling of empty query string."""
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
""
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: Query cannot be empty"
|
||||
mock_genai_client.aio.models.embed_content.assert_not_called()
|
||||
|
||||
async def test_whitespace_only_query(self, mock_genai_client):
|
||||
"""Test handling of whitespace-only query."""
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
" \t\n "
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: Query cannot be empty"
|
||||
mock_genai_client.aio.models.embed_content.assert_not_called()
|
||||
|
||||
async def test_rate_limit_error_429(self, mock_genai_client):
|
||||
"""Test handling of 429 rate limit error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("429 Too Many Requests")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: API rate limit exceeded. Please try again later."
|
||||
|
||||
async def test_rate_limit_error_resource_exhausted(self, mock_genai_client):
|
||||
"""Test handling of RESOURCE_EXHAUSTED error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("RESOURCE_EXHAUSTED: Quota exceeded")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: API rate limit exceeded. Please try again later."
|
||||
|
||||
async def test_generic_api_error(self, mock_genai_client):
|
||||
"""Test handling of generic API error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ValueError("Invalid model name")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"invalid-model",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert "Error generating embedding: Invalid model name" in error
|
||||
|
||||
async def test_network_error(self, mock_genai_client):
|
||||
"""Test handling of network error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ConnectionError("Network unreachable")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert "Error generating embedding: Network unreachable" in error
|
||||
|
||||
async def test_long_query_truncation_in_logging(self, mock_genai_client):
|
||||
"""Test that long queries are truncated in error logging."""
|
||||
long_query = "a" * 200
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("API error")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
long_query
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error is not None
|
||||
|
||||
|
||||
class TestFilterSearchResults:
|
||||
"""Tests for filter_search_results function."""
|
||||
|
||||
def test_empty_results(self):
|
||||
"""Test filtering empty results list."""
|
||||
filtered = filter_search_results([])
|
||||
assert filtered == []
|
||||
|
||||
def test_single_result_above_thresholds(self):
|
||||
"""Test single result above both thresholds."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.85, "content": "test content"}
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_single_result_below_min_similarity(self):
|
||||
"""Test single result below minimum similarity threshold."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.5, "content": "test content"}
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert filtered == []
|
||||
|
||||
def test_multiple_results_all_above_thresholds(self):
|
||||
"""Test multiple results all above thresholds."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.90, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.85, "content": "content 3"},
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.8)
|
||||
# max_sim = 0.95, cutoff = 0.95 * 0.8 = 0.76
|
||||
# Results with distance > 0.76 and > 0.6: all three
|
||||
assert len(filtered) == 3
|
||||
|
||||
def test_top_percent_filtering(self):
|
||||
"""Test filtering by top_percent threshold."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 1.0, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.95, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.85, "content": "content 3"},
|
||||
{"id": "doc4", "distance": 0.70, "content": "content 4"},
|
||||
]
|
||||
# max_sim = 1.0, cutoff = 1.0 * 0.9 = 0.9
|
||||
# Results with distance > 0.9: doc1 (1.0), doc2 (0.95)
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 2
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
assert filtered[1]["id"] == "doc2"
|
||||
|
||||
def test_min_similarity_filtering(self):
|
||||
"""Test filtering by minimum similarity threshold."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.75, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.55, "content": "content 3"},
|
||||
]
|
||||
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
|
||||
# doc1 > 0.855 and > 0.7: included
|
||||
# doc2 < 0.855: excluded by top_percent
|
||||
# doc3 < 0.7: excluded by min_similarity
|
||||
filtered = filter_search_results(results, min_similarity=0.7, top_percent=0.9)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_default_parameters(self):
|
||||
"""Test filtering with default parameters."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.85, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.50, "content": "content 3"},
|
||||
]
|
||||
# Default: min_similarity=0.6, top_percent=0.9
|
||||
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
|
||||
# doc1 > 0.855 and > 0.6: included
|
||||
# doc2 < 0.855: excluded
|
||||
# doc3 < 0.6: excluded
|
||||
filtered = filter_search_results(results)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_all_results_filtered_out(self):
|
||||
"""Test when all results are filtered out."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.55, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.45, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.35, "content": "content 3"},
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert filtered == []
|
||||
|
||||
def test_exact_threshold_boundaries(self):
|
||||
"""Test behavior at exact threshold boundaries."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.9, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.6, "content": "content 2"},
|
||||
]
|
||||
# max_sim = 0.9, cutoff = 0.9 * 0.9 = 0.81
|
||||
# doc1: 0.9 > 0.81 and 0.9 > 0.6: included
|
||||
# doc2: 0.6 < 0.81: excluded
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_identical_distances(self):
|
||||
"""Test filtering with identical distance values."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.8, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.8, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.8, "content": "content 3"},
|
||||
]
|
||||
# max_sim = 0.8, cutoff = 0.8 * 0.9 = 0.72
|
||||
# All have distance 0.8 > 0.72 and > 0.6: all included
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 3
|
||||
|
||||
|
||||
class TestFormatSearchResults:
|
||||
"""Tests for format_search_results function."""
|
||||
|
||||
def test_empty_results(self):
|
||||
"""Test formatting empty results list."""
|
||||
formatted = format_search_results([])
|
||||
assert formatted == "No relevant documents found for your query."
|
||||
|
||||
def test_single_result(self):
|
||||
"""Test formatting single result."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": "This is the content."}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=doc1.txt>\nThis is the content.\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_multiple_results(self):
|
||||
"""Test formatting multiple results."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": "First document content."},
|
||||
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content."},
|
||||
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content."},
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = (
|
||||
"<document 1 name=doc1.txt>\nFirst document content.\n</document 1>\n"
|
||||
"<document 2 name=doc2.txt>\nSecond document content.\n</document 2>\n"
|
||||
"<document 3 name=doc3.txt>\nThird document content.\n</document 3>"
|
||||
)
|
||||
assert formatted == expected
|
||||
|
||||
def test_multiline_content(self):
|
||||
"""Test formatting results with multiline content."""
|
||||
results: list[SearchResult] = [
|
||||
{
|
||||
"id": "doc1.txt",
|
||||
"distance": 0.95,
|
||||
"content": "Line 1\nLine 2\nLine 3"
|
||||
}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=doc1.txt>\nLine 1\nLine 2\nLine 3\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_special_characters_in_content(self):
|
||||
"""Test formatting with special characters in content."""
|
||||
results: list[SearchResult] = [
|
||||
{
|
||||
"id": "doc1.txt",
|
||||
"distance": 0.95,
|
||||
"content": "Content with <special> & \"characters\""
|
||||
}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = '<document 1 name=doc1.txt>\nContent with <special> & "characters"\n</document 1>'
|
||||
assert formatted == expected
|
||||
|
||||
def test_special_characters_in_document_id(self):
|
||||
"""Test formatting with special characters in document ID."""
|
||||
results: list[SearchResult] = [
|
||||
{
|
||||
"id": "path/to/doc-name_v2.txt",
|
||||
"distance": 0.95,
|
||||
"content": "Some content"
|
||||
}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=path/to/doc-name_v2.txt>\nSome content\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_empty_content(self):
|
||||
"""Test formatting result with empty content."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": ""}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=doc1.txt>\n\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_document_numbering(self):
|
||||
"""Test that document numbering starts at 1 and increments correctly."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "a.txt", "distance": 0.9, "content": "A"},
|
||||
{"id": "b.txt", "distance": 0.8, "content": "B"},
|
||||
{"id": "c.txt", "distance": 0.7, "content": "C"},
|
||||
{"id": "d.txt", "distance": 0.6, "content": "D"},
|
||||
{"id": "e.txt", "distance": 0.5, "content": "E"},
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
|
||||
assert "<document 1 name=a.txt>" in formatted
|
||||
assert "</document 1>" in formatted
|
||||
assert "<document 2 name=b.txt>" in formatted
|
||||
assert "</document 2>" in formatted
|
||||
assert "<document 3 name=c.txt>" in formatted
|
||||
assert "</document 3>" in formatted
|
||||
assert "<document 4 name=d.txt>" in formatted
|
||||
assert "</document 4>" in formatted
|
||||
assert "<document 5 name=e.txt>" in formatted
|
||||
assert "</document 5>" in formatted
|
||||
|
||||
def test_very_long_content(self):
|
||||
"""Test formatting with very long content."""
|
||||
long_content = "A" * 10000
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": long_content}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
assert f"<document 1 name=doc1.txt>\n{long_content}\n</document 1>" == formatted
|
||||
assert len(formatted) > 10000
|
||||
436
tests/test_validation_services.py
Normal file
436
tests/test_validation_services.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""Tests for validation service functions."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from knowledge_search_mcp.services.validation import (
|
||||
validate_genai_access,
|
||||
validate_gcs_access,
|
||||
validate_vector_search_access,
|
||||
)
|
||||
from knowledge_search_mcp.config import Settings
|
||||
|
||||
|
||||
class TestValidateGenAIAccess:
|
||||
"""Tests for validate_genai_access function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings."""
|
||||
settings = MagicMock(spec=Settings)
|
||||
settings.embedding_model = "models/text-embedding-004"
|
||||
settings.project_id = "test-project"
|
||||
settings.location = "us-central1"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai_client(self):
|
||||
"""Create a mock genai client."""
|
||||
client = MagicMock()
|
||||
client.aio = MagicMock()
|
||||
client.aio.models = MagicMock()
|
||||
return client
|
||||
|
||||
async def test_successful_validation(self, mock_genai_client, mock_settings):
|
||||
"""Test successful GenAI access validation."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = [0.1] * 768 # Typical embedding dimension
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
# Assert
|
||||
assert error is None
|
||||
mock_genai_client.aio.models.embed_content.assert_called_once()
|
||||
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
|
||||
assert call_kwargs["model"] == "models/text-embedding-004"
|
||||
assert call_kwargs["contents"] == "test"
|
||||
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
|
||||
|
||||
async def test_empty_response(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of empty response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = []
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error == "Embedding validation returned empty response"
|
||||
|
||||
async def test_none_response(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of None response."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=None)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error == "Embedding validation returned empty response"
|
||||
|
||||
async def test_api_permission_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of permission denied error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=PermissionError("Permission denied for GenAI API")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Permission denied for GenAI API" in error
|
||||
|
||||
async def test_api_quota_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of quota exceeded error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("Quota exceeded")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Quota exceeded" in error
|
||||
|
||||
async def test_network_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of network error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ConnectionError("Network unreachable")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Network unreachable" in error
|
||||
|
||||
async def test_invalid_model_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of invalid model error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ValueError("Invalid model name")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Invalid model name" in error
|
||||
|
||||
async def test_embeddings_with_zero_values(self, mock_genai_client, mock_settings):
|
||||
"""Test validation with empty embedding values."""
|
||||
mock_response = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = []
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
# Should succeed even with empty values, as long as embeddings exist
|
||||
assert error is None
|
||||
|
||||
|
||||
class TestValidateGCSAccess:
|
||||
"""Tests for validate_gcs_access function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings."""
|
||||
settings = MagicMock(spec=Settings)
|
||||
settings.bucket = "test-bucket"
|
||||
settings.project_id = "test-project"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_search(self):
|
||||
"""Create a mock vector search client."""
|
||||
vs = MagicMock()
|
||||
vs.storage = MagicMock()
|
||||
return vs
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock aiohttp session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Create a mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.text = AsyncMock(return_value='{"items": []}')
|
||||
return response
|
||||
|
||||
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test successful GCS bucket access validation."""
|
||||
# Setup mocks
|
||||
mock_response.status = 200
|
||||
mock_response.ok = True
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is None
|
||||
mock_session.get.assert_called_once()
|
||||
call_args = mock_session.get.call_args
|
||||
assert "test-bucket" in call_args[0][0]
|
||||
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-access-token"
|
||||
|
||||
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 403 access denied."""
|
||||
mock_response.status = 403
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Access denied to bucket 'test-bucket'" in error
|
||||
assert "permissions" in error.lower()
|
||||
|
||||
async def test_bucket_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 404 bucket not found."""
|
||||
mock_response.status = 404
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Bucket 'test-bucket' not found" in error
|
||||
assert "bucket name" in error.lower()
|
||||
|
||||
async def test_server_error_500(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 500 server error."""
|
||||
mock_response.status = 500
|
||||
mock_response.ok = False
|
||||
mock_response.text = AsyncMock(return_value='{"error": "Internal server error"}')
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Failed to access bucket 'test-bucket': 500" in error
|
||||
|
||||
async def test_token_acquisition_error(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of token acquisition error."""
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(side_effect=Exception("Failed to get access token"))
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GCS:" in error
|
||||
assert "Failed to get access token" in error
|
||||
|
||||
async def test_network_error(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of network error."""
|
||||
mock_session.get = MagicMock(side_effect=ConnectionError("Network unreachable"))
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GCS:" in error
|
||||
assert "Network unreachable" in error
|
||||
|
||||
|
||||
class TestValidateVectorSearchAccess:
|
||||
"""Tests for validate_vector_search_access function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings."""
|
||||
settings = MagicMock(spec=Settings)
|
||||
settings.endpoint_name = "projects/test/locations/us-central1/indexEndpoints/test-endpoint"
|
||||
settings.location = "us-central1"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_search(self):
|
||||
"""Create a mock vector search client."""
|
||||
vs = MagicMock()
|
||||
vs._async_get_auth_headers = AsyncMock(return_value={"Authorization": "Bearer fake-token"})
|
||||
return vs
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock aiohttp session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Create a mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.text = AsyncMock(return_value='{"name": "test-endpoint"}')
|
||||
return response
|
||||
|
||||
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test successful vector search endpoint validation."""
|
||||
mock_response.status = 200
|
||||
mock_response.ok = True
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is None
|
||||
mock_vector_search._async_get_auth_headers.assert_called_once()
|
||||
mock_session.get.assert_called_once()
|
||||
call_args = mock_session.get.call_args
|
||||
assert "us-central1-aiplatform.googleapis.com" in call_args[0][0]
|
||||
assert "test-endpoint" in call_args[0][0]
|
||||
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-token"
|
||||
|
||||
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 403 access denied."""
|
||||
mock_response.status = 403
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Access denied to endpoint" in error
|
||||
assert "test-endpoint" in error
|
||||
assert "permissions" in error.lower()
|
||||
|
||||
async def test_endpoint_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 404 endpoint not found."""
|
||||
mock_response.status = 404
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "not found" in error.lower()
|
||||
assert "test-endpoint" in error
|
||||
|
||||
async def test_server_error_503(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 503 service unavailable."""
|
||||
mock_response.status = 503
|
||||
mock_response.ok = False
|
||||
mock_response.text = AsyncMock(return_value='{"error": "Service unavailable"}')
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Failed to access endpoint" in error
|
||||
assert "503" in error
|
||||
|
||||
async def test_auth_header_error(self, mock_vector_search, mock_settings):
|
||||
"""Test handling of authentication header error."""
|
||||
mock_vector_search._async_get_auth_headers = AsyncMock(
|
||||
side_effect=Exception("Failed to get auth headers")
|
||||
)
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Vector Search:" in error
|
||||
assert "Failed to get auth headers" in error
|
||||
|
||||
async def test_network_timeout(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of network timeout."""
|
||||
mock_session.get = MagicMock(side_effect=TimeoutError("Request timed out"))
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Vector Search:" in error
|
||||
assert "Request timed out" in error
|
||||
|
||||
async def test_connection_error(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of connection error."""
|
||||
mock_session.get = MagicMock(side_effect=ConnectionError("Connection refused"))
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Vector Search:" in error
|
||||
assert "Connection refused" in error
|
||||
|
||||
async def test_endpoint_url_construction(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test that endpoint URL is constructed correctly."""
|
||||
mock_response.status = 200
|
||||
mock_response.ok = True
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
# Custom location
|
||||
mock_settings.location = "europe-west1"
|
||||
mock_settings.endpoint_name = "projects/my-project/locations/europe-west1/indexEndpoints/my-endpoint"
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is None
|
||||
call_args = mock_session.get.call_args
|
||||
url = call_args[0][0]
|
||||
assert "europe-west1-aiplatform.googleapis.com" in url
|
||||
assert "my-endpoint" in url
|
||||
Reference in New Issue
Block a user