Add more tests
This commit is contained in:
408
tests/test_main_tool.py
Normal file
408
tests/test_main_tool.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""Tests for the main knowledge_search tool."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from knowledge_search_mcp.__main__ import knowledge_search
|
||||
from knowledge_search_mcp.models import AppContext, SourceNamespace, SearchResult
|
||||
|
||||
|
||||
class TestKnowledgeSearch:
|
||||
"""Tests for knowledge_search tool function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_context(self):
|
||||
"""Create a mock AppContext."""
|
||||
app = MagicMock(spec=AppContext)
|
||||
|
||||
# Mock genai_client
|
||||
app.genai_client = MagicMock()
|
||||
|
||||
# Mock vector_search
|
||||
app.vector_search = MagicMock()
|
||||
app.vector_search.async_run_query = AsyncMock()
|
||||
|
||||
# Mock settings
|
||||
app.settings = MagicMock()
|
||||
app.settings.embedding_model = "models/text-embedding-004"
|
||||
app.settings.deployed_index_id = "test-deployed-index"
|
||||
app.settings.search_limit = 10
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(self, mock_app_context):
|
||||
"""Create a mock MCP Context."""
|
||||
ctx = MagicMock()
|
||||
ctx.request_context = MagicMock()
|
||||
ctx.request_context.lifespan_context = mock_app_context
|
||||
return ctx
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding(self):
|
||||
"""Create a sample embedding vector."""
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_results(self):
|
||||
"""Create sample search results."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": "First document content"},
|
||||
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content"},
|
||||
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content"},
|
||||
]
|
||||
return results
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_successful_search(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test successful search workflow."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("What is financial education?", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
|
||||
mock_generate.assert_called_once()
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
|
||||
deployed_index_id="test-deployed-index",
|
||||
query=sample_embedding,
|
||||
limit=10,
|
||||
source=None,
|
||||
)
|
||||
mock_filter.assert_called_once_with(sample_search_results)
|
||||
mock_format.assert_called_once_with(sample_search_results)
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_embedding_generation_error(self, mock_generate, mock_context):
|
||||
"""Test handling of embedding generation error."""
|
||||
# Setup mock to return error
|
||||
mock_generate.return_value = ([], "Error: API rate limit exceeded. Please try again later.")
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "Error: API rate limit exceeded. Please try again later."
|
||||
mock_generate.assert_called_once()
|
||||
# Vector search should not be called
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called()
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_empty_query_error(self, mock_generate, mock_context):
|
||||
"""Test handling of empty query."""
|
||||
# Setup mock to return error for empty query
|
||||
mock_generate.return_value = ([], "Error: Query cannot be empty")
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "Error: Query cannot be empty"
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_vector_search_error(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of vector search error."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = Exception(
|
||||
"Vector search service unavailable"
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Vector search service unavailable" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_empty_search_results(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding
|
||||
):
|
||||
"""Test handling of empty search results."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = []
|
||||
mock_filter.return_value = []
|
||||
mock_format.return_value = "No relevant documents found for your query."
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("obscure query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "No relevant documents found for your query."
|
||||
mock_format.assert_called_once_with([])
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_filtered_results_empty(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test when filtering removes all results."""
|
||||
# Setup mocks - results exist but get filtered out
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = [] # All filtered out
|
||||
mock_format.return_value = "No relevant documents found for your query."
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "No relevant documents found for your query."
|
||||
mock_filter.assert_called_once_with(sample_search_results)
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_source_filter_parameter(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test that source filter is passed correctly to vector search."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "formatted results"
|
||||
|
||||
# Execute with source filter
|
||||
source_filter = SourceNamespace.EDUCACION_FINANCIERA
|
||||
result = await knowledge_search("test query", mock_context, source=source_filter)
|
||||
|
||||
# Assert
|
||||
assert result == "formatted results"
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
|
||||
deployed_index_id="test-deployed-index",
|
||||
query=sample_embedding,
|
||||
limit=10,
|
||||
source=source_filter,
|
||||
)
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_all_source_filters(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test all available source filter values."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "results"
|
||||
|
||||
# Test each source filter
|
||||
for source in SourceNamespace:
|
||||
result = await knowledge_search("test query", mock_context, source=source)
|
||||
assert result == "results"
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_vector_search_timeout(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of vector search timeout."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = TimeoutError(
|
||||
"Request timed out"
|
||||
)
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Request timed out" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_vector_search_connection_error(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of vector search connection error."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = ConnectionError(
|
||||
"Connection refused"
|
||||
)
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Connection refused" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
async def test_format_results_unexpected_error(
|
||||
self,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test handling of unexpected error in format_search_results."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
|
||||
# Mock format_search_results to raise an error
|
||||
with patch('knowledge_search_mcp.__main__.format_search_results', side_effect=ValueError("Format error")):
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Unexpected error during search:" in result
|
||||
assert "Format error" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_filter_results_unexpected_error(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of unexpected error in filter_search_results."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = [
|
||||
{"id": "doc1", "distance": 0.9, "content": "test"}
|
||||
]
|
||||
|
||||
# Mock filter_search_results to raise an error
|
||||
with patch('knowledge_search_mcp.__main__.filter_search_results', side_effect=TypeError("Filter error")):
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Unexpected error during search:" in result
|
||||
assert "Filter error" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_long_query_truncation_in_logs(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test that long queries are handled correctly."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "results"
|
||||
|
||||
# Execute with very long query
|
||||
long_query = "a" * 500
|
||||
result = await knowledge_search(long_query, mock_context)
|
||||
|
||||
# Assert - should succeed
|
||||
assert result == "results"
|
||||
# Verify generate_query_embedding was called with full query
|
||||
assert mock_generate.call_args[0][2] == long_query
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_multiple_results_returned(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding
|
||||
):
|
||||
"""Test handling of multiple search results."""
|
||||
# Create larger result set
|
||||
large_results: list[SearchResult] = [
|
||||
{"id": f"doc{i}.txt", "distance": 0.9 - (i * 0.05), "content": f"Content {i}"}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = large_results
|
||||
mock_filter.return_value = large_results[:5] # Filter to top 5
|
||||
mock_format.return_value = "formatted 5 results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "formatted 5 results"
|
||||
mock_filter.assert_called_once_with(large_results)
|
||||
mock_format.assert_called_once_with(large_results[:5])
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_settings_values_used_correctly(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test that settings values are used correctly."""
|
||||
# Customize settings
|
||||
mock_context.request_context.lifespan_context.settings.embedding_model = "custom-model"
|
||||
mock_context.request_context.lifespan_context.settings.deployed_index_id = "custom-index"
|
||||
mock_context.request_context.lifespan_context.settings.search_limit = 20
|
||||
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Verify embedding model
|
||||
assert mock_generate.call_args[0][1] == "custom-model"
|
||||
|
||||
# Verify vector search parameters
|
||||
call_kwargs = mock_context.request_context.lifespan_context.vector_search.async_run_query.call_args.kwargs
|
||||
assert call_kwargs["deployed_index_id"] == "custom-index"
|
||||
assert call_kwargs["limit"] == 20
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_graceful_degradation_on_partial_failure(
|
||||
self, mock_generate, mock_context, sample_embedding
|
||||
):
|
||||
"""Test that errors are caught and returned as strings, not raised."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = RuntimeError(
|
||||
"Critical failure"
|
||||
)
|
||||
|
||||
# Should not raise, should return error message
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Critical failure" in result
|
||||
Reference in New Issue
Block a user