Files
knowledge-search-mcp/tests/test_search_services.py
2026-03-04 05:13:50 +00:00

382 lines
15 KiB
Python

"""Tests for search service functions."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from knowledge_search_mcp.services.search import (
generate_query_embedding,
filter_search_results,
format_search_results,
)
from knowledge_search_mcp.models import SearchResult
class TestGenerateQueryEmbedding:
"""Tests for generate_query_embedding function."""
@pytest.fixture
def mock_genai_client(self):
"""Create a mock genai client."""
client = MagicMock()
client.aio = MagicMock()
client.aio.models = MagicMock()
return client
async def test_successful_embedding_generation(self, mock_genai_client):
"""Test successful embedding generation."""
# Setup mock response
mock_response = MagicMock()
mock_embedding = MagicMock()
mock_embedding.values = [0.1, 0.2, 0.3, 0.4, 0.5]
mock_response.embeddings = [mock_embedding]
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
# Execute
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"What is financial education?"
)
# Assert
assert error is None
assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
mock_genai_client.aio.models.embed_content.assert_called_once()
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
assert call_kwargs["model"] == "models/text-embedding-004"
assert call_kwargs["contents"] == "What is financial education?"
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
async def test_empty_query_string(self, mock_genai_client):
"""Test handling of empty query string."""
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
""
)
assert embedding == []
assert error == "Error: Query cannot be empty"
mock_genai_client.aio.models.embed_content.assert_not_called()
async def test_whitespace_only_query(self, mock_genai_client):
"""Test handling of whitespace-only query."""
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
" \t\n "
)
assert embedding == []
assert error == "Error: Query cannot be empty"
mock_genai_client.aio.models.embed_content.assert_not_called()
async def test_rate_limit_error_429(self, mock_genai_client):
"""Test handling of 429 rate limit error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("429 Too Many Requests")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"test query"
)
assert embedding == []
assert error == "Error: API rate limit exceeded. Please try again later."
async def test_rate_limit_error_resource_exhausted(self, mock_genai_client):
"""Test handling of RESOURCE_EXHAUSTED error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("RESOURCE_EXHAUSTED: Quota exceeded")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"test query"
)
assert embedding == []
assert error == "Error: API rate limit exceeded. Please try again later."
async def test_generic_api_error(self, mock_genai_client):
"""Test handling of generic API error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=ValueError("Invalid model name")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"invalid-model",
"test query"
)
assert embedding == []
assert "Error generating embedding: Invalid model name" in error
async def test_network_error(self, mock_genai_client):
"""Test handling of network error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=ConnectionError("Network unreachable")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"test query"
)
assert embedding == []
assert "Error generating embedding: Network unreachable" in error
async def test_long_query_truncation_in_logging(self, mock_genai_client):
"""Test that long queries are truncated in error logging."""
long_query = "a" * 200
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("API error")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
long_query
)
assert embedding == []
assert error is not None
class TestFilterSearchResults:
"""Tests for filter_search_results function."""
def test_empty_results(self):
"""Test filtering empty results list."""
filtered = filter_search_results([])
assert filtered == []
def test_single_result_above_thresholds(self):
"""Test single result above both thresholds."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.85, "content": "test content"}
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_single_result_below_min_similarity(self):
"""Test single result below minimum similarity threshold."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.5, "content": "test content"}
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert filtered == []
def test_multiple_results_all_above_thresholds(self):
"""Test multiple results all above thresholds."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.95, "content": "content 1"},
{"id": "doc2", "distance": 0.90, "content": "content 2"},
{"id": "doc3", "distance": 0.85, "content": "content 3"},
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.8)
# max_sim = 0.95, cutoff = 0.95 * 0.8 = 0.76
# Results with distance > 0.76 and > 0.6: all three
assert len(filtered) == 3
def test_top_percent_filtering(self):
"""Test filtering by top_percent threshold."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 1.0, "content": "content 1"},
{"id": "doc2", "distance": 0.95, "content": "content 2"},
{"id": "doc3", "distance": 0.85, "content": "content 3"},
{"id": "doc4", "distance": 0.70, "content": "content 4"},
]
# max_sim = 1.0, cutoff = 1.0 * 0.9 = 0.9
# Results with distance > 0.9: doc1 (1.0), doc2 (0.95)
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 2
assert filtered[0]["id"] == "doc1"
assert filtered[1]["id"] == "doc2"
def test_min_similarity_filtering(self):
"""Test filtering by minimum similarity threshold."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.95, "content": "content 1"},
{"id": "doc2", "distance": 0.75, "content": "content 2"},
{"id": "doc3", "distance": 0.55, "content": "content 3"},
]
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
# doc1 > 0.855 and > 0.7: included
# doc2 < 0.855: excluded by top_percent
# doc3 < 0.7: excluded by min_similarity
filtered = filter_search_results(results, min_similarity=0.7, top_percent=0.9)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_default_parameters(self):
"""Test filtering with default parameters."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.95, "content": "content 1"},
{"id": "doc2", "distance": 0.85, "content": "content 2"},
{"id": "doc3", "distance": 0.50, "content": "content 3"},
]
# Default: min_similarity=0.6, top_percent=0.9
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
# doc1 > 0.855 and > 0.6: included
# doc2 < 0.855: excluded
# doc3 < 0.6: excluded
filtered = filter_search_results(results)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_all_results_filtered_out(self):
"""Test when all results are filtered out."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.55, "content": "content 1"},
{"id": "doc2", "distance": 0.45, "content": "content 2"},
{"id": "doc3", "distance": 0.35, "content": "content 3"},
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert filtered == []
def test_exact_threshold_boundaries(self):
"""Test behavior at exact threshold boundaries."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.9, "content": "content 1"},
{"id": "doc2", "distance": 0.6, "content": "content 2"},
]
# max_sim = 0.9, cutoff = 0.9 * 0.9 = 0.81
# doc1: 0.9 > 0.81 and 0.9 > 0.6: included
# doc2: 0.6 < 0.81: excluded
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_identical_distances(self):
"""Test filtering with identical distance values."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.8, "content": "content 1"},
{"id": "doc2", "distance": 0.8, "content": "content 2"},
{"id": "doc3", "distance": 0.8, "content": "content 3"},
]
# max_sim = 0.8, cutoff = 0.8 * 0.9 = 0.72
# All have distance 0.8 > 0.72 and > 0.6: all included
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 3
class TestFormatSearchResults:
"""Tests for format_search_results function."""
def test_empty_results(self):
"""Test formatting empty results list."""
formatted = format_search_results([])
assert formatted == "No relevant documents found for your query."
def test_single_result(self):
"""Test formatting single result."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": "This is the content."}
]
formatted = format_search_results(results)
expected = "<document 1 name=doc1.txt>\nThis is the content.\n</document 1>"
assert formatted == expected
def test_multiple_results(self):
"""Test formatting multiple results."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": "First document content."},
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content."},
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content."},
]
formatted = format_search_results(results)
expected = (
"<document 1 name=doc1.txt>\nFirst document content.\n</document 1>\n"
"<document 2 name=doc2.txt>\nSecond document content.\n</document 2>\n"
"<document 3 name=doc3.txt>\nThird document content.\n</document 3>"
)
assert formatted == expected
def test_multiline_content(self):
"""Test formatting results with multiline content."""
results: list[SearchResult] = [
{
"id": "doc1.txt",
"distance": 0.95,
"content": "Line 1\nLine 2\nLine 3"
}
]
formatted = format_search_results(results)
expected = "<document 1 name=doc1.txt>\nLine 1\nLine 2\nLine 3\n</document 1>"
assert formatted == expected
def test_special_characters_in_content(self):
"""Test formatting with special characters in content."""
results: list[SearchResult] = [
{
"id": "doc1.txt",
"distance": 0.95,
"content": "Content with <special> & \"characters\""
}
]
formatted = format_search_results(results)
expected = '<document 1 name=doc1.txt>\nContent with <special> & "characters"\n</document 1>'
assert formatted == expected
def test_special_characters_in_document_id(self):
"""Test formatting with special characters in document ID."""
results: list[SearchResult] = [
{
"id": "path/to/doc-name_v2.txt",
"distance": 0.95,
"content": "Some content"
}
]
formatted = format_search_results(results)
expected = "<document 1 name=path/to/doc-name_v2.txt>\nSome content\n</document 1>"
assert formatted == expected
def test_empty_content(self):
"""Test formatting result with empty content."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": ""}
]
formatted = format_search_results(results)
expected = "<document 1 name=doc1.txt>\n\n</document 1>"
assert formatted == expected
def test_document_numbering(self):
"""Test that document numbering starts at 1 and increments correctly."""
results: list[SearchResult] = [
{"id": "a.txt", "distance": 0.9, "content": "A"},
{"id": "b.txt", "distance": 0.8, "content": "B"},
{"id": "c.txt", "distance": 0.7, "content": "C"},
{"id": "d.txt", "distance": 0.6, "content": "D"},
{"id": "e.txt", "distance": 0.5, "content": "E"},
]
formatted = format_search_results(results)
assert "<document 1 name=a.txt>" in formatted
assert "</document 1>" in formatted
assert "<document 2 name=b.txt>" in formatted
assert "</document 2>" in formatted
assert "<document 3 name=c.txt>" in formatted
assert "</document 3>" in formatted
assert "<document 4 name=d.txt>" in formatted
assert "</document 4>" in formatted
assert "<document 5 name=e.txt>" in formatted
assert "</document 5>" in formatted
def test_very_long_content(self):
"""Test formatting with very long content."""
long_content = "A" * 10000
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": long_content}
]
formatted = format_search_results(results)
assert f"<document 1 name=doc1.txt>\n{long_content}\n</document 1>" == formatted
assert len(formatted) > 10000