"""Tests for search service functions.""" import pytest from unittest.mock import AsyncMock, MagicMock, patch from knowledge_search_mcp.services.search import ( generate_query_embedding, filter_search_results, format_search_results, ) from knowledge_search_mcp.models import SearchResult class TestGenerateQueryEmbedding: """Tests for generate_query_embedding function.""" @pytest.fixture def mock_genai_client(self): """Create a mock genai client.""" client = MagicMock() client.aio = MagicMock() client.aio.models = MagicMock() return client async def test_successful_embedding_generation(self, mock_genai_client): """Test successful embedding generation.""" # Setup mock response mock_response = MagicMock() mock_embedding = MagicMock() mock_embedding.values = [0.1, 0.2, 0.3, 0.4, 0.5] mock_response.embeddings = [mock_embedding] mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response) # Execute embedding, error = await generate_query_embedding( mock_genai_client, "models/text-embedding-004", "What is financial education?" ) # Assert assert error is None assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5] mock_genai_client.aio.models.embed_content.assert_called_once() call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs assert call_kwargs["model"] == "models/text-embedding-004" assert call_kwargs["contents"] == "What is financial education?" assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY" async def test_empty_query_string(self, mock_genai_client): """Test handling of empty query string.""" embedding, error = await generate_query_embedding( mock_genai_client, "models/text-embedding-004", "" ) assert embedding == [] assert error == "Error: Query cannot be empty" mock_genai_client.aio.models.embed_content.assert_not_called() async def test_whitespace_only_query(self, mock_genai_client): """Test handling of whitespace-only query.""" embedding, error = await generate_query_embedding( mock_genai_client, "models/text-embedding-004", " \t\n " ) assert embedding == [] assert error == "Error: Query cannot be empty" mock_genai_client.aio.models.embed_content.assert_not_called() async def test_rate_limit_error_429(self, mock_genai_client): """Test handling of 429 rate limit error.""" mock_genai_client.aio.models.embed_content = AsyncMock( side_effect=Exception("429 Too Many Requests") ) embedding, error = await generate_query_embedding( mock_genai_client, "models/text-embedding-004", "test query" ) assert embedding == [] assert error == "Error: API rate limit exceeded. Please try again later." async def test_rate_limit_error_resource_exhausted(self, mock_genai_client): """Test handling of RESOURCE_EXHAUSTED error.""" mock_genai_client.aio.models.embed_content = AsyncMock( side_effect=Exception("RESOURCE_EXHAUSTED: Quota exceeded") ) embedding, error = await generate_query_embedding( mock_genai_client, "models/text-embedding-004", "test query" ) assert embedding == [] assert error == "Error: API rate limit exceeded. Please try again later." async def test_generic_api_error(self, mock_genai_client): """Test handling of generic API error.""" mock_genai_client.aio.models.embed_content = AsyncMock( side_effect=ValueError("Invalid model name") ) embedding, error = await generate_query_embedding( mock_genai_client, "invalid-model", "test query" ) assert embedding == [] assert "Error generating embedding: Invalid model name" in error async def test_network_error(self, mock_genai_client): """Test handling of network error.""" mock_genai_client.aio.models.embed_content = AsyncMock( side_effect=ConnectionError("Network unreachable") ) embedding, error = await generate_query_embedding( mock_genai_client, "models/text-embedding-004", "test query" ) assert embedding == [] assert "Error generating embedding: Network unreachable" in error async def test_long_query_truncation_in_logging(self, mock_genai_client): """Test that long queries are truncated in error logging.""" long_query = "a" * 200 mock_genai_client.aio.models.embed_content = AsyncMock( side_effect=Exception("API error") ) embedding, error = await generate_query_embedding( mock_genai_client, "models/text-embedding-004", long_query ) assert embedding == [] assert error is not None class TestFilterSearchResults: """Tests for filter_search_results function.""" def test_empty_results(self): """Test filtering empty results list.""" filtered = filter_search_results([]) assert filtered == [] def test_single_result_above_thresholds(self): """Test single result above both thresholds.""" results: list[SearchResult] = [ {"id": "doc1", "distance": 0.85, "content": "test content"} ] filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9) assert len(filtered) == 1 assert filtered[0]["id"] == "doc1" def test_single_result_below_min_similarity(self): """Test single result below minimum similarity threshold.""" results: list[SearchResult] = [ {"id": "doc1", "distance": 0.5, "content": "test content"} ] filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9) assert filtered == [] def test_multiple_results_all_above_thresholds(self): """Test multiple results all above thresholds.""" results: list[SearchResult] = [ {"id": "doc1", "distance": 0.95, "content": "content 1"}, {"id": "doc2", "distance": 0.90, "content": "content 2"}, {"id": "doc3", "distance": 0.85, "content": "content 3"}, ] filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.8) # max_sim = 0.95, cutoff = 0.95 * 0.8 = 0.76 # Results with distance > 0.76 and > 0.6: all three assert len(filtered) == 3 def test_top_percent_filtering(self): """Test filtering by top_percent threshold.""" results: list[SearchResult] = [ {"id": "doc1", "distance": 1.0, "content": "content 1"}, {"id": "doc2", "distance": 0.95, "content": "content 2"}, {"id": "doc3", "distance": 0.85, "content": "content 3"}, {"id": "doc4", "distance": 0.70, "content": "content 4"}, ] # max_sim = 1.0, cutoff = 1.0 * 0.9 = 0.9 # Results with distance > 0.9: doc1 (1.0), doc2 (0.95) filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9) assert len(filtered) == 2 assert filtered[0]["id"] == "doc1" assert filtered[1]["id"] == "doc2" def test_min_similarity_filtering(self): """Test filtering by minimum similarity threshold.""" results: list[SearchResult] = [ {"id": "doc1", "distance": 0.95, "content": "content 1"}, {"id": "doc2", "distance": 0.75, "content": "content 2"}, {"id": "doc3", "distance": 0.55, "content": "content 3"}, ] # max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855 # doc1 > 0.855 and > 0.7: included # doc2 < 0.855: excluded by top_percent # doc3 < 0.7: excluded by min_similarity filtered = filter_search_results(results, min_similarity=0.7, top_percent=0.9) assert len(filtered) == 1 assert filtered[0]["id"] == "doc1" def test_default_parameters(self): """Test filtering with default parameters.""" results: list[SearchResult] = [ {"id": "doc1", "distance": 0.95, "content": "content 1"}, {"id": "doc2", "distance": 0.85, "content": "content 2"}, {"id": "doc3", "distance": 0.50, "content": "content 3"}, ] # Default: min_similarity=0.6, top_percent=0.9 # max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855 # doc1 > 0.855 and > 0.6: included # doc2 < 0.855: excluded # doc3 < 0.6: excluded filtered = filter_search_results(results) assert len(filtered) == 1 assert filtered[0]["id"] == "doc1" def test_all_results_filtered_out(self): """Test when all results are filtered out.""" results: list[SearchResult] = [ {"id": "doc1", "distance": 0.55, "content": "content 1"}, {"id": "doc2", "distance": 0.45, "content": "content 2"}, {"id": "doc3", "distance": 0.35, "content": "content 3"}, ] filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9) assert filtered == [] def test_exact_threshold_boundaries(self): """Test behavior at exact threshold boundaries.""" results: list[SearchResult] = [ {"id": "doc1", "distance": 0.9, "content": "content 1"}, {"id": "doc2", "distance": 0.6, "content": "content 2"}, ] # max_sim = 0.9, cutoff = 0.9 * 0.9 = 0.81 # doc1: 0.9 > 0.81 and 0.9 > 0.6: included # doc2: 0.6 < 0.81: excluded filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9) assert len(filtered) == 1 assert filtered[0]["id"] == "doc1" def test_identical_distances(self): """Test filtering with identical distance values.""" results: list[SearchResult] = [ {"id": "doc1", "distance": 0.8, "content": "content 1"}, {"id": "doc2", "distance": 0.8, "content": "content 2"}, {"id": "doc3", "distance": 0.8, "content": "content 3"}, ] # max_sim = 0.8, cutoff = 0.8 * 0.9 = 0.72 # All have distance 0.8 > 0.72 and > 0.6: all included filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9) assert len(filtered) == 3 class TestFormatSearchResults: """Tests for format_search_results function.""" def test_empty_results(self): """Test formatting empty results list.""" formatted = format_search_results([]) assert formatted == "No relevant documents found for your query." def test_single_result(self): """Test formatting single result.""" results: list[SearchResult] = [ {"id": "doc1.txt", "distance": 0.95, "content": "This is the content."} ] formatted = format_search_results(results) expected = "\nThis is the content.\n" assert formatted == expected def test_multiple_results(self): """Test formatting multiple results.""" results: list[SearchResult] = [ {"id": "doc1.txt", "distance": 0.95, "content": "First document content."}, {"id": "doc2.txt", "distance": 0.85, "content": "Second document content."}, {"id": "doc3.txt", "distance": 0.75, "content": "Third document content."}, ] formatted = format_search_results(results) expected = ( "\nFirst document content.\n\n" "\nSecond document content.\n\n" "\nThird document content.\n" ) assert formatted == expected def test_multiline_content(self): """Test formatting results with multiline content.""" results: list[SearchResult] = [ { "id": "doc1.txt", "distance": 0.95, "content": "Line 1\nLine 2\nLine 3" } ] formatted = format_search_results(results) expected = "\nLine 1\nLine 2\nLine 3\n" assert formatted == expected def test_special_characters_in_content(self): """Test formatting with special characters in content.""" results: list[SearchResult] = [ { "id": "doc1.txt", "distance": 0.95, "content": "Content with & \"characters\"" } ] formatted = format_search_results(results) expected = '\nContent with & "characters"\n' assert formatted == expected def test_special_characters_in_document_id(self): """Test formatting with special characters in document ID.""" results: list[SearchResult] = [ { "id": "path/to/doc-name_v2.txt", "distance": 0.95, "content": "Some content" } ] formatted = format_search_results(results) expected = "\nSome content\n" assert formatted == expected def test_empty_content(self): """Test formatting result with empty content.""" results: list[SearchResult] = [ {"id": "doc1.txt", "distance": 0.95, "content": ""} ] formatted = format_search_results(results) expected = "\n\n" assert formatted == expected def test_document_numbering(self): """Test that document numbering starts at 1 and increments correctly.""" results: list[SearchResult] = [ {"id": "a.txt", "distance": 0.9, "content": "A"}, {"id": "b.txt", "distance": 0.8, "content": "B"}, {"id": "c.txt", "distance": 0.7, "content": "C"}, {"id": "d.txt", "distance": 0.6, "content": "D"}, {"id": "e.txt", "distance": 0.5, "content": "E"}, ] formatted = format_search_results(results) assert "" in formatted assert "" in formatted assert "" in formatted assert "" in formatted assert "" in formatted assert "" in formatted assert "" in formatted assert "" in formatted assert "" in formatted assert "" in formatted def test_very_long_content(self): """Test formatting with very long content.""" long_content = "A" * 10000 results: list[SearchResult] = [ {"id": "doc1.txt", "distance": 0.95, "content": long_content} ] formatted = format_search_results(results) assert f"\n{long_content}\n" == formatted assert len(formatted) > 10000