Files
searchbox/tests/test_engine/test_base_engine.py
2025-09-26 21:02:27 +00:00

247 lines
8.5 KiB
Python

from typing import Any
import pytest
from vector_search_mcp.engine.base_engine import BaseEngine
from vector_search_mcp.models import Condition, Match, MatchAny, SearchRow
class MockEngine(BaseEngine[dict[str, Any], str, dict]):
"""Mock engine for testing BaseEngine abstract functionality"""
def __init__(self):
self.transform_conditions_called = False
self.transform_response_called = False
self.run_similarity_query_called = False
def transform_conditions(self, conditions: list[Condition] | None) -> str | None:
self.transform_conditions_called = True
if not conditions:
return None
return f"transformed_{len(conditions)}_conditions"
def transform_response(self, response: dict[str, Any]) -> list[SearchRow]:
self.transform_response_called = True
return [
SearchRow(
chunk_id=str(i),
score=response.get(f"score_{i}", 0.5),
payload={"text": f"result_{i}"},
)
for i in range(response.get("count", 1))
]
async def run_similarity_query(
self,
embedding: list[float],
collection: str,
limit: int = 10,
conditions: str | None = None,
threshold: float | None = None,
) -> dict[str, Any]:
self.run_similarity_query_called = True
return {
"count": min(limit, 3),
"collection": collection,
"conditions": conditions,
"threshold": threshold,
"score_0": 0.95,
"score_1": 0.85,
"score_2": 0.75,
}
async def create_index(self, name: str, size: int) -> bool:
"""Mock implementation of create_index"""
return True
def transform_chunk(self, chunk) -> dict:
"""Mock implementation of transform_chunk"""
return {
"id": chunk.id,
"vector": chunk.vector,
"payload": chunk.payload.model_dump()
}
async def run_upload_chunk(self, index_name: str, chunk: dict) -> bool:
"""Mock implementation of run_upload_chunk"""
return True
class TestBaseEngine:
"""Test suite for BaseEngine abstract class"""
def test_base_engine_is_abstract(self):
"""Test that BaseEngine cannot be instantiated directly"""
with pytest.raises(TypeError):
BaseEngine()
def test_base_engine_abc_methods(self):
"""Test that BaseEngine has the correct abstract methods"""
abstract_methods = BaseEngine.__abstractmethods__
expected_methods = {
"transform_conditions",
"transform_response",
"run_similarity_query",
"create_index",
"transform_chunk",
"run_upload_chunk",
}
assert abstract_methods == expected_methods
def test_mock_engine_instantiation(self):
"""Test that mock engine can be instantiated"""
engine = MockEngine()
assert isinstance(engine, BaseEngine)
assert not engine.transform_conditions_called
assert not engine.transform_response_called
assert not engine.run_similarity_query_called
def test_transform_conditions_with_none(self):
"""Test transform_conditions with None input"""
engine = MockEngine()
result = engine.transform_conditions(None)
assert result is None
assert engine.transform_conditions_called
def test_transform_conditions_with_conditions(self):
"""Test transform_conditions with actual conditions"""
engine = MockEngine()
conditions = [
Match(key="category", value="test"),
MatchAny(key="tags", any=["tag1", "tag2"]),
]
result = engine.transform_conditions(conditions)
assert result == "transformed_2_conditions"
assert engine.transform_conditions_called
def test_transform_response(self):
"""Test transform_response with mock data"""
engine = MockEngine()
mock_response = {"count": 2, "score_0": 0.9, "score_1": 0.8}
result = engine.transform_response(mock_response)
assert len(result) == 2
assert all(isinstance(row, SearchRow) for row in result)
assert result[0].chunk_id == "0"
assert result[0].score == 0.9
assert result[1].chunk_id == "1"
assert result[1].score == 0.8
assert engine.transform_response_called
@pytest.mark.asyncio
async def test_run_similarity_query(self):
"""Test run_similarity_query with various parameters"""
engine = MockEngine()
embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
result = await engine.run_similarity_query(
embedding=embedding,
collection="test_collection",
limit=5,
conditions="test_conditions",
threshold=0.7,
)
assert isinstance(result, dict)
assert result["collection"] == "test_collection"
assert result["conditions"] == "test_conditions"
assert result["threshold"] == 0.7
assert result["count"] == 3 # min(limit=5, 3)
assert engine.run_similarity_query_called
@pytest.mark.asyncio
async def test_semantic_search_full_flow(self):
"""Test the complete semantic_search flow"""
engine = MockEngine()
vector = [0.1, 0.2, 0.3, 0.4, 0.5]
conditions = [Match(key="status", value="active")]
result = await engine.semantic_search(
embedding=vector,
collection="test_collection",
limit=2,
conditions=conditions,
threshold=0.8,
)
# Verify all methods were called
assert engine.transform_conditions_called
assert engine.run_similarity_query_called
assert engine.transform_response_called
# Verify result structure
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(row, SearchRow) for row in result)
assert result[0].score == 0.95
assert result[1].score == 0.85
@pytest.mark.asyncio
async def test_semantic_search_with_none_conditions(self):
"""Test semantic_search with None conditions"""
engine = MockEngine()
vector = [0.1, 0.2, 0.3]
result = await engine.semantic_search(
embedding=vector, collection="test_collection"
)
assert engine.transform_conditions_called
assert engine.run_similarity_query_called
assert engine.transform_response_called
assert isinstance(result, list)
assert len(result) == 3 # default limit from mock
@pytest.mark.asyncio
async def test_semantic_search_default_parameters(self):
"""Test semantic_search with default parameters"""
engine = MockEngine()
vector = [0.1, 0.2, 0.3]
result = await engine.semantic_search(
embedding=vector, collection="test_collection"
)
# Check that defaults were used (limit=10, conditions=None, threshold=None)
assert isinstance(result, list)
assert len(result) <= 10 # Should respect limit
def test_typing_constraints(self):
"""Test that the generic typing works correctly"""
engine = MockEngine()
# Verify the engine has the correct generic types
assert hasattr(engine, "transform_conditions")
assert hasattr(engine, "transform_response")
assert hasattr(engine, "run_similarity_query")
# The mock engine should work with dict[str, Any] and str types
conditions_result = engine.transform_conditions([])
assert conditions_result is None or isinstance(conditions_result, str)
response_result = engine.transform_response({"test": "data"})
assert isinstance(response_result, list)
assert all(isinstance(item, SearchRow) for item in response_result)
class IncompleteEngine(BaseEngine[str, int, str]):
"""Incomplete engine implementation for testing abstract method enforcement"""
def transform_conditions(self, conditions: list[Condition] | None) -> int | None:
return None
# Missing transform_response, run_similarity_query, create_index, transform_chunk, run_upload_chunk
class TestAbstractMethodEnforcement:
"""Test that abstract methods must be implemented"""
def test_incomplete_engine_cannot_be_instantiated(self):
"""Test that incomplete engine implementations cannot be instantiated"""
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
IncompleteEngine()