"""Tests for the main knowledge_search tool.""" import pytest from unittest.mock import AsyncMock, MagicMock, patch from knowledge_search_mcp.__main__ import knowledge_search from knowledge_search_mcp.models import AppContext, SourceNamespace, SearchResult class TestKnowledgeSearch: """Tests for knowledge_search tool function.""" @pytest.fixture def mock_app_context(self): """Create a mock AppContext.""" app = MagicMock(spec=AppContext) # Mock genai_client app.genai_client = MagicMock() # Mock vector_search app.vector_search = MagicMock() app.vector_search.async_run_query = AsyncMock() # Mock settings app.settings = MagicMock() app.settings.embedding_model = "models/text-embedding-004" app.settings.deployed_index_id = "test-deployed-index" app.settings.search_limit = 10 # No semantic cache by default app.semantic_cache = None return app @pytest.fixture def mock_context(self, mock_app_context): """Create a mock MCP Context.""" ctx = MagicMock() ctx.request_context = MagicMock() ctx.request_context.lifespan_context = mock_app_context return ctx @pytest.fixture def sample_embedding(self): """Create a sample embedding vector.""" return [0.1, 0.2, 0.3, 0.4, 0.5] @pytest.fixture def sample_search_results(self): """Create sample search results.""" results: list[SearchResult] = [ {"id": "doc1.txt", "distance": 0.95, "content": "First document content"}, {"id": "doc2.txt", "distance": 0.85, "content": "Second document content"}, {"id": "doc3.txt", "distance": 0.75, "content": "Third document content"}, ] return results @patch('knowledge_search_mcp.__main__.generate_query_embedding') @patch('knowledge_search_mcp.__main__.filter_search_results') @patch('knowledge_search_mcp.__main__.format_search_results') async def test_successful_search( self, mock_format, mock_filter, mock_generate, mock_context, sample_embedding, sample_search_results ): """Test successful search workflow.""" # Setup mocks mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results mock_filter.return_value = sample_search_results mock_format.return_value = "\nFirst document content\n" # Execute result = await knowledge_search("What is financial education?", mock_context) # Assert assert result == "\nFirst document content\n" mock_generate.assert_called_once() mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with( deployed_index_id="test-deployed-index", query=sample_embedding, limit=10, source=None, ) mock_filter.assert_called_once_with(sample_search_results) mock_format.assert_called_once_with(sample_search_results) @patch('knowledge_search_mcp.__main__.generate_query_embedding') async def test_embedding_generation_error(self, mock_generate, mock_context): """Test handling of embedding generation error.""" # Setup mock to return error mock_generate.return_value = ([], "Error: API rate limit exceeded. Please try again later.") # Execute result = await knowledge_search("test query", mock_context) # Assert assert result == "Error: API rate limit exceeded. Please try again later." mock_generate.assert_called_once() # Vector search should not be called mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called() @patch('knowledge_search_mcp.__main__.generate_query_embedding') async def test_empty_query_error(self, mock_generate, mock_context): """Test handling of empty query.""" # Setup mock to return error for empty query mock_generate.return_value = ([], "Error: Query cannot be empty") # Execute result = await knowledge_search("", mock_context) # Assert assert result == "Error: Query cannot be empty" mock_generate.assert_called_once() @patch('knowledge_search_mcp.__main__.generate_query_embedding') async def test_vector_search_error(self, mock_generate, mock_context, sample_embedding): """Test handling of vector search error.""" # Setup mocks mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = Exception( "Vector search service unavailable" ) # Execute result = await knowledge_search("test query", mock_context) # Assert assert "Error performing vector search:" in result assert "Vector search service unavailable" in result @patch('knowledge_search_mcp.__main__.generate_query_embedding') @patch('knowledge_search_mcp.__main__.filter_search_results') @patch('knowledge_search_mcp.__main__.format_search_results') async def test_empty_search_results( self, mock_format, mock_filter, mock_generate, mock_context, sample_embedding ): """Test handling of empty search results.""" # Setup mocks mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = [] mock_filter.return_value = [] mock_format.return_value = "No relevant documents found for your query." # Execute result = await knowledge_search("obscure query", mock_context) # Assert assert result == "No relevant documents found for your query." mock_format.assert_called_once_with([]) @patch('knowledge_search_mcp.__main__.generate_query_embedding') @patch('knowledge_search_mcp.__main__.filter_search_results') @patch('knowledge_search_mcp.__main__.format_search_results') async def test_filtered_results_empty( self, mock_format, mock_filter, mock_generate, mock_context, sample_embedding, sample_search_results ): """Test when filtering removes all results.""" # Setup mocks - results exist but get filtered out mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results mock_filter.return_value = [] # All filtered out mock_format.return_value = "No relevant documents found for your query." # Execute result = await knowledge_search("test query", mock_context) # Assert assert result == "No relevant documents found for your query." mock_filter.assert_called_once_with(sample_search_results) @patch('knowledge_search_mcp.__main__.generate_query_embedding') @patch('knowledge_search_mcp.__main__.filter_search_results') @patch('knowledge_search_mcp.__main__.format_search_results') async def test_source_filter_parameter( self, mock_format, mock_filter, mock_generate, mock_context, sample_embedding, sample_search_results ): """Test that source filter is passed correctly to vector search.""" # Setup mocks mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results mock_filter.return_value = sample_search_results mock_format.return_value = "formatted results" # Execute with source filter source_filter = SourceNamespace.EDUCACION_FINANCIERA result = await knowledge_search("test query", mock_context, source=source_filter) # Assert assert result == "formatted results" mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with( deployed_index_id="test-deployed-index", query=sample_embedding, limit=10, source=source_filter, ) @patch('knowledge_search_mcp.__main__.generate_query_embedding') @patch('knowledge_search_mcp.__main__.filter_search_results') @patch('knowledge_search_mcp.__main__.format_search_results') async def test_all_source_filters( self, mock_format, mock_filter, mock_generate, mock_context, sample_embedding, sample_search_results ): """Test all available source filter values.""" mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results mock_filter.return_value = sample_search_results mock_format.return_value = "results" # Test each source filter for source in SourceNamespace: result = await knowledge_search("test query", mock_context, source=source) assert result == "results" @patch('knowledge_search_mcp.__main__.generate_query_embedding') async def test_vector_search_timeout(self, mock_generate, mock_context, sample_embedding): """Test handling of vector search timeout.""" mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = TimeoutError( "Request timed out" ) result = await knowledge_search("test query", mock_context) assert "Error performing vector search:" in result assert "Request timed out" in result @patch('knowledge_search_mcp.__main__.generate_query_embedding') async def test_vector_search_connection_error(self, mock_generate, mock_context, sample_embedding): """Test handling of vector search connection error.""" mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = ConnectionError( "Connection refused" ) result = await knowledge_search("test query", mock_context) assert "Error performing vector search:" in result assert "Connection refused" in result @patch('knowledge_search_mcp.__main__.generate_query_embedding') @patch('knowledge_search_mcp.__main__.filter_search_results') async def test_format_results_unexpected_error( self, mock_filter, mock_generate, mock_context, sample_embedding, sample_search_results ): """Test handling of unexpected error in format_search_results.""" mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results mock_filter.return_value = sample_search_results # Mock format_search_results to raise an error with patch('knowledge_search_mcp.__main__.format_search_results', side_effect=ValueError("Format error")): result = await knowledge_search("test query", mock_context) assert "Unexpected error during search:" in result assert "Format error" in result @patch('knowledge_search_mcp.__main__.generate_query_embedding') async def test_filter_results_unexpected_error(self, mock_generate, mock_context, sample_embedding): """Test handling of unexpected error in filter_search_results.""" mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = [ {"id": "doc1", "distance": 0.9, "content": "test"} ] # Mock filter_search_results to raise an error with patch('knowledge_search_mcp.__main__.filter_search_results', side_effect=TypeError("Filter error")): result = await knowledge_search("test query", mock_context) assert "Unexpected error during search:" in result assert "Filter error" in result @patch('knowledge_search_mcp.__main__.generate_query_embedding') @patch('knowledge_search_mcp.__main__.filter_search_results') @patch('knowledge_search_mcp.__main__.format_search_results') async def test_long_query_truncation_in_logs( self, mock_format, mock_filter, mock_generate, mock_context, sample_embedding, sample_search_results ): """Test that long queries are handled correctly.""" # Setup mocks mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results mock_filter.return_value = sample_search_results mock_format.return_value = "results" # Execute with very long query long_query = "a" * 500 result = await knowledge_search(long_query, mock_context) # Assert - should succeed assert result == "results" # Verify generate_query_embedding was called with full query assert mock_generate.call_args[0][2] == long_query @patch('knowledge_search_mcp.__main__.generate_query_embedding') @patch('knowledge_search_mcp.__main__.filter_search_results') @patch('knowledge_search_mcp.__main__.format_search_results') async def test_multiple_results_returned( self, mock_format, mock_filter, mock_generate, mock_context, sample_embedding ): """Test handling of multiple search results.""" # Create larger result set large_results: list[SearchResult] = [ {"id": f"doc{i}.txt", "distance": 0.9 - (i * 0.05), "content": f"Content {i}"} for i in range(10) ] mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = large_results mock_filter.return_value = large_results[:5] # Filter to top 5 mock_format.return_value = "formatted 5 results" result = await knowledge_search("test query", mock_context) assert result == "formatted 5 results" mock_filter.assert_called_once_with(large_results) mock_format.assert_called_once_with(large_results[:5]) @patch('knowledge_search_mcp.__main__.generate_query_embedding') @patch('knowledge_search_mcp.__main__.filter_search_results') @patch('knowledge_search_mcp.__main__.format_search_results') async def test_settings_values_used_correctly( self, mock_format, mock_filter, mock_generate, mock_context, sample_embedding, sample_search_results ): """Test that settings values are used correctly.""" # Customize settings mock_context.request_context.lifespan_context.settings.embedding_model = "custom-model" mock_context.request_context.lifespan_context.settings.deployed_index_id = "custom-index" mock_context.request_context.lifespan_context.settings.search_limit = 20 mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results mock_filter.return_value = sample_search_results mock_format.return_value = "results" result = await knowledge_search("test query", mock_context) # Verify embedding model assert mock_generate.call_args[0][1] == "custom-model" # Verify vector search parameters call_kwargs = mock_context.request_context.lifespan_context.vector_search.async_run_query.call_args.kwargs assert call_kwargs["deployed_index_id"] == "custom-index" assert call_kwargs["limit"] == 20 @patch('knowledge_search_mcp.__main__.generate_query_embedding') async def test_graceful_degradation_on_partial_failure( self, mock_generate, mock_context, sample_embedding ): """Test that errors are caught and returned as strings, not raised.""" mock_generate.return_value = (sample_embedding, None) mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = RuntimeError( "Critical failure" ) # Should not raise, should return error message result = await knowledge_search("test query", mock_context) assert isinstance(result, str) assert "Error performing vector search:" in result assert "Critical failure" in result