forked from innovacion/searchbox
Add testing
This commit is contained in:
380
tests/test_engine/test_qdrant_engine.py
Normal file
380
tests/test_engine/test_qdrant_engine.py
Normal file
@@ -0,0 +1,380 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from qdrant_client import models
|
||||
from qdrant_client.models import ScoredPoint
|
||||
|
||||
from vector_search_mcp.engine.base_engine import BaseEngine
|
||||
from vector_search_mcp.engine.qdrant_engine import QdrantEngine
|
||||
from vector_search_mcp.models import Match, MatchAny, MatchExclude, SearchRow
|
||||
|
||||
|
||||
class TestQdrantEngine:
|
||||
"""Test suite for QdrantEngine"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(self):
|
||||
"""Create a mock Qdrant client"""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings"""
|
||||
settings = MagicMock()
|
||||
settings.url = "http://localhost:6333"
|
||||
settings.api_key = "test_api_key"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_engine(self, mock_client, mock_settings):
|
||||
"""Create a QdrantEngine instance with mocked dependencies"""
|
||||
with (
|
||||
patch(
|
||||
"vector_search_mcp.engine.qdrant_engine.Settings"
|
||||
) as mock_settings_class,
|
||||
patch(
|
||||
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
||||
) as mock_client_class,
|
||||
):
|
||||
mock_settings_class.return_value = mock_settings
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
engine = QdrantEngine()
|
||||
engine.client = mock_client # Ensure we use our mock
|
||||
return engine
|
||||
|
||||
def test_inheritance(self, qdrant_engine):
|
||||
"""Test that QdrantEngine properly inherits from BaseEngine"""
|
||||
assert isinstance(qdrant_engine, BaseEngine)
|
||||
assert isinstance(qdrant_engine, QdrantEngine)
|
||||
|
||||
def test_typing_parameters(self):
|
||||
"""Test that QdrantEngine has correct generic type parameters"""
|
||||
# QdrantEngine should be BaseEngine[list[models.ScoredPoint], models.Filter]
|
||||
# This is verified by the type checker, but we can test the methods exist
|
||||
assert hasattr(QdrantEngine, "transform_conditions")
|
||||
assert hasattr(QdrantEngine, "transform_response")
|
||||
assert hasattr(QdrantEngine, "run_similarity_query")
|
||||
|
||||
def test_transform_conditions_none(self, qdrant_engine):
|
||||
"""Test transform_conditions with None input"""
|
||||
result = qdrant_engine.transform_conditions(None)
|
||||
assert result is None
|
||||
|
||||
def test_transform_conditions_empty_list(self, qdrant_engine):
|
||||
"""Test transform_conditions with empty list"""
|
||||
result = qdrant_engine.transform_conditions([])
|
||||
assert result is None
|
||||
|
||||
def test_transform_conditions_with_match(self, qdrant_engine):
|
||||
"""Test transform_conditions with Match condition"""
|
||||
conditions = [Match(key="category", value="document")]
|
||||
|
||||
result = qdrant_engine.transform_conditions(conditions)
|
||||
|
||||
assert isinstance(result, models.Filter)
|
||||
assert len(result.must) == 1
|
||||
|
||||
condition = result.must[0]
|
||||
assert isinstance(condition, models.FieldCondition)
|
||||
assert condition.key == "category"
|
||||
assert isinstance(condition.match, models.MatchValue)
|
||||
assert condition.match.value == "document"
|
||||
|
||||
def test_transform_conditions_with_match_any(self, qdrant_engine):
|
||||
"""Test transform_conditions with MatchAny condition"""
|
||||
conditions = [MatchAny(key="tags", any=["python", "rust", "javascript"])]
|
||||
|
||||
result = qdrant_engine.transform_conditions(conditions)
|
||||
|
||||
assert isinstance(result, models.Filter)
|
||||
assert len(result.must) == 1
|
||||
|
||||
condition = result.must[0]
|
||||
assert isinstance(condition, models.FieldCondition)
|
||||
assert condition.key == "tags"
|
||||
assert isinstance(condition.match, models.MatchAny)
|
||||
assert condition.match.any == ["python", "rust", "javascript"]
|
||||
|
||||
def test_transform_conditions_with_match_exclude(self, qdrant_engine):
|
||||
"""Test transform_conditions with MatchExclude condition"""
|
||||
conditions = [MatchExclude(key="status", exclude=["deleted", "archived"])]
|
||||
|
||||
result = qdrant_engine.transform_conditions(conditions)
|
||||
|
||||
assert isinstance(result, models.Filter)
|
||||
assert len(result.must) == 1
|
||||
|
||||
condition = result.must[0]
|
||||
assert isinstance(condition, models.FieldCondition)
|
||||
assert condition.key == "status"
|
||||
assert isinstance(condition.match, models.MatchExcept)
|
||||
# MatchExcept uses 'except' parameter which conflicts with Python keyword
|
||||
assert hasattr(condition.match, "except_")
|
||||
|
||||
def test_transform_conditions_multiple(self, qdrant_engine):
|
||||
"""Test transform_conditions with multiple conditions"""
|
||||
conditions = [
|
||||
Match(key="type", value="article"),
|
||||
MatchAny(key="language", any=["en", "es"]),
|
||||
MatchExclude(key="status", exclude=["draft"]),
|
||||
]
|
||||
|
||||
result = qdrant_engine.transform_conditions(conditions)
|
||||
|
||||
assert isinstance(result, models.Filter)
|
||||
assert len(result.must) == 3
|
||||
|
||||
# Verify all conditions are FieldCondition instances
|
||||
assert all(isinstance(cond, models.FieldCondition) for cond in result.must)
|
||||
|
||||
def test_transform_response_empty(self, qdrant_engine):
|
||||
"""Test transform_response with empty response"""
|
||||
response = []
|
||||
|
||||
result = qdrant_engine.transform_response(response)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_transform_response_with_scored_points(self, qdrant_engine):
|
||||
"""Test transform_response with valid ScoredPoint objects"""
|
||||
response = [
|
||||
ScoredPoint(
|
||||
id=1,
|
||||
score=0.95,
|
||||
payload={"text": "First document", "category": "tech"},
|
||||
version=1,
|
||||
),
|
||||
ScoredPoint(
|
||||
id=2,
|
||||
score=0.87,
|
||||
payload={"text": "Second document", "category": "science"},
|
||||
version=1,
|
||||
),
|
||||
]
|
||||
|
||||
result = qdrant_engine.transform_response(response)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
|
||||
# Check first result
|
||||
assert isinstance(result[0], SearchRow)
|
||||
assert result[0].chunk_id == "1"
|
||||
assert result[0].score == 0.95
|
||||
assert result[0].payload == {"text": "First document", "category": "tech"}
|
||||
|
||||
# Check second result
|
||||
assert isinstance(result[1], SearchRow)
|
||||
assert result[1].chunk_id == "2"
|
||||
assert result[1].score == 0.87
|
||||
assert result[1].payload == {"text": "Second document", "category": "science"}
|
||||
|
||||
def test_transform_response_filters_none_payload(self, qdrant_engine):
|
||||
"""Test transform_response filters out points with None payload"""
|
||||
response = [
|
||||
ScoredPoint(
|
||||
id=1, score=0.95, payload={"text": "Valid document"}, version=1
|
||||
),
|
||||
ScoredPoint(
|
||||
id=2,
|
||||
score=0.87,
|
||||
payload=None, # This should be filtered out
|
||||
version=1,
|
||||
),
|
||||
ScoredPoint(
|
||||
id=3, score=0.75, payload={"text": "Another valid document"}, version=1
|
||||
),
|
||||
]
|
||||
|
||||
result = qdrant_engine.transform_response(response)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2 # Only 2 valid results
|
||||
assert result[0].chunk_id == "1"
|
||||
assert result[1].chunk_id == "3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_similarity_query_basic(self, qdrant_engine, mock_client):
|
||||
"""Test run_similarity_query with basic parameters"""
|
||||
# Setup mock response
|
||||
mock_response = [
|
||||
ScoredPoint(id=1, score=0.9, payload={"text": "Test document"}, version=1)
|
||||
]
|
||||
mock_client.search.return_value = mock_response
|
||||
|
||||
embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
collection = "test_collection"
|
||||
|
||||
result = await qdrant_engine.run_similarity_query(
|
||||
embedding=embedding, collection=collection
|
||||
)
|
||||
|
||||
# Verify client.search was called with correct parameters
|
||||
mock_client.search.assert_called_once_with(
|
||||
collection_name=collection,
|
||||
query_vector=embedding,
|
||||
query_filter=None,
|
||||
limit=10, # default
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
score_threshold=None,
|
||||
)
|
||||
|
||||
assert result == mock_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_similarity_query_with_all_parameters(
|
||||
self, qdrant_engine, mock_client
|
||||
):
|
||||
"""Test run_similarity_query with all parameters"""
|
||||
mock_response = []
|
||||
mock_client.search.return_value = mock_response
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
collection = "test_collection"
|
||||
limit = 5
|
||||
conditions = models.Filter(must=[])
|
||||
threshold = 0.75
|
||||
|
||||
result = await qdrant_engine.run_similarity_query(
|
||||
embedding=embedding,
|
||||
collection=collection,
|
||||
limit=limit,
|
||||
conditions=conditions,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
mock_client.search.assert_called_once_with(
|
||||
collection_name=collection,
|
||||
query_vector=embedding,
|
||||
query_filter=conditions,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
score_threshold=threshold,
|
||||
)
|
||||
|
||||
assert result == mock_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_similarity_query_with_named_vector(
|
||||
self, qdrant_engine, mock_client
|
||||
):
|
||||
"""Test run_similarity_query with NamedVector"""
|
||||
mock_response = []
|
||||
mock_client.search.return_value = mock_response
|
||||
|
||||
named_vector = models.NamedVector(name="text", vector=[0.1, 0.2, 0.3])
|
||||
collection = "test_collection"
|
||||
|
||||
result = await qdrant_engine.run_similarity_query(
|
||||
embedding=named_vector, collection=collection
|
||||
)
|
||||
|
||||
mock_client.search.assert_called_once_with(
|
||||
collection_name=collection,
|
||||
query_vector=named_vector,
|
||||
query_filter=None,
|
||||
limit=10,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
score_threshold=None,
|
||||
)
|
||||
|
||||
assert result == mock_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search_integration(self, qdrant_engine, mock_client):
|
||||
"""Test the full semantic_search flow through QdrantEngine"""
|
||||
# Setup mock response
|
||||
mock_search_response = [
|
||||
ScoredPoint(
|
||||
id=1,
|
||||
score=0.95,
|
||||
payload={"text": "Python programming guide", "category": "tech"},
|
||||
version=1,
|
||||
),
|
||||
ScoredPoint(
|
||||
id=2,
|
||||
score=0.87,
|
||||
payload={"text": "Rust systems programming", "category": "tech"},
|
||||
version=1,
|
||||
),
|
||||
]
|
||||
mock_client.search.return_value = mock_search_response
|
||||
|
||||
# Test data
|
||||
vector = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
collection = "documents"
|
||||
conditions = [
|
||||
Match(key="category", value="tech"),
|
||||
MatchAny(key="language", any=["python", "rust"]),
|
||||
]
|
||||
|
||||
result = await qdrant_engine.semantic_search(
|
||||
embedding=vector,
|
||||
collection=collection,
|
||||
limit=5,
|
||||
conditions=conditions,
|
||||
threshold=0.8,
|
||||
)
|
||||
|
||||
# Verify the search was called with transformed conditions
|
||||
assert mock_client.search.called
|
||||
call_args = mock_client.search.call_args
|
||||
assert call_args[1]["collection_name"] == collection
|
||||
assert call_args[1]["query_vector"] == vector
|
||||
assert call_args[1]["limit"] == 5
|
||||
assert call_args[1]["score_threshold"] == 0.8
|
||||
assert isinstance(call_args[1]["query_filter"], models.Filter)
|
||||
|
||||
# Verify the response was transformed correctly
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(row, SearchRow) for row in result)
|
||||
assert result[0].chunk_id == "1"
|
||||
assert result[0].score == 0.95
|
||||
assert result[1].chunk_id == "2"
|
||||
assert result[1].score == 0.87
|
||||
|
||||
def test_initialization_with_settings(self, mock_settings):
|
||||
"""Test QdrantEngine initialization uses settings correctly"""
|
||||
with (
|
||||
patch(
|
||||
"vector_search_mcp.engine.qdrant_engine.Settings"
|
||||
) as mock_settings_class,
|
||||
patch(
|
||||
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
||||
) as mock_client_class,
|
||||
):
|
||||
mock_settings_class.return_value = mock_settings
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
engine = QdrantEngine()
|
||||
|
||||
# Verify Settings was instantiated
|
||||
mock_settings_class.assert_called_once()
|
||||
|
||||
# Verify AsyncQdrantClient was created with correct parameters
|
||||
mock_client_class.assert_called_once_with(
|
||||
url=mock_settings.url, api_key=mock_settings.api_key
|
||||
)
|
||||
|
||||
assert engine.client == mock_client
|
||||
assert engine.settings == mock_settings
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_search_exception_propagation(
|
||||
self, qdrant_engine, mock_client
|
||||
):
|
||||
"""Test that exceptions from client.search are properly propagated"""
|
||||
# Setup mock to raise an exception
|
||||
mock_client.search.side_effect = Exception("Qdrant connection failed")
|
||||
|
||||
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||
await qdrant_engine.run_similarity_query(
|
||||
embedding=[0.1, 0.2, 0.3], collection="test_collection"
|
||||
)
|
||||
Reference in New Issue
Block a user