Add more tests
This commit is contained in:
381
tests/test_search_services.py
Normal file
381
tests/test_search_services.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Tests for search service functions."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from knowledge_search_mcp.services.search import (
|
||||
generate_query_embedding,
|
||||
filter_search_results,
|
||||
format_search_results,
|
||||
)
|
||||
from knowledge_search_mcp.models import SearchResult
|
||||
|
||||
|
||||
class TestGenerateQueryEmbedding:
|
||||
"""Tests for generate_query_embedding function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai_client(self):
|
||||
"""Create a mock genai client."""
|
||||
client = MagicMock()
|
||||
client.aio = MagicMock()
|
||||
client.aio.models = MagicMock()
|
||||
return client
|
||||
|
||||
async def test_successful_embedding_generation(self, mock_genai_client):
|
||||
"""Test successful embedding generation."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"What is financial education?"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert error is None
|
||||
assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
mock_genai_client.aio.models.embed_content.assert_called_once()
|
||||
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
|
||||
assert call_kwargs["model"] == "models/text-embedding-004"
|
||||
assert call_kwargs["contents"] == "What is financial education?"
|
||||
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
|
||||
|
||||
async def test_empty_query_string(self, mock_genai_client):
|
||||
"""Test handling of empty query string."""
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
""
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: Query cannot be empty"
|
||||
mock_genai_client.aio.models.embed_content.assert_not_called()
|
||||
|
||||
async def test_whitespace_only_query(self, mock_genai_client):
|
||||
"""Test handling of whitespace-only query."""
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
" \t\n "
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: Query cannot be empty"
|
||||
mock_genai_client.aio.models.embed_content.assert_not_called()
|
||||
|
||||
async def test_rate_limit_error_429(self, mock_genai_client):
|
||||
"""Test handling of 429 rate limit error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("429 Too Many Requests")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: API rate limit exceeded. Please try again later."
|
||||
|
||||
async def test_rate_limit_error_resource_exhausted(self, mock_genai_client):
|
||||
"""Test handling of RESOURCE_EXHAUSTED error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("RESOURCE_EXHAUSTED: Quota exceeded")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: API rate limit exceeded. Please try again later."
|
||||
|
||||
async def test_generic_api_error(self, mock_genai_client):
|
||||
"""Test handling of generic API error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ValueError("Invalid model name")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"invalid-model",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert "Error generating embedding: Invalid model name" in error
|
||||
|
||||
async def test_network_error(self, mock_genai_client):
|
||||
"""Test handling of network error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ConnectionError("Network unreachable")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert "Error generating embedding: Network unreachable" in error
|
||||
|
||||
async def test_long_query_truncation_in_logging(self, mock_genai_client):
|
||||
"""Test that long queries are truncated in error logging."""
|
||||
long_query = "a" * 200
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("API error")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
long_query
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error is not None
|
||||
|
||||
|
||||
class TestFilterSearchResults:
|
||||
"""Tests for filter_search_results function."""
|
||||
|
||||
def test_empty_results(self):
|
||||
"""Test filtering empty results list."""
|
||||
filtered = filter_search_results([])
|
||||
assert filtered == []
|
||||
|
||||
def test_single_result_above_thresholds(self):
|
||||
"""Test single result above both thresholds."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.85, "content": "test content"}
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_single_result_below_min_similarity(self):
|
||||
"""Test single result below minimum similarity threshold."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.5, "content": "test content"}
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert filtered == []
|
||||
|
||||
def test_multiple_results_all_above_thresholds(self):
|
||||
"""Test multiple results all above thresholds."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.90, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.85, "content": "content 3"},
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.8)
|
||||
# max_sim = 0.95, cutoff = 0.95 * 0.8 = 0.76
|
||||
# Results with distance > 0.76 and > 0.6: all three
|
||||
assert len(filtered) == 3
|
||||
|
||||
def test_top_percent_filtering(self):
|
||||
"""Test filtering by top_percent threshold."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 1.0, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.95, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.85, "content": "content 3"},
|
||||
{"id": "doc4", "distance": 0.70, "content": "content 4"},
|
||||
]
|
||||
# max_sim = 1.0, cutoff = 1.0 * 0.9 = 0.9
|
||||
# Results with distance > 0.9: doc1 (1.0), doc2 (0.95)
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 2
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
assert filtered[1]["id"] == "doc2"
|
||||
|
||||
def test_min_similarity_filtering(self):
|
||||
"""Test filtering by minimum similarity threshold."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.75, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.55, "content": "content 3"},
|
||||
]
|
||||
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
|
||||
# doc1 > 0.855 and > 0.7: included
|
||||
# doc2 < 0.855: excluded by top_percent
|
||||
# doc3 < 0.7: excluded by min_similarity
|
||||
filtered = filter_search_results(results, min_similarity=0.7, top_percent=0.9)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_default_parameters(self):
|
||||
"""Test filtering with default parameters."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.85, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.50, "content": "content 3"},
|
||||
]
|
||||
# Default: min_similarity=0.6, top_percent=0.9
|
||||
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
|
||||
# doc1 > 0.855 and > 0.6: included
|
||||
# doc2 < 0.855: excluded
|
||||
# doc3 < 0.6: excluded
|
||||
filtered = filter_search_results(results)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_all_results_filtered_out(self):
|
||||
"""Test when all results are filtered out."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.55, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.45, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.35, "content": "content 3"},
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert filtered == []
|
||||
|
||||
def test_exact_threshold_boundaries(self):
|
||||
"""Test behavior at exact threshold boundaries."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.9, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.6, "content": "content 2"},
|
||||
]
|
||||
# max_sim = 0.9, cutoff = 0.9 * 0.9 = 0.81
|
||||
# doc1: 0.9 > 0.81 and 0.9 > 0.6: included
|
||||
# doc2: 0.6 < 0.81: excluded
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_identical_distances(self):
|
||||
"""Test filtering with identical distance values."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.8, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.8, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.8, "content": "content 3"},
|
||||
]
|
||||
# max_sim = 0.8, cutoff = 0.8 * 0.9 = 0.72
|
||||
# All have distance 0.8 > 0.72 and > 0.6: all included
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 3
|
||||
|
||||
|
||||
class TestFormatSearchResults:
|
||||
"""Tests for format_search_results function."""
|
||||
|
||||
def test_empty_results(self):
|
||||
"""Test formatting empty results list."""
|
||||
formatted = format_search_results([])
|
||||
assert formatted == "No relevant documents found for your query."
|
||||
|
||||
def test_single_result(self):
|
||||
"""Test formatting single result."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": "This is the content."}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=doc1.txt>\nThis is the content.\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_multiple_results(self):
|
||||
"""Test formatting multiple results."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": "First document content."},
|
||||
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content."},
|
||||
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content."},
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = (
|
||||
"<document 1 name=doc1.txt>\nFirst document content.\n</document 1>\n"
|
||||
"<document 2 name=doc2.txt>\nSecond document content.\n</document 2>\n"
|
||||
"<document 3 name=doc3.txt>\nThird document content.\n</document 3>"
|
||||
)
|
||||
assert formatted == expected
|
||||
|
||||
def test_multiline_content(self):
|
||||
"""Test formatting results with multiline content."""
|
||||
results: list[SearchResult] = [
|
||||
{
|
||||
"id": "doc1.txt",
|
||||
"distance": 0.95,
|
||||
"content": "Line 1\nLine 2\nLine 3"
|
||||
}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=doc1.txt>\nLine 1\nLine 2\nLine 3\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_special_characters_in_content(self):
|
||||
"""Test formatting with special characters in content."""
|
||||
results: list[SearchResult] = [
|
||||
{
|
||||
"id": "doc1.txt",
|
||||
"distance": 0.95,
|
||||
"content": "Content with <special> & \"characters\""
|
||||
}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = '<document 1 name=doc1.txt>\nContent with <special> & "characters"\n</document 1>'
|
||||
assert formatted == expected
|
||||
|
||||
def test_special_characters_in_document_id(self):
|
||||
"""Test formatting with special characters in document ID."""
|
||||
results: list[SearchResult] = [
|
||||
{
|
||||
"id": "path/to/doc-name_v2.txt",
|
||||
"distance": 0.95,
|
||||
"content": "Some content"
|
||||
}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=path/to/doc-name_v2.txt>\nSome content\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_empty_content(self):
|
||||
"""Test formatting result with empty content."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": ""}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=doc1.txt>\n\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_document_numbering(self):
|
||||
"""Test that document numbering starts at 1 and increments correctly."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "a.txt", "distance": 0.9, "content": "A"},
|
||||
{"id": "b.txt", "distance": 0.8, "content": "B"},
|
||||
{"id": "c.txt", "distance": 0.7, "content": "C"},
|
||||
{"id": "d.txt", "distance": 0.6, "content": "D"},
|
||||
{"id": "e.txt", "distance": 0.5, "content": "E"},
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
|
||||
assert "<document 1 name=a.txt>" in formatted
|
||||
assert "</document 1>" in formatted
|
||||
assert "<document 2 name=b.txt>" in formatted
|
||||
assert "</document 2>" in formatted
|
||||
assert "<document 3 name=c.txt>" in formatted
|
||||
assert "</document 3>" in formatted
|
||||
assert "<document 4 name=d.txt>" in formatted
|
||||
assert "</document 4>" in formatted
|
||||
assert "<document 5 name=e.txt>" in formatted
|
||||
assert "</document 5>" in formatted
|
||||
|
||||
def test_very_long_content(self):
|
||||
"""Test formatting with very long content."""
|
||||
long_content = "A" * 10000
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": long_content}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
assert f"<document 1 name=doc1.txt>\n{long_content}\n</document 1>" == formatted
|
||||
assert len(formatted) > 10000
|
||||
Reference in New Issue
Block a user