Files
knowledge-search-mcp/tests/test_main_tool.py
Anibal Angulo 132ea1c04f
Some checks failed
CI / lint (pull_request) Failing after 12s
CI / typecheck (pull_request) Successful in 13s
CI / test (pull_request) Failing after 27s
Add semantic caching
2026-03-05 22:10:46 +00:00

412 lines
17 KiB
Python

"""Tests for the main knowledge_search tool."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from knowledge_search_mcp.__main__ import knowledge_search
from knowledge_search_mcp.models import AppContext, SourceNamespace, SearchResult
class TestKnowledgeSearch:
"""Tests for knowledge_search tool function."""
@pytest.fixture
def mock_app_context(self):
"""Create a mock AppContext."""
app = MagicMock(spec=AppContext)
# Mock genai_client
app.genai_client = MagicMock()
# Mock vector_search
app.vector_search = MagicMock()
app.vector_search.async_run_query = AsyncMock()
# Mock settings
app.settings = MagicMock()
app.settings.embedding_model = "models/text-embedding-004"
app.settings.deployed_index_id = "test-deployed-index"
app.settings.search_limit = 10
# No semantic cache by default
app.semantic_cache = None
return app
@pytest.fixture
def mock_context(self, mock_app_context):
"""Create a mock MCP Context."""
ctx = MagicMock()
ctx.request_context = MagicMock()
ctx.request_context.lifespan_context = mock_app_context
return ctx
@pytest.fixture
def sample_embedding(self):
"""Create a sample embedding vector."""
return [0.1, 0.2, 0.3, 0.4, 0.5]
@pytest.fixture
def sample_search_results(self):
"""Create sample search results."""
results: list[SearchResult] = [
{"id": "doc1.txt", "distance": 0.95, "content": "First document content"},
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content"},
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content"},
]
return results
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_successful_search(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test successful search workflow."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
# Execute
result = await knowledge_search("What is financial education?", mock_context)
# Assert
assert result == "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
mock_generate.assert_called_once()
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
deployed_index_id="test-deployed-index",
query=sample_embedding,
limit=10,
source=None,
)
mock_filter.assert_called_once_with(sample_search_results)
mock_format.assert_called_once_with(sample_search_results)
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_embedding_generation_error(self, mock_generate, mock_context):
"""Test handling of embedding generation error."""
# Setup mock to return error
mock_generate.return_value = ([], "Error: API rate limit exceeded. Please try again later.")
# Execute
result = await knowledge_search("test query", mock_context)
# Assert
assert result == "Error: API rate limit exceeded. Please try again later."
mock_generate.assert_called_once()
# Vector search should not be called
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called()
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_empty_query_error(self, mock_generate, mock_context):
"""Test handling of empty query."""
# Setup mock to return error for empty query
mock_generate.return_value = ([], "Error: Query cannot be empty")
# Execute
result = await knowledge_search("", mock_context)
# Assert
assert result == "Error: Query cannot be empty"
mock_generate.assert_called_once()
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_vector_search_error(self, mock_generate, mock_context, sample_embedding):
"""Test handling of vector search error."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = Exception(
"Vector search service unavailable"
)
# Execute
result = await knowledge_search("test query", mock_context)
# Assert
assert "Error performing vector search:" in result
assert "Vector search service unavailable" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_empty_search_results(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding
):
"""Test handling of empty search results."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = []
mock_filter.return_value = []
mock_format.return_value = "No relevant documents found for your query."
# Execute
result = await knowledge_search("obscure query", mock_context)
# Assert
assert result == "No relevant documents found for your query."
mock_format.assert_called_once_with([])
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_filtered_results_empty(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test when filtering removes all results."""
# Setup mocks - results exist but get filtered out
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = [] # All filtered out
mock_format.return_value = "No relevant documents found for your query."
# Execute
result = await knowledge_search("test query", mock_context)
# Assert
assert result == "No relevant documents found for your query."
mock_filter.assert_called_once_with(sample_search_results)
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_source_filter_parameter(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test that source filter is passed correctly to vector search."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "formatted results"
# Execute with source filter
source_filter = SourceNamespace.EDUCACION_FINANCIERA
result = await knowledge_search("test query", mock_context, source=source_filter)
# Assert
assert result == "formatted results"
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
deployed_index_id="test-deployed-index",
query=sample_embedding,
limit=10,
source=source_filter,
)
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_all_source_filters(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test all available source filter values."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "results"
# Test each source filter
for source in SourceNamespace:
result = await knowledge_search("test query", mock_context, source=source)
assert result == "results"
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_vector_search_timeout(self, mock_generate, mock_context, sample_embedding):
"""Test handling of vector search timeout."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = TimeoutError(
"Request timed out"
)
result = await knowledge_search("test query", mock_context)
assert "Error performing vector search:" in result
assert "Request timed out" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_vector_search_connection_error(self, mock_generate, mock_context, sample_embedding):
"""Test handling of vector search connection error."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = ConnectionError(
"Connection refused"
)
result = await knowledge_search("test query", mock_context)
assert "Error performing vector search:" in result
assert "Connection refused" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
async def test_format_results_unexpected_error(
self,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test handling of unexpected error in format_search_results."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
# Mock format_search_results to raise an error
with patch('knowledge_search_mcp.__main__.format_search_results', side_effect=ValueError("Format error")):
result = await knowledge_search("test query", mock_context)
assert "Unexpected error during search:" in result
assert "Format error" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_filter_results_unexpected_error(self, mock_generate, mock_context, sample_embedding):
"""Test handling of unexpected error in filter_search_results."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = [
{"id": "doc1", "distance": 0.9, "content": "test"}
]
# Mock filter_search_results to raise an error
with patch('knowledge_search_mcp.__main__.filter_search_results', side_effect=TypeError("Filter error")):
result = await knowledge_search("test query", mock_context)
assert "Unexpected error during search:" in result
assert "Filter error" in result
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_long_query_truncation_in_logs(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test that long queries are handled correctly."""
# Setup mocks
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "results"
# Execute with very long query
long_query = "a" * 500
result = await knowledge_search(long_query, mock_context)
# Assert - should succeed
assert result == "results"
# Verify generate_query_embedding was called with full query
assert mock_generate.call_args[0][2] == long_query
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_multiple_results_returned(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding
):
"""Test handling of multiple search results."""
# Create larger result set
large_results: list[SearchResult] = [
{"id": f"doc{i}.txt", "distance": 0.9 - (i * 0.05), "content": f"Content {i}"}
for i in range(10)
]
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = large_results
mock_filter.return_value = large_results[:5] # Filter to top 5
mock_format.return_value = "formatted 5 results"
result = await knowledge_search("test query", mock_context)
assert result == "formatted 5 results"
mock_filter.assert_called_once_with(large_results)
mock_format.assert_called_once_with(large_results[:5])
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
@patch('knowledge_search_mcp.__main__.filter_search_results')
@patch('knowledge_search_mcp.__main__.format_search_results')
async def test_settings_values_used_correctly(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_search_results
):
"""Test that settings values are used correctly."""
# Customize settings
mock_context.request_context.lifespan_context.settings.embedding_model = "custom-model"
mock_context.request_context.lifespan_context.settings.deployed_index_id = "custom-index"
mock_context.request_context.lifespan_context.settings.search_limit = 20
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
mock_filter.return_value = sample_search_results
mock_format.return_value = "results"
result = await knowledge_search("test query", mock_context)
# Verify embedding model
assert mock_generate.call_args[0][1] == "custom-model"
# Verify vector search parameters
call_kwargs = mock_context.request_context.lifespan_context.vector_search.async_run_query.call_args.kwargs
assert call_kwargs["deployed_index_id"] == "custom-index"
assert call_kwargs["limit"] == 20
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
async def test_graceful_degradation_on_partial_failure(
self, mock_generate, mock_context, sample_embedding
):
"""Test that errors are caught and returned as strings, not raised."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = RuntimeError(
"Critical failure"
)
# Should not raise, should return error message
result = await knowledge_search("test query", mock_context)
assert isinstance(result, str)
assert "Error performing vector search:" in result
assert "Critical failure" in result