Add more tests

This commit is contained in:
2026-03-04 04:55:21 +00:00
parent f6e122b5a9
commit d69c4e4f4a
6 changed files with 1326 additions and 136 deletions

408
tests/test_main_tool.py Normal file
View File

@@ -0,0 +1,408 @@
"""Tests for the main knowledge_search tool."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from knowledge_search_mcp.__main__ import knowledge_search
from knowledge_search_mcp.models import AppContext, SourceNamespace, SearchResult
class TestKnowledgeSearch:
"""Tests for knowledge_search tool function."""
@pytest.fixture
def mock_app_context(self):
"""Create a mock AppContext."""
app = MagicMock(spec=AppContext)
# Mock genai_client
app.genai_client = MagicMock()
# Mock vector_search
app.vector_search = MagicMock()
app.vector_search.async_run_query = AsyncMock()
# Mock settings
app.settings = MagicMock()
app.settings.embedding_model = "models/text-embedding-004"
app.settings.deployed_index_id = "test-deployed-index"
app.settings.search_limit = 10
return app
@pytest.fixture
def mock_context(self, mock_app_context):
"""Create a mock MCP Context."""
ctx = MagicMock()
ctx.request_context = MagicMock()
ctx.request_context.lifespan_context = mock_app_context
return ctx
@pytest.fixture
def sample_embedding(self):
"""Create a sample embedding vector."""
return [0.1, 0.2, 0.3, 0.4, 0.5]
@pytest.fixture
def sample_search_results(self):
"""Create sample search results."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": "First document content"},
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content"},
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content"},
]
return results
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_successful_search(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test successful search workflow."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
# Execute
result = await knowledge_search("What is financial education?", mock_context)
# Assert
assert result == "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
mock_generate.assert_called_once()
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
deployed_index_id="test-deployed-index",
query=sample_embedding,
limit=10,
source=None,
)
mock_filter.assert_called_once_with(sample_search_results)
mock_format.assert_called_once_with(sample_search_results)
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_embedding_generation_error(self, mock_generate, mock_context):
"""Test handling of embedding generation error."""
# Setup mock to return error
mock_generate.return_value = ([], "Error: API rate limit exceeded. Please try again later.")
# Execute
result = await knowledge_search("test query", mock_context)
# Assert
assert result == "Error: API rate limit exceeded. Please try again later."
mock_generate.assert_called_once()
# Vector search should not be called
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called()
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_empty_query_error(self, mock_generate, mock_context):
"""Test handling of empty query."""
# Setup mock to return error for empty query
mock_generate.return_value = ([], "Error: Query cannot be empty")
# Execute
result = await knowledge_search("", mock_context)
# Assert
assert result == "Error: Query cannot be empty"
mock_generate.assert_called_once()
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_vector_search_error(self, mock_generate, mock_context, sample_embedding):
"""Test handling of vector search error."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = Exception(
"Vector search service unavailable"
)
# Execute
result = await knowledge_search("test query", mock_context)
# Assert
assert "Error performing vector search:" in result
assert "Vector search service unavailable" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_empty_search_results(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding
):
"""Test handling of empty search results."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = []
mock_filter.return_value = []
mock_format.return_value = "No relevant documents found for your query."
# Execute
result = await knowledge_search("obscure query", mock_context)
# Assert
assert result == "No relevant documents found for your query."
mock_format.assert_called_once_with([])
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_filtered_results_empty(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test when filtering removes all results."""
# Setup mocks - results exist but get filtered out
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = [] # All filtered out
mock_format.return_value = "No relevant documents found for your query."
# Execute
result = await knowledge_search("test query", mock_context)
# Assert
assert result == "No relevant documents found for your query."
mock_filter.assert_called_once_with(sample_search_results)
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_source_filter_parameter(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test that source filter is passed correctly to vector search."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "formatted results"
# Execute with source filter
source_filter = SourceNamespace.EDUCACION_FINANCIERA
result = await knowledge_search("test query", mock_context, source=source_filter)
# Assert
assert result == "formatted results"
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
deployed_index_id="test-deployed-index",
query=sample_embedding,
limit=10,
source=source_filter,
)
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_all_source_filters(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test all available source filter values."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "results"
# Test each source filter
for source in SourceNamespace:
result = await knowledge_search("test query", mock_context, source=source)
assert result == "results"
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_vector_search_timeout(self, mock_generate, mock_context, sample_embedding):
"""Test handling of vector search timeout."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = TimeoutError(
"Request timed out"
)
result = await knowledge_search("test query", mock_context)
assert "Error performing vector search:" in result
assert "Request timed out" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_vector_search_connection_error(self, mock_generate, mock_context, sample_embedding):
"""Test handling of vector search connection error."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = ConnectionError(
"Connection refused"
)
result = await knowledge_search("test query", mock_context)
assert "Error performing vector search:" in result
assert "Connection refused" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
async def test_format_results_unexpected_error(
self,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test handling of unexpected error in format_search_results."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
# Mock format_search_results to raise an error
with patch('knowledge_search_mcp.__main__.format_search_results', side_effect=ValueError("Format error")):
result = await knowledge_search("test query", mock_context)
assert "Unexpected error during search:" in result
assert "Format error" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_filter_results_unexpected_error(self, mock_generate, mock_context, sample_embedding):
"""Test handling of unexpected error in filter_search_results."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = [
{"id": "doc1", "distance": 0.9, "content": "test"}
]
# Mock filter_search_results to raise an error
with patch('knowledge_search_mcp.__main__.filter_search_results', side_effect=TypeError("Filter error")):
result = await knowledge_search("test query", mock_context)
assert "Unexpected error during search:" in result
assert "Filter error" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_long_query_truncation_in_logs(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test that long queries are handled correctly."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "results"
# Execute with very long query
long_query = "a" * 500
result = await knowledge_search(long_query, mock_context)
# Assert - should succeed
assert result == "results"
# Verify generate_query_embedding was called with full query
assert mock_generate.call_args[0][2] == long_query
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_multiple_results_returned(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding
):
"""Test handling of multiple search results."""
# Create larger result set
large_results: list[SearchResult] = [
{"id": f"doc{i}.txt", "distance": 0.9 - (i * 0.05), "content": f"Content {i}"}
for i in range(10)
]
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = large_results
mock_filter.return_value = large_results[:5] # Filter to top 5
mock_format.return_value = "formatted 5 results"
result = await knowledge_search("test query", mock_context)
assert result == "formatted 5 results"
mock_filter.assert_called_once_with(large_results)
mock_format.assert_called_once_with(large_results[:5])
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_settings_values_used_correctly(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test that settings values are used correctly."""
# Customize settings
mock_context.request_context.lifespan_context.settings.embedding_model = "custom-model"
mock_context.request_context.lifespan_context.settings.deployed_index_id = "custom-index"
mock_context.request_context.lifespan_context.settings.search_limit = 20
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "results"
result = await knowledge_search("test query", mock_context)
# Verify embedding model
assert mock_generate.call_args[0][1] == "custom-model"
# Verify vector search parameters
call_kwargs = mock_context.request_context.lifespan_context.vector_search.async_run_query.call_args.kwargs
assert call_kwargs["deployed_index_id"] == "custom-index"
assert call_kwargs["limit"] == 20
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_graceful_degradation_on_partial_failure(
self, mock_generate, mock_context, sample_embedding
):
"""Test that errors are caught and returned as strings, not raised."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = RuntimeError(
"Critical failure"
)
# Should not raise, should return error message
result = await knowledge_search("test query", mock_context)
assert isinstance(result, str)
assert "Error performing vector search:" in result
assert "Critical failure" in result

View File

@@ -0,0 +1,381 @@
"""Tests for search service functions."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from knowledge_search_mcp.services.search import (
generate_query_embedding,
filter_search_results,
format_search_results,
)
from knowledge_search_mcp.models import SearchResult
class TestGenerateQueryEmbedding:
"""Tests for generate_query_embedding function."""
@pytest.fixture
def mock_genai_client(self):
"""Create a mock genai client."""
client = MagicMock()
client.aio = MagicMock()
client.aio.models = MagicMock()
return client
async def test_successful_embedding_generation(self, mock_genai_client):
"""Test successful embedding generation."""
# Setup mock response
mock_response = MagicMock()
mock_embedding = MagicMock()
mock_embedding.values = [0.1, 0.2, 0.3, 0.4, 0.5]
mock_response.embeddings = [mock_embedding]
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
# Execute
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"What is financial education?"
)
# Assert
assert error is None
assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
mock_genai_client.aio.models.embed_content.assert_called_once()
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
assert call_kwargs["model"] == "models/text-embedding-004"
assert call_kwargs["contents"] == "What is financial education?"
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
async def test_empty_query_string(self, mock_genai_client):
"""Test handling of empty query string."""
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
""
)
assert embedding == []
assert error == "Error: Query cannot be empty"
mock_genai_client.aio.models.embed_content.assert_not_called()
async def test_whitespace_only_query(self, mock_genai_client):
"""Test handling of whitespace-only query."""
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
" \t\n "
)
assert embedding == []
assert error == "Error: Query cannot be empty"
mock_genai_client.aio.models.embed_content.assert_not_called()
async def test_rate_limit_error_429(self, mock_genai_client):
"""Test handling of 429 rate limit error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("429 Too Many Requests")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"test query"
)
assert embedding == []
assert error == "Error: API rate limit exceeded. Please try again later."
async def test_rate_limit_error_resource_exhausted(self, mock_genai_client):
"""Test handling of RESOURCE_EXHAUSTED error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("RESOURCE_EXHAUSTED: Quota exceeded")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"test query"
)
assert embedding == []
assert error == "Error: API rate limit exceeded. Please try again later."
async def test_generic_api_error(self, mock_genai_client):
"""Test handling of generic API error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=ValueError("Invalid model name")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"invalid-model",
"test query"
)
assert embedding == []
assert "Error generating embedding: Invalid model name" in error
async def test_network_error(self, mock_genai_client):
"""Test handling of network error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=ConnectionError("Network unreachable")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"test query"
)
assert embedding == []
assert "Error generating embedding: Network unreachable" in error
async def test_long_query_truncation_in_logging(self, mock_genai_client):
"""Test that long queries are truncated in error logging."""
long_query = "a" * 200
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("API error")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
long_query
)
assert embedding == []
assert error is not None
class TestFilterSearchResults:
"""Tests for filter_search_results function."""
def test_empty_results(self):
"""Test filtering empty results list."""
filtered = filter_search_results([])
assert filtered == []
def test_single_result_above_thresholds(self):
"""Test single result above both thresholds."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.85, "content": "test content"}
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_single_result_below_min_similarity(self):
"""Test single result below minimum similarity threshold."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.5, "content": "test content"}
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert filtered == []
def test_multiple_results_all_above_thresholds(self):
"""Test multiple results all above thresholds."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.95, "content": "content 1"},
{"id": "doc2", "distance": 0.90, "content": "content 2"},
{"id": "doc3", "distance": 0.85, "content": "content 3"},
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.8)
# max_sim = 0.95, cutoff = 0.95 * 0.8 = 0.76
# Results with distance > 0.76 and > 0.6: all three
assert len(filtered) == 3
def test_top_percent_filtering(self):
"""Test filtering by top_percent threshold."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 1.0, "content": "content 1"},
{"id": "doc2", "distance": 0.95, "content": "content 2"},
{"id": "doc3", "distance": 0.85, "content": "content 3"},
{"id": "doc4", "distance": 0.70, "content": "content 4"},
]
# max_sim = 1.0, cutoff = 1.0 * 0.9 = 0.9
# Results with distance > 0.9: doc1 (1.0), doc2 (0.95)
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 2
assert filtered[0]["id"] == "doc1"
assert filtered[1]["id"] == "doc2"
def test_min_similarity_filtering(self):
"""Test filtering by minimum similarity threshold."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.95, "content": "content 1"},
{"id": "doc2", "distance": 0.75, "content": "content 2"},
{"id": "doc3", "distance": 0.55, "content": "content 3"},
]
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
# doc1 > 0.855 and > 0.7: included
# doc2 < 0.855: excluded by top_percent
# doc3 < 0.7: excluded by min_similarity
filtered = filter_search_results(results, min_similarity=0.7, top_percent=0.9)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_default_parameters(self):
"""Test filtering with default parameters."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.95, "content": "content 1"},
{"id": "doc2", "distance": 0.85, "content": "content 2"},
{"id": "doc3", "distance": 0.50, "content": "content 3"},
]
# Default: min_similarity=0.6, top_percent=0.9
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
# doc1 > 0.855 and > 0.6: included
# doc2 < 0.855: excluded
# doc3 < 0.6: excluded
filtered = filter_search_results(results)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_all_results_filtered_out(self):
"""Test when all results are filtered out."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.55, "content": "content 1"},
{"id": "doc2", "distance": 0.45, "content": "content 2"},
{"id": "doc3", "distance": 0.35, "content": "content 3"},
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert filtered == []
def test_exact_threshold_boundaries(self):
"""Test behavior at exact threshold boundaries."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.9, "content": "content 1"},
{"id": "doc2", "distance": 0.6, "content": "content 2"},
]
# max_sim = 0.9, cutoff = 0.9 * 0.9 = 0.81
# doc1: 0.9 > 0.81 and 0.9 > 0.6: included
# doc2: 0.6 < 0.81: excluded
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_identical_distances(self):
"""Test filtering with identical distance values."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.8, "content": "content 1"},
{"id": "doc2", "distance": 0.8, "content": "content 2"},
{"id": "doc3", "distance": 0.8, "content": "content 3"},
]
# max_sim = 0.8, cutoff = 0.8 * 0.9 = 0.72
# All have distance 0.8 > 0.72 and > 0.6: all included
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 3
class TestFormatSearchResults:
"""Tests for format_search_results function."""
def test_empty_results(self):
"""Test formatting empty results list."""
formatted = format_search_results([])
assert formatted == "No relevant documents found for your query."
def test_single_result(self):
"""Test formatting single result."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": "This is the content."}
]
formatted = format_search_results(results)
expected = "<document 1 name=doc1.txt>\nThis is the content.\n</document 1>"
assert formatted == expected
def test_multiple_results(self):
"""Test formatting multiple results."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": "First document content."},
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content."},
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content."},
]
formatted = format_search_results(results)
expected = (
"<document 1 name=doc1.txt>\nFirst document content.\n</document 1>\n"
"<document 2 name=doc2.txt>\nSecond document content.\n</document 2>\n"
"<document 3 name=doc3.txt>\nThird document content.\n</document 3>"
)
assert formatted == expected
def test_multiline_content(self):
"""Test formatting results with multiline content."""
results: list[SearchResult] = [
{
"id": "doc1.txt",
"distance": 0.95,
"content": "Line 1\nLine 2\nLine 3"
}
]
formatted = format_search_results(results)
expected = "<document 1 name=doc1.txt>\nLine 1\nLine 2\nLine 3\n</document 1>"
assert formatted == expected
def test_special_characters_in_content(self):
"""Test formatting with special characters in content."""
results: list[SearchResult] = [
{
"id": "doc1.txt",
"distance": 0.95,
"content": "Content with <special> & \"characters\""
}
]
formatted = format_search_results(results)
expected = '<document 1 name=doc1.txt>\nContent with <special> & "characters"\n</document 1>'
assert formatted == expected
def test_special_characters_in_document_id(self):
"""Test formatting with special characters in document ID."""
results: list[SearchResult] = [
{
"id": "path/to/doc-name_v2.txt",
"distance": 0.95,
"content": "Some content"
}
]
formatted = format_search_results(results)
expected = "<document 1 name=path/to/doc-name_v2.txt>\nSome content\n</document 1>"
assert formatted == expected
def test_empty_content(self):
"""Test formatting result with empty content."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": ""}
]
formatted = format_search_results(results)
expected = "<document 1 name=doc1.txt>\n\n</document 1>"
assert formatted == expected
def test_document_numbering(self):
"""Test that document numbering starts at 1 and increments correctly."""
results: list[SearchResult] = [
{"id": "a.txt", "distance": 0.9, "content": "A"},
{"id": "b.txt", "distance": 0.8, "content": "B"},
{"id": "c.txt", "distance": 0.7, "content": "C"},
{"id": "d.txt", "distance": 0.6, "content": "D"},
{"id": "e.txt", "distance": 0.5, "content": "E"},
]
formatted = format_search_results(results)
assert "<document 1 name=a.txt>" in formatted
assert "</document 1>" in formatted
assert "<document 2 name=b.txt>" in formatted
assert "</document 2>" in formatted
assert "<document 3 name=c.txt>" in formatted
assert "</document 3>" in formatted
assert "<document 4 name=d.txt>" in formatted
assert "</document 4>" in formatted
assert "<document 5 name=e.txt>" in formatted
assert "</document 5>" in formatted
def test_very_long_content(self):
"""Test formatting with very long content."""
long_content = "A" * 10000
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": long_content}
]
formatted = format_search_results(results)
assert f"<document 1 name=doc1.txt>\n{long_content}\n</document 1>" == formatted
assert len(formatted) > 10000

View File

@@ -0,0 +1,436 @@
"""Tests for validation service functions."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from aiohttp import ClientResponse
from knowledge_search_mcp.services.validation import (
validate_genai_access,
validate_gcs_access,
validate_vector_search_access,
)
from knowledge_search_mcp.config import Settings
class TestValidateGenAIAccess:
"""Tests for validate_genai_access function."""
@pytest.fixture
def mock_settings(self):
"""Create mock settings."""
settings = MagicMock(spec=Settings)
settings.embedding_model = "models/text-embedding-004"
settings.project_id = "test-project"
settings.location = "us-central1"
return settings
@pytest.fixture
def mock_genai_client(self):
"""Create a mock genai client."""
client = MagicMock()
client.aio = MagicMock()
client.aio.models = MagicMock()
return client
async def test_successful_validation(self, mock_genai_client, mock_settings):
"""Test successful GenAI access validation."""
# Setup mock response
mock_response = MagicMock()
mock_embedding = MagicMock()
mock_embedding.values = [0.1] * 768 # Typical embedding dimension
mock_response.embeddings = [mock_embedding]
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
# Execute
error = await validate_genai_access(mock_genai_client, mock_settings)
# Assert
assert error is None
mock_genai_client.aio.models.embed_content.assert_called_once()
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
assert call_kwargs["model"] == "models/text-embedding-004"
assert call_kwargs["contents"] == "test"
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
async def test_empty_response(self, mock_genai_client, mock_settings):
"""Test handling of empty response."""
mock_response = MagicMock()
mock_response.embeddings = []
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
error = await validate_genai_access(mock_genai_client, mock_settings)
assert error == "Embedding validation returned empty response"
async def test_none_response(self, mock_genai_client, mock_settings):
"""Test handling of None response."""
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=None)
error = await validate_genai_access(mock_genai_client, mock_settings)
assert error == "Embedding validation returned empty response"
async def test_api_permission_error(self, mock_genai_client, mock_settings):
"""Test handling of permission denied error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=PermissionError("Permission denied for GenAI API")
)
error = await validate_genai_access(mock_genai_client, mock_settings)
assert error is not None
assert "GenAI:" in error
assert "Permission denied for GenAI API" in error
async def test_api_quota_error(self, mock_genai_client, mock_settings):
"""Test handling of quota exceeded error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("Quota exceeded")
)
error = await validate_genai_access(mock_genai_client, mock_settings)
assert error is not None
assert "GenAI:" in error
assert "Quota exceeded" in error
async def test_network_error(self, mock_genai_client, mock_settings):
"""Test handling of network error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=ConnectionError("Network unreachable")
)
error = await validate_genai_access(mock_genai_client, mock_settings)
assert error is not None
assert "GenAI:" in error
assert "Network unreachable" in error
async def test_invalid_model_error(self, mock_genai_client, mock_settings):
"""Test handling of invalid model error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=ValueError("Invalid model name")
)
error = await validate_genai_access(mock_genai_client, mock_settings)
assert error is not None
assert "GenAI:" in error
assert "Invalid model name" in error
async def test_embeddings_with_zero_values(self, mock_genai_client, mock_settings):
"""Test validation with empty embedding values."""
mock_response = MagicMock()
mock_embedding = MagicMock()
mock_embedding.values = []
mock_response.embeddings = [mock_embedding]
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
error = await validate_genai_access(mock_genai_client, mock_settings)
# Should succeed even with empty values, as long as embeddings exist
assert error is None
class TestValidateGCSAccess:
"""Tests for validate_gcs_access function."""
@pytest.fixture
def mock_settings(self):
"""Create mock settings."""
settings = MagicMock(spec=Settings)
settings.bucket = "test-bucket"
settings.project_id = "test-project"
return settings
@pytest.fixture
def mock_vector_search(self):
"""Create a mock vector search client."""
vs = MagicMock()
vs.storage = MagicMock()
return vs
@pytest.fixture
def mock_session(self):
"""Create a mock aiohttp session."""
session = MagicMock()
return session
@pytest.fixture
def mock_response(self):
"""Create a mock HTTP response."""
response = MagicMock()
response.text = AsyncMock(return_value='{"items": []}')
return response
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
"""Test successful GCS bucket access validation."""
# Setup mocks
mock_response.status = 200
mock_response.ok = True
mock_session.get = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
mock_vector_search.storage._get_aio_session.return_value = mock_session
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
mock_token = MockToken.return_value
mock_token.get = AsyncMock(return_value="fake-access-token")
error = await validate_gcs_access(mock_vector_search, mock_settings)
assert error is None
mock_session.get.assert_called_once()
call_args = mock_session.get.call_args
assert "test-bucket" in call_args[0][0]
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-access-token"
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
"""Test handling of 403 access denied."""
mock_response.status = 403
mock_response.ok = False
mock_session.get = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
mock_vector_search.storage._get_aio_session.return_value = mock_session
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
mock_token = MockToken.return_value
mock_token.get = AsyncMock(return_value="fake-access-token")
error = await validate_gcs_access(mock_vector_search, mock_settings)
assert error is not None
assert "Access denied to bucket 'test-bucket'" in error
assert "permissions" in error.lower()
async def test_bucket_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
"""Test handling of 404 bucket not found."""
mock_response.status = 404
mock_response.ok = False
mock_session.get = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
mock_vector_search.storage._get_aio_session.return_value = mock_session
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
mock_token = MockToken.return_value
mock_token.get = AsyncMock(return_value="fake-access-token")
error = await validate_gcs_access(mock_vector_search, mock_settings)
assert error is not None
assert "Bucket 'test-bucket' not found" in error
assert "bucket name" in error.lower()
async def test_server_error_500(self, mock_vector_search, mock_settings, mock_session, mock_response):
"""Test handling of 500 server error."""
mock_response.status = 500
mock_response.ok = False
mock_response.text = AsyncMock(return_value='{"error": "Internal server error"}')
mock_session.get = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
mock_vector_search.storage._get_aio_session.return_value = mock_session
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
mock_token = MockToken.return_value
mock_token.get = AsyncMock(return_value="fake-access-token")
error = await validate_gcs_access(mock_vector_search, mock_settings)
assert error is not None
assert "Failed to access bucket 'test-bucket': 500" in error
async def test_token_acquisition_error(self, mock_vector_search, mock_settings, mock_session):
"""Test handling of token acquisition error."""
mock_vector_search.storage._get_aio_session.return_value = mock_session
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
mock_token = MockToken.return_value
mock_token.get = AsyncMock(side_effect=Exception("Failed to get access token"))
error = await validate_gcs_access(mock_vector_search, mock_settings)
assert error is not None
assert "GCS:" in error
assert "Failed to get access token" in error
async def test_network_error(self, mock_vector_search, mock_settings, mock_session):
"""Test handling of network error."""
mock_session.get = MagicMock(side_effect=ConnectionError("Network unreachable"))
mock_vector_search.storage._get_aio_session.return_value = mock_session
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
mock_token = MockToken.return_value
mock_token.get = AsyncMock(return_value="fake-access-token")
error = await validate_gcs_access(mock_vector_search, mock_settings)
assert error is not None
assert "GCS:" in error
assert "Network unreachable" in error
class TestValidateVectorSearchAccess:
"""Tests for validate_vector_search_access function."""
@pytest.fixture
def mock_settings(self):
"""Create mock settings."""
settings = MagicMock(spec=Settings)
settings.endpoint_name = "projects/test/locations/us-central1/indexEndpoints/test-endpoint"
settings.location = "us-central1"
return settings
@pytest.fixture
def mock_vector_search(self):
"""Create a mock vector search client."""
vs = MagicMock()
vs._async_get_auth_headers = AsyncMock(return_value={"Authorization": "Bearer fake-token"})
return vs
@pytest.fixture
def mock_session(self):
"""Create a mock aiohttp session."""
session = MagicMock()
return session
@pytest.fixture
def mock_response(self):
"""Create a mock HTTP response."""
response = MagicMock()
response.text = AsyncMock(return_value='{"name": "test-endpoint"}')
return response
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
"""Test successful vector search endpoint validation."""
mock_response.status = 200
mock_response.ok = True
mock_session.get = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
mock_vector_search._get_aio_session.return_value = mock_session
error = await validate_vector_search_access(mock_vector_search, mock_settings)
assert error is None
mock_vector_search._async_get_auth_headers.assert_called_once()
mock_session.get.assert_called_once()
call_args = mock_session.get.call_args
assert "us-central1-aiplatform.googleapis.com" in call_args[0][0]
assert "test-endpoint" in call_args[0][0]
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-token"
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
"""Test handling of 403 access denied."""
mock_response.status = 403
mock_response.ok = False
mock_session.get = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
mock_vector_search._get_aio_session.return_value = mock_session
error = await validate_vector_search_access(mock_vector_search, mock_settings)
assert error is not None
assert "Access denied to endpoint" in error
assert "test-endpoint" in error
assert "permissions" in error.lower()
async def test_endpoint_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
"""Test handling of 404 endpoint not found."""
mock_response.status = 404
mock_response.ok = False
mock_session.get = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
mock_vector_search._get_aio_session.return_value = mock_session
error = await validate_vector_search_access(mock_vector_search, mock_settings)
assert error is not None
assert "not found" in error.lower()
assert "test-endpoint" in error
async def test_server_error_503(self, mock_vector_search, mock_settings, mock_session, mock_response):
"""Test handling of 503 service unavailable."""
mock_response.status = 503
mock_response.ok = False
mock_response.text = AsyncMock(return_value='{"error": "Service unavailable"}')
mock_session.get = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
mock_vector_search._get_aio_session.return_value = mock_session
error = await validate_vector_search_access(mock_vector_search, mock_settings)
assert error is not None
assert "Failed to access endpoint" in error
assert "503" in error
async def test_auth_header_error(self, mock_vector_search, mock_settings):
"""Test handling of authentication header error."""
mock_vector_search._async_get_auth_headers = AsyncMock(
side_effect=Exception("Failed to get auth headers")
)
error = await validate_vector_search_access(mock_vector_search, mock_settings)
assert error is not None
assert "Vector Search:" in error
assert "Failed to get auth headers" in error
async def test_network_timeout(self, mock_vector_search, mock_settings, mock_session):
"""Test handling of network timeout."""
mock_session.get = MagicMock(side_effect=TimeoutError("Request timed out"))
mock_vector_search._get_aio_session.return_value = mock_session
error = await validate_vector_search_access(mock_vector_search, mock_settings)
assert error is not None
assert "Vector Search:" in error
assert "Request timed out" in error
async def test_connection_error(self, mock_vector_search, mock_settings, mock_session):
"""Test handling of connection error."""
mock_session.get = MagicMock(side_effect=ConnectionError("Connection refused"))
mock_vector_search._get_aio_session.return_value = mock_session
error = await validate_vector_search_access(mock_vector_search, mock_settings)
assert error is not None
assert "Vector Search:" in error
assert "Connection refused" in error
async def test_endpoint_url_construction(self, mock_vector_search, mock_settings, mock_session, mock_response):
"""Test that endpoint URL is constructed correctly."""
mock_response.status = 200
mock_response.ok = True
mock_session.get = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
mock_vector_search._get_aio_session.return_value = mock_session
# Custom location
mock_settings.location = "europe-west1"
mock_settings.endpoint_name = "projects/my-project/locations/europe-west1/indexEndpoints/my-endpoint"
error = await validate_vector_search_access(mock_vector_search, mock_settings)
assert error is None
call_args = mock_session.get.call_args
url = call_args[0][0]
assert "europe-west1-aiplatform.googleapis.com" in url
assert "my-endpoint" in url