Add semantic caching

This commit is contained in:
2026-03-04 06:02:24 +00:00
parent 694b060fa4
commit e81aac2e29
9 changed files with 625 additions and 2 deletions

View File

@@ -28,6 +28,9 @@ class TestKnowledgeSearch:
app.settings.deployed_index_id = "test-deployed-index"
app.settings.search_limit = 10
# No semantic cache by default
app.semantic_cache = None
return app
@pytest.fixture

View File

@@ -0,0 +1,272 @@
"""Tests for the semantic cache service and its integration."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from knowledge_search_mcp.__main__ import knowledge_search
from knowledge_search_mcp.models import AppContext, SearchResult, SourceNamespace
from knowledge_search_mcp.services.semantic_cache import KnowledgeSemanticCache
class TestKnowledgeSemanticCache:
"""Unit tests for the KnowledgeSemanticCache wrapper."""
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
def test_init_creates_cache(self, mock_sc_cls, mock_vec_cls):
"""Test that __init__ creates the SemanticCache with correct params."""
mock_vectorizer = MagicMock()
mock_vec_cls.return_value = mock_vectorizer
KnowledgeSemanticCache(
redis_url="redis://localhost:6379",
name="test_cache",
vector_dims=3072,
distance_threshold=0.12,
ttl=3600,
)
mock_vec_cls.assert_called_once()
mock_sc_cls.assert_called_once_with(
name="test_cache",
distance_threshold=0.12,
ttl=3600,
redis_url="redis://localhost:6379",
vectorizer=mock_vectorizer,
overwrite=False,
)
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
async def test_check_returns_response_on_hit(self, mock_sc_cls, _mock_vec_cls):
"""Test cache check returns response when a similar vector is found."""
mock_inner = MagicMock()
mock_inner.acheck = AsyncMock(return_value=[
{"response": "cached answer", "prompt": "original q", "vector_distance": 0.05},
])
mock_sc_cls.return_value = mock_inner
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
result = await cache.check([0.1] * 3072)
assert result == "cached answer"
mock_inner.acheck.assert_awaited_once_with(
vector=[0.1] * 3072,
num_results=1,
)
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
async def test_check_returns_none_on_miss(self, mock_sc_cls, _mock_vec_cls):
"""Test cache check returns None when no similar vector is found."""
mock_inner = MagicMock()
mock_inner.acheck = AsyncMock(return_value=[])
mock_sc_cls.return_value = mock_inner
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
result = await cache.check([0.1] * 3072)
assert result is None
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
async def test_check_returns_none_on_error(self, mock_sc_cls, _mock_vec_cls):
"""Test cache check degrades gracefully on Redis errors."""
mock_inner = MagicMock()
mock_inner.acheck = AsyncMock(side_effect=ConnectionError("Redis down"))
mock_sc_cls.return_value = mock_inner
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
result = await cache.check([0.1] * 3072)
assert result is None
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
async def test_store_calls_astore(self, mock_sc_cls, _mock_vec_cls):
"""Test store delegates to SemanticCache.astore."""
mock_inner = MagicMock()
mock_inner.astore = AsyncMock()
mock_sc_cls.return_value = mock_inner
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
await cache.store("query", "response", [0.1] * 3072, {"key": "val"})
mock_inner.astore.assert_awaited_once_with(
prompt="query",
response="response",
vector=[0.1] * 3072,
metadata={"key": "val"},
)
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
async def test_store_does_not_raise_on_error(self, mock_sc_cls, _mock_vec_cls):
"""Test store degrades gracefully on Redis errors."""
mock_inner = MagicMock()
mock_inner.astore = AsyncMock(side_effect=ConnectionError("Redis down"))
mock_sc_cls.return_value = mock_inner
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
await cache.store("query", "response", [0.1] * 3072)
class TestKnowledgeSearchCacheIntegration:
"""Tests for cache integration in the knowledge_search tool."""
@pytest.fixture
def mock_cache(self):
"""Create a mock KnowledgeSemanticCache."""
cache = MagicMock(spec=KnowledgeSemanticCache)
cache.check = AsyncMock(return_value=None)
cache.store = AsyncMock()
return cache
@pytest.fixture
def mock_app_context(self, mock_cache):
"""Create a mock AppContext with semantic cache."""
app = MagicMock(spec=AppContext)
app.genai_client = MagicMock()
app.vector_search = MagicMock()
app.vector_search.async_run_query = AsyncMock()
app.settings = MagicMock()
app.settings.embedding_model = "gemini-embedding-001"
app.settings.deployed_index_id = "test-deployed-index"
app.settings.search_limit = 10
app.semantic_cache = mock_cache
return app
@pytest.fixture
def mock_context(self, mock_app_context):
"""Create a mock MCP Context."""
ctx = MagicMock()
ctx.request_context.lifespan_context = mock_app_context
return ctx
@pytest.fixture
def sample_embedding(self):
return [0.1] * 3072
@pytest.fixture
def sample_results(self) -> list[SearchResult]:
return [
{"id": "doc1", "distance": 0.95, "content": "Content 1"},
{"id": "doc2", "distance": 0.90, "content": "Content 2"},
]
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
async def test_cache_hit_skips_vector_search(
self, mock_generate, mock_context, sample_embedding, mock_cache
):
"""On cache hit, vector search is never called."""
mock_generate.return_value = (sample_embedding, None)
mock_cache.check.return_value = "cached result"
result = await knowledge_search("test query", mock_context)
assert result == "cached result"
mock_cache.check.assert_awaited_once_with(sample_embedding)
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called()
mock_cache.store.assert_not_awaited()
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
@patch("knowledge_search_mcp.__main__.filter_search_results")
@patch("knowledge_search_mcp.__main__.format_search_results")
async def test_cache_miss_stores_result(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_results,
mock_cache,
):
"""On cache miss, results are fetched and stored in cache."""
mock_generate.return_value = (sample_embedding, None)
mock_cache.check.return_value = None
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_results
mock_filter.return_value = sample_results
mock_format.return_value = "formatted results"
result = await knowledge_search("test query", mock_context)
assert result == "formatted results"
mock_cache.check.assert_awaited_once_with(sample_embedding)
mock_cache.store.assert_awaited_once_with(
"test query", "formatted results", sample_embedding,
)
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
@patch("knowledge_search_mcp.__main__.filter_search_results")
@patch("knowledge_search_mcp.__main__.format_search_results")
async def test_cache_skipped_when_source_filter_set(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_results,
mock_cache,
):
"""Cache is bypassed when a source filter is specified."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_results
mock_filter.return_value = sample_results
mock_format.return_value = "formatted results"
result = await knowledge_search(
"test query", mock_context, source=SourceNamespace.EDUCACION_FINANCIERA,
)
assert result == "formatted results"
mock_cache.check.assert_not_awaited()
mock_cache.store.assert_not_awaited()
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
@patch("knowledge_search_mcp.__main__.filter_search_results")
@patch("knowledge_search_mcp.__main__.format_search_results")
async def test_cache_not_stored_when_no_results(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
mock_cache,
):
"""Empty results are not stored in the cache."""
mock_generate.return_value = (sample_embedding, None)
mock_cache.check.return_value = None
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = []
mock_filter.return_value = []
mock_format.return_value = "No relevant documents found for your query."
result = await knowledge_search("test query", mock_context)
assert result == "No relevant documents found for your query."
mock_cache.store.assert_not_awaited()
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
@patch("knowledge_search_mcp.__main__.filter_search_results")
@patch("knowledge_search_mcp.__main__.format_search_results")
async def test_works_without_cache(
self,
mock_format,
mock_filter,
mock_generate,
mock_context,
sample_embedding,
sample_results,
):
"""Tool works normally when semantic_cache is None."""
mock_generate.return_value = (sample_embedding, None)
mock_context.request_context.lifespan_context.semantic_cache = None
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_results
mock_filter.return_value = sample_results
mock_format.return_value = "formatted results"
result = await knowledge_search("test query", mock_context)
assert result == "formatted results"