"""Tests for search service functions."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from knowledge_search_mcp.services.search import (
generate_query_embedding,
filter_search_results,
format_search_results,
)
from knowledge_search_mcp.models import SearchResult
class TestGenerateQueryEmbedding:
"""Tests for generate_query_embedding function."""
@pytest.fixture
def mock_genai_client(self):
"""Create a mock genai client."""
client = MagicMock()
client.aio = MagicMock()
client.aio.models = MagicMock()
return client
async def test_successful_embedding_generation(self, mock_genai_client):
"""Test successful embedding generation."""
# Setup mock response
mock_response = MagicMock()
mock_embedding = MagicMock()
mock_embedding.values = [0.1, 0.2, 0.3, 0.4, 0.5]
mock_response.embeddings = [mock_embedding]
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
# Execute
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"What is financial education?"
)
# Assert
assert error is None
assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
mock_genai_client.aio.models.embed_content.assert_called_once()
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
assert call_kwargs["model"] == "models/text-embedding-004"
assert call_kwargs["contents"] == "What is financial education?"
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
async def test_empty_query_string(self, mock_genai_client):
"""Test handling of empty query string."""
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
""
)
assert embedding == []
assert error == "Error: Query cannot be empty"
mock_genai_client.aio.models.embed_content.assert_not_called()
async def test_whitespace_only_query(self, mock_genai_client):
"""Test handling of whitespace-only query."""
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
" \t\n "
)
assert embedding == []
assert error == "Error: Query cannot be empty"
mock_genai_client.aio.models.embed_content.assert_not_called()
async def test_rate_limit_error_429(self, mock_genai_client):
"""Test handling of 429 rate limit error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("429 Too Many Requests")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"test query"
)
assert embedding == []
assert error == "Error: API rate limit exceeded. Please try again later."
async def test_rate_limit_error_resource_exhausted(self, mock_genai_client):
"""Test handling of RESOURCE_EXHAUSTED error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("RESOURCE_EXHAUSTED: Quota exceeded")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"test query"
)
assert embedding == []
assert error == "Error: API rate limit exceeded. Please try again later."
async def test_generic_api_error(self, mock_genai_client):
"""Test handling of generic API error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=ValueError("Invalid model name")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"invalid-model",
"test query"
)
assert embedding == []
assert "Error generating embedding: Invalid model name" in error
async def test_network_error(self, mock_genai_client):
"""Test handling of network error."""
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=ConnectionError("Network unreachable")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
"test query"
)
assert embedding == []
assert "Error generating embedding: Network unreachable" in error
async def test_long_query_truncation_in_logging(self, mock_genai_client):
"""Test that long queries are truncated in error logging."""
long_query = "a" * 200
mock_genai_client.aio.models.embed_content = AsyncMock(
side_effect=Exception("API error")
)
embedding, error = await generate_query_embedding(
mock_genai_client,
"models/text-embedding-004",
long_query
)
assert embedding == []
assert error is not None
class TestFilterSearchResults:
"""Tests for filter_search_results function."""
def test_empty_results(self):
"""Test filtering empty results list."""
filtered = filter_search_results([])
assert filtered == []
def test_single_result_above_thresholds(self):
"""Test single result above both thresholds."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.85, "content": "test content"}
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_single_result_below_min_similarity(self):
"""Test single result below minimum similarity threshold."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.5, "content": "test content"}
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert filtered == []
def test_multiple_results_all_above_thresholds(self):
"""Test multiple results all above thresholds."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.95, "content": "content 1"},
{"id": "doc2", "distance": 0.90, "content": "content 2"},
{"id": "doc3", "distance": 0.85, "content": "content 3"},
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.8)
# max_sim = 0.95, cutoff = 0.95 * 0.8 = 0.76
# Results with distance > 0.76 and > 0.6: all three
assert len(filtered) == 3
def test_top_percent_filtering(self):
"""Test filtering by top_percent threshold."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 1.0, "content": "content 1"},
{"id": "doc2", "distance": 0.95, "content": "content 2"},
{"id": "doc3", "distance": 0.85, "content": "content 3"},
{"id": "doc4", "distance": 0.70, "content": "content 4"},
]
# max_sim = 1.0, cutoff = 1.0 * 0.9 = 0.9
# Results with distance > 0.9: doc1 (1.0), doc2 (0.95)
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 2
assert filtered[0]["id"] == "doc1"
assert filtered[1]["id"] == "doc2"
def test_min_similarity_filtering(self):
"""Test filtering by minimum similarity threshold."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.95, "content": "content 1"},
{"id": "doc2", "distance": 0.75, "content": "content 2"},
{"id": "doc3", "distance": 0.55, "content": "content 3"},
]
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
# doc1 > 0.855 and > 0.7: included
# doc2 < 0.855: excluded by top_percent
# doc3 < 0.7: excluded by min_similarity
filtered = filter_search_results(results, min_similarity=0.7, top_percent=0.9)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_default_parameters(self):
"""Test filtering with default parameters."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.95, "content": "content 1"},
{"id": "doc2", "distance": 0.85, "content": "content 2"},
{"id": "doc3", "distance": 0.50, "content": "content 3"},
]
# Default: min_similarity=0.6, top_percent=0.9
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
# doc1 > 0.855 and > 0.6: included
# doc2 < 0.855: excluded
# doc3 < 0.6: excluded
filtered = filter_search_results(results)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_all_results_filtered_out(self):
"""Test when all results are filtered out."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.55, "content": "content 1"},
{"id": "doc2", "distance": 0.45, "content": "content 2"},
{"id": "doc3", "distance": 0.35, "content": "content 3"},
]
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert filtered == []
def test_exact_threshold_boundaries(self):
"""Test behavior at exact threshold boundaries."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.9, "content": "content 1"},
{"id": "doc2", "distance": 0.6, "content": "content 2"},
]
# max_sim = 0.9, cutoff = 0.9 * 0.9 = 0.81
# doc1: 0.9 > 0.81 and 0.9 > 0.6: included
# doc2: 0.6 < 0.81: excluded
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 1
assert filtered[0]["id"] == "doc1"
def test_identical_distances(self):
"""Test filtering with identical distance values."""
results: list[SearchResult] = [
{"id": "doc1", "distance": 0.8, "content": "content 1"},
{"id": "doc2", "distance": 0.8, "content": "content 2"},
{"id": "doc3", "distance": 0.8, "content": "content 3"},
]
# max_sim = 0.8, cutoff = 0.8 * 0.9 = 0.72
# All have distance 0.8 > 0.72 and > 0.6: all included
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
assert len(filtered) == 3
class TestFormatSearchResults:
"""Tests for format_search_results function."""
def test_empty_results(self):
"""Test formatting empty results list."""
formatted = format_search_results([])
assert formatted == "No relevant documents found for your query."
def test_single_result(self):
"""Test formatting single result."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": "This is the content."}
]
formatted = format_search_results(results)
expected = "\nThis is the content.\n"
assert formatted == expected
def test_multiple_results(self):
"""Test formatting multiple results."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": "First document content."},
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content."},
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content."},
]
formatted = format_search_results(results)
expected = (
"\nFirst document content.\n\n"
"\nSecond document content.\n\n"
"\nThird document content.\n"
)
assert formatted == expected
def test_multiline_content(self):
"""Test formatting results with multiline content."""
results: list[SearchResult] = [
{
"id": "doc1.txt",
"distance": 0.95,
"content": "Line 1\nLine 2\nLine 3"
}
]
formatted = format_search_results(results)
expected = "\nLine 1\nLine 2\nLine 3\n"
assert formatted == expected
def test_special_characters_in_content(self):
"""Test formatting with special characters in content."""
results: list[SearchResult] = [
{
"id": "doc1.txt",
"distance": 0.95,
"content": "Content with & \"characters\""
}
]
formatted = format_search_results(results)
expected = '\nContent with & "characters"\n'
assert formatted == expected
def test_special_characters_in_document_id(self):
"""Test formatting with special characters in document ID."""
results: list[SearchResult] = [
{
"id": "path/to/doc-name_v2.txt",
"distance": 0.95,
"content": "Some content"
}
]
formatted = format_search_results(results)
expected = "\nSome content\n"
assert formatted == expected
def test_empty_content(self):
"""Test formatting result with empty content."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": ""}
]
formatted = format_search_results(results)
expected = "\n\n"
assert formatted == expected
def test_document_numbering(self):
"""Test that document numbering starts at 1 and increments correctly."""
results: list[SearchResult] = [
{"id": "a.txt", "distance": 0.9, "content": "A"},
{"id": "b.txt", "distance": 0.8, "content": "B"},
{"id": "c.txt", "distance": 0.7, "content": "C"},
{"id": "d.txt", "distance": 0.6, "content": "D"},
{"id": "e.txt", "distance": 0.5, "content": "E"},
]
formatted = format_search_results(results)
assert "" in formatted
assert "" in formatted
assert "" in formatted
assert "" in formatted
assert "" in formatted
assert "" in formatted
assert "" in formatted
assert "" in formatted
assert "" in formatted
assert "" in formatted
def test_very_long_content(self):
"""Test formatting with very long content."""
long_content = "A" * 10000
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": long_content}
]
formatted = format_search_results(results)
assert f"\n{long_content}\n" == formatted
assert len(formatted) > 10000