forked from innovacion/searchbox
Add testing
This commit is contained in:
227
tests/test_engine/test_base_engine.py
Normal file
227
tests/test_engine/test_base_engine.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vector_search_mcp.engine.base_engine import BaseEngine
|
||||
from vector_search_mcp.models import Condition, Match, MatchAny, SearchRow
|
||||
|
||||
|
||||
class MockEngine(BaseEngine[dict[str, Any], str]):
|
||||
"""Mock engine for testing BaseEngine abstract functionality"""
|
||||
|
||||
def __init__(self):
|
||||
self.transform_conditions_called = False
|
||||
self.transform_response_called = False
|
||||
self.run_similarity_query_called = False
|
||||
|
||||
def transform_conditions(self, conditions: list[Condition] | None) -> str | None:
|
||||
self.transform_conditions_called = True
|
||||
if not conditions:
|
||||
return None
|
||||
return f"transformed_{len(conditions)}_conditions"
|
||||
|
||||
def transform_response(self, response: dict[str, Any]) -> list[SearchRow]:
|
||||
self.transform_response_called = True
|
||||
return [
|
||||
SearchRow(
|
||||
chunk_id=str(i),
|
||||
score=response.get(f"score_{i}", 0.5),
|
||||
payload={"text": f"result_{i}"},
|
||||
)
|
||||
for i in range(response.get("count", 1))
|
||||
]
|
||||
|
||||
async def run_similarity_query(
|
||||
self,
|
||||
embedding: list[float],
|
||||
collection: str,
|
||||
limit: int = 10,
|
||||
conditions: str | None = None,
|
||||
threshold: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
self.run_similarity_query_called = True
|
||||
return {
|
||||
"count": min(limit, 3),
|
||||
"collection": collection,
|
||||
"conditions": conditions,
|
||||
"threshold": threshold,
|
||||
"score_0": 0.95,
|
||||
"score_1": 0.85,
|
||||
"score_2": 0.75,
|
||||
}
|
||||
|
||||
|
||||
class TestBaseEngine:
|
||||
"""Test suite for BaseEngine abstract class"""
|
||||
|
||||
def test_base_engine_is_abstract(self):
|
||||
"""Test that BaseEngine cannot be instantiated directly"""
|
||||
with pytest.raises(TypeError):
|
||||
BaseEngine()
|
||||
|
||||
def test_base_engine_abc_methods(self):
|
||||
"""Test that BaseEngine has the correct abstract methods"""
|
||||
abstract_methods = BaseEngine.__abstractmethods__
|
||||
expected_methods = {
|
||||
"transform_conditions",
|
||||
"transform_response",
|
||||
"run_similarity_query",
|
||||
}
|
||||
assert abstract_methods == expected_methods
|
||||
|
||||
def test_mock_engine_instantiation(self):
|
||||
"""Test that mock engine can be instantiated"""
|
||||
engine = MockEngine()
|
||||
assert isinstance(engine, BaseEngine)
|
||||
assert not engine.transform_conditions_called
|
||||
assert not engine.transform_response_called
|
||||
assert not engine.run_similarity_query_called
|
||||
|
||||
def test_transform_conditions_with_none(self):
|
||||
"""Test transform_conditions with None input"""
|
||||
engine = MockEngine()
|
||||
result = engine.transform_conditions(None)
|
||||
|
||||
assert result is None
|
||||
assert engine.transform_conditions_called
|
||||
|
||||
def test_transform_conditions_with_conditions(self):
|
||||
"""Test transform_conditions with actual conditions"""
|
||||
engine = MockEngine()
|
||||
conditions = [
|
||||
Match(key="category", value="test"),
|
||||
MatchAny(key="tags", any=["tag1", "tag2"]),
|
||||
]
|
||||
|
||||
result = engine.transform_conditions(conditions)
|
||||
|
||||
assert result == "transformed_2_conditions"
|
||||
assert engine.transform_conditions_called
|
||||
|
||||
def test_transform_response(self):
|
||||
"""Test transform_response with mock data"""
|
||||
engine = MockEngine()
|
||||
mock_response = {"count": 2, "score_0": 0.9, "score_1": 0.8}
|
||||
|
||||
result = engine.transform_response(mock_response)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(row, SearchRow) for row in result)
|
||||
assert result[0].chunk_id == "0"
|
||||
assert result[0].score == 0.9
|
||||
assert result[1].chunk_id == "1"
|
||||
assert result[1].score == 0.8
|
||||
assert engine.transform_response_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_similarity_query(self):
|
||||
"""Test run_similarity_query with various parameters"""
|
||||
engine = MockEngine()
|
||||
embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
result = await engine.run_similarity_query(
|
||||
embedding=embedding,
|
||||
collection="test_collection",
|
||||
limit=5,
|
||||
conditions="test_conditions",
|
||||
threshold=0.7,
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["collection"] == "test_collection"
|
||||
assert result["conditions"] == "test_conditions"
|
||||
assert result["threshold"] == 0.7
|
||||
assert result["count"] == 3 # min(limit=5, 3)
|
||||
assert engine.run_similarity_query_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search_full_flow(self):
|
||||
"""Test the complete semantic_search flow"""
|
||||
engine = MockEngine()
|
||||
vector = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
conditions = [Match(key="status", value="active")]
|
||||
|
||||
result = await engine.semantic_search(
|
||||
embedding=vector,
|
||||
collection="test_collection",
|
||||
limit=2,
|
||||
conditions=conditions,
|
||||
threshold=0.8,
|
||||
)
|
||||
|
||||
# Verify all methods were called
|
||||
assert engine.transform_conditions_called
|
||||
assert engine.run_similarity_query_called
|
||||
assert engine.transform_response_called
|
||||
|
||||
# Verify result structure
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(row, SearchRow) for row in result)
|
||||
assert result[0].score == 0.95
|
||||
assert result[1].score == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search_with_none_conditions(self):
|
||||
"""Test semantic_search with None conditions"""
|
||||
engine = MockEngine()
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await engine.semantic_search(
|
||||
embedding=vector, collection="test_collection"
|
||||
)
|
||||
|
||||
assert engine.transform_conditions_called
|
||||
assert engine.run_similarity_query_called
|
||||
assert engine.transform_response_called
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3 # default limit from mock
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search_default_parameters(self):
|
||||
"""Test semantic_search with default parameters"""
|
||||
engine = MockEngine()
|
||||
vector = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await engine.semantic_search(
|
||||
embedding=vector, collection="test_collection"
|
||||
)
|
||||
|
||||
# Check that defaults were used (limit=10, conditions=None, threshold=None)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) <= 10 # Should respect limit
|
||||
|
||||
def test_typing_constraints(self):
|
||||
"""Test that the generic typing works correctly"""
|
||||
engine = MockEngine()
|
||||
|
||||
# Verify the engine has the correct generic types
|
||||
assert hasattr(engine, "transform_conditions")
|
||||
assert hasattr(engine, "transform_response")
|
||||
assert hasattr(engine, "run_similarity_query")
|
||||
|
||||
# The mock engine should work with dict[str, Any] and str types
|
||||
conditions_result = engine.transform_conditions([])
|
||||
assert conditions_result is None or isinstance(conditions_result, str)
|
||||
|
||||
response_result = engine.transform_response({"test": "data"})
|
||||
assert isinstance(response_result, list)
|
||||
assert all(isinstance(item, SearchRow) for item in response_result)
|
||||
|
||||
|
||||
class IncompleteEngine(BaseEngine[str, int]):
|
||||
"""Incomplete engine implementation for testing abstract method enforcement"""
|
||||
|
||||
def transform_conditions(self, conditions: list[Condition] | None) -> int | None:
|
||||
return None
|
||||
|
||||
# Missing transform_response and run_similarity_query
|
||||
|
||||
|
||||
class TestAbstractMethodEnforcement:
|
||||
"""Test that abstract methods must be implemented"""
|
||||
|
||||
def test_incomplete_engine_cannot_be_instantiated(self):
|
||||
"""Test that incomplete engine implementations cannot be instantiated"""
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
IncompleteEngine()
|
||||
Reference in New Issue
Block a user