forked from innovacion/searchbox
270 lines
9.2 KiB
Python
270 lines
9.2 KiB
Python
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from vector_search_mcp.engine.base_engine import BaseEngine
|
|
from vector_search_mcp.models import Condition, Match, MatchAny, SearchRow
|
|
|
|
|
|
class MockEngine(BaseEngine[dict[str, Any], str, dict]):
|
|
"""Mock engine for testing BaseEngine abstract functionality"""
|
|
|
|
def __init__(self):
|
|
self.transform_conditions_called = False
|
|
self.transform_response_called = False
|
|
self.run_similarity_query_called = False
|
|
|
|
def transform_conditions(self, conditions: list[Condition] | None) -> str | None:
|
|
self.transform_conditions_called = True
|
|
if not conditions:
|
|
return None
|
|
return f"transformed_{len(conditions)}_conditions"
|
|
|
|
def transform_response(self, response: dict[str, Any]) -> list[SearchRow]:
|
|
self.transform_response_called = True
|
|
return [
|
|
SearchRow(
|
|
chunk_id=str(i),
|
|
score=response.get(f"score_{i}", 0.5),
|
|
payload={"text": f"result_{i}"},
|
|
)
|
|
for i in range(response.get("count", 1))
|
|
]
|
|
|
|
async def run_similarity_query(
|
|
self,
|
|
embedding: list[float],
|
|
collection: str,
|
|
limit: int = 10,
|
|
conditions: str | None = None,
|
|
threshold: float | None = None,
|
|
) -> dict[str, Any]:
|
|
self.run_similarity_query_called = True
|
|
return {
|
|
"count": min(limit, 3),
|
|
"collection": collection,
|
|
"conditions": conditions,
|
|
"threshold": threshold,
|
|
"score_0": 0.95,
|
|
"score_1": 0.85,
|
|
"score_2": 0.75,
|
|
}
|
|
|
|
async def create_index(self, name: str, size: int) -> bool:
|
|
"""Mock implementation of create_index"""
|
|
return True
|
|
|
|
def transform_chunk(self, chunk) -> dict:
|
|
"""Mock implementation of transform_chunk"""
|
|
return {
|
|
"id": chunk.id,
|
|
"vector": chunk.vector,
|
|
"payload": chunk.payload.model_dump()
|
|
}
|
|
|
|
async def run_upload_chunk(self, index_name: str, chunk: dict) -> bool:
|
|
"""Mock implementation of run_upload_chunk"""
|
|
return True
|
|
|
|
|
|
class TestBaseEngine:
|
|
"""Test suite for BaseEngine abstract class"""
|
|
|
|
def test_base_engine_is_abstract(self):
|
|
"""Test that BaseEngine cannot be instantiated directly"""
|
|
with pytest.raises(TypeError):
|
|
BaseEngine()
|
|
|
|
def test_base_engine_abc_methods(self):
|
|
"""Test that BaseEngine has the correct abstract methods"""
|
|
abstract_methods = BaseEngine.__abstractmethods__
|
|
expected_methods = {
|
|
"transform_conditions",
|
|
"transform_response",
|
|
"run_similarity_query",
|
|
"create_index",
|
|
"transform_chunk",
|
|
"run_upload_chunk",
|
|
}
|
|
assert abstract_methods == expected_methods
|
|
|
|
def test_mock_engine_instantiation(self):
|
|
"""Test that mock engine can be instantiated"""
|
|
engine = MockEngine()
|
|
assert isinstance(engine, BaseEngine)
|
|
assert not engine.transform_conditions_called
|
|
assert not engine.transform_response_called
|
|
assert not engine.run_similarity_query_called
|
|
|
|
def test_transform_conditions_with_none(self):
|
|
"""Test transform_conditions with None input"""
|
|
engine = MockEngine()
|
|
result = engine.transform_conditions(None)
|
|
|
|
assert result is None
|
|
assert engine.transform_conditions_called
|
|
|
|
def test_transform_conditions_with_conditions(self):
|
|
"""Test transform_conditions with actual conditions"""
|
|
engine = MockEngine()
|
|
conditions = [
|
|
Match(key="category", value="test"),
|
|
MatchAny(key="tags", any=["tag1", "tag2"]),
|
|
]
|
|
|
|
result = engine.transform_conditions(conditions)
|
|
|
|
assert result == "transformed_2_conditions"
|
|
assert engine.transform_conditions_called
|
|
|
|
def test_transform_response(self):
|
|
"""Test transform_response with mock data"""
|
|
engine = MockEngine()
|
|
mock_response = {"count": 2, "score_0": 0.9, "score_1": 0.8}
|
|
|
|
result = engine.transform_response(mock_response)
|
|
|
|
assert len(result) == 2
|
|
assert all(isinstance(row, SearchRow) for row in result)
|
|
assert result[0].chunk_id == "0"
|
|
assert result[0].score == 0.9
|
|
assert result[1].chunk_id == "1"
|
|
assert result[1].score == 0.8
|
|
assert engine.transform_response_called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_similarity_query(self):
|
|
"""Test run_similarity_query with various parameters"""
|
|
engine = MockEngine()
|
|
embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
|
|
result = await engine.run_similarity_query(
|
|
embedding=embedding,
|
|
collection="test_collection",
|
|
limit=5,
|
|
conditions="test_conditions",
|
|
threshold=0.7,
|
|
)
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["collection"] == "test_collection"
|
|
assert result["conditions"] == "test_conditions"
|
|
assert result["threshold"] == 0.7
|
|
assert result["count"] == 3 # min(limit=5, 3)
|
|
assert engine.run_similarity_query_called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_full_flow(self):
|
|
"""Test the complete semantic_search flow"""
|
|
engine = MockEngine()
|
|
vector = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
conditions = [Match(key="status", value="active")]
|
|
|
|
result = await engine.semantic_search(
|
|
embedding=vector,
|
|
collection="test_collection",
|
|
limit=2,
|
|
conditions=conditions,
|
|
threshold=0.8,
|
|
)
|
|
|
|
# Verify all methods were called
|
|
assert engine.transform_conditions_called
|
|
assert engine.run_similarity_query_called
|
|
assert engine.transform_response_called
|
|
|
|
# Verify result structure
|
|
assert isinstance(result, list)
|
|
assert len(result) == 2
|
|
assert all(isinstance(row, SearchRow) for row in result)
|
|
assert result[0].score == 0.95
|
|
assert result[1].score == 0.85
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_with_none_conditions(self):
|
|
"""Test semantic_search with None conditions"""
|
|
engine = MockEngine()
|
|
vector = [0.1, 0.2, 0.3]
|
|
|
|
result = await engine.semantic_search(
|
|
embedding=vector, collection="test_collection"
|
|
)
|
|
|
|
assert engine.transform_conditions_called
|
|
assert engine.run_similarity_query_called
|
|
assert engine.transform_response_called
|
|
assert isinstance(result, list)
|
|
assert len(result) == 3 # default limit from mock
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_default_parameters(self):
|
|
"""Test semantic_search with default parameters"""
|
|
engine = MockEngine()
|
|
vector = [0.1, 0.2, 0.3]
|
|
|
|
result = await engine.semantic_search(
|
|
embedding=vector, collection="test_collection"
|
|
)
|
|
|
|
# Check that defaults were used (limit=10, conditions=None, threshold=None)
|
|
assert isinstance(result, list)
|
|
assert len(result) <= 10 # Should respect limit
|
|
|
|
def test_typing_constraints(self):
|
|
"""Test that the generic typing works correctly"""
|
|
engine = MockEngine()
|
|
|
|
# Verify the engine has the correct generic types
|
|
assert hasattr(engine, "transform_conditions")
|
|
assert hasattr(engine, "transform_response")
|
|
assert hasattr(engine, "run_similarity_query")
|
|
|
|
# The mock engine should work with dict[str, Any] and str types
|
|
conditions_result = engine.transform_conditions([])
|
|
assert conditions_result is None or isinstance(conditions_result, str)
|
|
|
|
response_result = engine.transform_response({"test": "data"})
|
|
assert isinstance(response_result, list)
|
|
assert all(isinstance(item, SearchRow) for item in response_result)
|
|
|
|
|
|
class IncompleteEngine(BaseEngine[str, int, str]):
|
|
"""Incomplete engine implementation for testing abstract method enforcement"""
|
|
|
|
def transform_conditions(self, conditions: list[Condition] | None) -> int | None:
|
|
return None
|
|
|
|
# Missing transform_response, run_similarity_query, create_index, transform_chunk, run_upload_chunk
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_chunk_workflow(self):
|
|
"""Test the complete upload_chunk workflow"""
|
|
engine = MockEngine()
|
|
from vector_search_mcp.models import Chunk, ChunkData
|
|
|
|
chunk = Chunk(
|
|
id="test-chunk-1",
|
|
vector=[0.1, 0.2, 0.3],
|
|
payload=ChunkData(
|
|
page_content="Test content",
|
|
filename="test.pdf",
|
|
page=1
|
|
)
|
|
)
|
|
|
|
result = await engine.upload_chunk("test_index", chunk)
|
|
|
|
# Verify the workflow called both transform_chunk and run_upload_chunk
|
|
assert result is True
|
|
# The MockEngine.run_upload_chunk should have been called with transformed chunk
|
|
|
|
|
|
class TestAbstractMethodEnforcement:
|
|
"""Test that abstract methods must be implemented"""
|
|
|
|
def test_incomplete_engine_cannot_be_instantiated(self):
|
|
"""Test that incomplete engine implementations cannot be instantiated"""
|
|
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
|
IncompleteEngine()
|