Add semantic caching
This commit is contained in:
@@ -28,6 +28,9 @@ class TestKnowledgeSearch:
|
||||
app.settings.deployed_index_id = "test-deployed-index"
|
||||
app.settings.search_limit = 10
|
||||
|
||||
# No semantic cache by default
|
||||
app.semantic_cache = None
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
272
tests/test_semantic_cache.py
Normal file
272
tests/test_semantic_cache.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Tests for the semantic cache service and its integration."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from knowledge_search_mcp.__main__ import knowledge_search
|
||||
from knowledge_search_mcp.models import AppContext, SearchResult, SourceNamespace
|
||||
from knowledge_search_mcp.services.semantic_cache import KnowledgeSemanticCache
|
||||
|
||||
|
||||
class TestKnowledgeSemanticCache:
|
||||
"""Unit tests for the KnowledgeSemanticCache wrapper."""
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
def test_init_creates_cache(self, mock_sc_cls, mock_vec_cls):
|
||||
"""Test that __init__ creates the SemanticCache with correct params."""
|
||||
mock_vectorizer = MagicMock()
|
||||
mock_vec_cls.return_value = mock_vectorizer
|
||||
|
||||
KnowledgeSemanticCache(
|
||||
redis_url="redis://localhost:6379",
|
||||
name="test_cache",
|
||||
vector_dims=3072,
|
||||
distance_threshold=0.12,
|
||||
ttl=3600,
|
||||
)
|
||||
|
||||
mock_vec_cls.assert_called_once()
|
||||
mock_sc_cls.assert_called_once_with(
|
||||
name="test_cache",
|
||||
distance_threshold=0.12,
|
||||
ttl=3600,
|
||||
redis_url="redis://localhost:6379",
|
||||
vectorizer=mock_vectorizer,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_check_returns_response_on_hit(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test cache check returns response when a similar vector is found."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.acheck = AsyncMock(return_value=[
|
||||
{"response": "cached answer", "prompt": "original q", "vector_distance": 0.05},
|
||||
])
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
result = await cache.check([0.1] * 3072)
|
||||
|
||||
assert result == "cached answer"
|
||||
mock_inner.acheck.assert_awaited_once_with(
|
||||
vector=[0.1] * 3072,
|
||||
num_results=1,
|
||||
)
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_check_returns_none_on_miss(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test cache check returns None when no similar vector is found."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.acheck = AsyncMock(return_value=[])
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
result = await cache.check([0.1] * 3072)
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_check_returns_none_on_error(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test cache check degrades gracefully on Redis errors."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.acheck = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
result = await cache.check([0.1] * 3072)
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_store_calls_astore(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test store delegates to SemanticCache.astore."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.astore = AsyncMock()
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
await cache.store("query", "response", [0.1] * 3072, {"key": "val"})
|
||||
|
||||
mock_inner.astore.assert_awaited_once_with(
|
||||
prompt="query",
|
||||
response="response",
|
||||
vector=[0.1] * 3072,
|
||||
metadata={"key": "val"},
|
||||
)
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_store_does_not_raise_on_error(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test store degrades gracefully on Redis errors."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.astore = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
await cache.store("query", "response", [0.1] * 3072)
|
||||
|
||||
|
||||
class TestKnowledgeSearchCacheIntegration:
|
||||
"""Tests for cache integration in the knowledge_search tool."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cache(self):
|
||||
"""Create a mock KnowledgeSemanticCache."""
|
||||
cache = MagicMock(spec=KnowledgeSemanticCache)
|
||||
cache.check = AsyncMock(return_value=None)
|
||||
cache.store = AsyncMock()
|
||||
return cache
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_context(self, mock_cache):
|
||||
"""Create a mock AppContext with semantic cache."""
|
||||
app = MagicMock(spec=AppContext)
|
||||
app.genai_client = MagicMock()
|
||||
app.vector_search = MagicMock()
|
||||
app.vector_search.async_run_query = AsyncMock()
|
||||
app.settings = MagicMock()
|
||||
app.settings.embedding_model = "gemini-embedding-001"
|
||||
app.settings.deployed_index_id = "test-deployed-index"
|
||||
app.settings.search_limit = 10
|
||||
app.semantic_cache = mock_cache
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(self, mock_app_context):
|
||||
"""Create a mock MCP Context."""
|
||||
ctx = MagicMock()
|
||||
ctx.request_context.lifespan_context = mock_app_context
|
||||
return ctx
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding(self):
|
||||
return [0.1] * 3072
|
||||
|
||||
@pytest.fixture
|
||||
def sample_results(self) -> list[SearchResult]:
|
||||
return [
|
||||
{"id": "doc1", "distance": 0.95, "content": "Content 1"},
|
||||
{"id": "doc2", "distance": 0.90, "content": "Content 2"},
|
||||
]
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
async def test_cache_hit_skips_vector_search(
|
||||
self, mock_generate, mock_context, sample_embedding, mock_cache
|
||||
):
|
||||
"""On cache hit, vector search is never called."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_cache.check.return_value = "cached result"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "cached result"
|
||||
mock_cache.check.assert_awaited_once_with(sample_embedding)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called()
|
||||
mock_cache.store.assert_not_awaited()
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
@patch("knowledge_search_mcp.__main__.filter_search_results")
|
||||
@patch("knowledge_search_mcp.__main__.format_search_results")
|
||||
async def test_cache_miss_stores_result(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_results,
|
||||
mock_cache,
|
||||
):
|
||||
"""On cache miss, results are fetched and stored in cache."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_cache.check.return_value = None
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_results
|
||||
mock_filter.return_value = sample_results
|
||||
mock_format.return_value = "formatted results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "formatted results"
|
||||
mock_cache.check.assert_awaited_once_with(sample_embedding)
|
||||
mock_cache.store.assert_awaited_once_with(
|
||||
"test query", "formatted results", sample_embedding,
|
||||
)
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
@patch("knowledge_search_mcp.__main__.filter_search_results")
|
||||
@patch("knowledge_search_mcp.__main__.format_search_results")
|
||||
async def test_cache_skipped_when_source_filter_set(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_results,
|
||||
mock_cache,
|
||||
):
|
||||
"""Cache is bypassed when a source filter is specified."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_results
|
||||
mock_filter.return_value = sample_results
|
||||
mock_format.return_value = "formatted results"
|
||||
|
||||
result = await knowledge_search(
|
||||
"test query", mock_context, source=SourceNamespace.EDUCACION_FINANCIERA,
|
||||
)
|
||||
|
||||
assert result == "formatted results"
|
||||
mock_cache.check.assert_not_awaited()
|
||||
mock_cache.store.assert_not_awaited()
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
@patch("knowledge_search_mcp.__main__.filter_search_results")
|
||||
@patch("knowledge_search_mcp.__main__.format_search_results")
|
||||
async def test_cache_not_stored_when_no_results(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
mock_cache,
|
||||
):
|
||||
"""Empty results are not stored in the cache."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_cache.check.return_value = None
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = []
|
||||
mock_filter.return_value = []
|
||||
mock_format.return_value = "No relevant documents found for your query."
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "No relevant documents found for your query."
|
||||
mock_cache.store.assert_not_awaited()
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
@patch("knowledge_search_mcp.__main__.filter_search_results")
|
||||
@patch("knowledge_search_mcp.__main__.format_search_results")
|
||||
async def test_works_without_cache(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_results,
|
||||
):
|
||||
"""Tool works normally when semantic_cache is None."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.semantic_cache = None
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_results
|
||||
mock_filter.return_value = sample_results
|
||||
mock_format.return_value = "formatted results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "formatted results"
|
||||
Reference in New Issue
Block a user