forked from innovacion/searchbox
506 lines
18 KiB
Python
506 lines
18 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from qdrant_client import models
|
|
from qdrant_client.models import ScoredPoint
|
|
|
|
from vector_search_mcp.engine.base_engine import BaseEngine
|
|
from vector_search_mcp.engine.qdrant_engine import QdrantEngine
|
|
from vector_search_mcp.models import Match, MatchAny, MatchExclude, SearchRow
|
|
|
|
|
|
class TestQdrantEngine:
|
|
"""Test suite for QdrantEngine"""
|
|
|
|
@pytest.fixture
|
|
def mock_client(self):
|
|
"""Create a mock Qdrant client"""
|
|
return AsyncMock()
|
|
|
|
@pytest.fixture
|
|
def mock_settings(self):
|
|
"""Create mock settings"""
|
|
settings = MagicMock()
|
|
settings.url = "http://localhost:6333"
|
|
settings.api_key = "test_api_key"
|
|
return settings
|
|
|
|
@pytest.fixture
|
|
def qdrant_engine(self, mock_client, mock_settings):
|
|
"""Create a QdrantEngine instance with mocked dependencies"""
|
|
with (
|
|
patch(
|
|
"vector_search_mcp.engine.qdrant_engine.Settings"
|
|
) as mock_settings_class,
|
|
patch(
|
|
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
|
) as mock_client_class,
|
|
):
|
|
mock_settings_class.return_value = mock_settings
|
|
mock_client_class.return_value = mock_client
|
|
|
|
engine = QdrantEngine()
|
|
engine.client = mock_client # Ensure we use our mock
|
|
return engine
|
|
|
|
def test_inheritance(self, qdrant_engine):
|
|
"""Test that QdrantEngine properly inherits from BaseEngine"""
|
|
assert isinstance(qdrant_engine, BaseEngine)
|
|
assert isinstance(qdrant_engine, QdrantEngine)
|
|
|
|
def test_typing_parameters(self):
|
|
"""Test that QdrantEngine has correct generic type parameters"""
|
|
# QdrantEngine should be BaseEngine[list[models.ScoredPoint], models.Filter]
|
|
# This is verified by the type checker, but we can test the methods exist
|
|
assert hasattr(QdrantEngine, "transform_conditions")
|
|
assert hasattr(QdrantEngine, "transform_response")
|
|
assert hasattr(QdrantEngine, "run_similarity_query")
|
|
|
|
def test_transform_conditions_none(self, qdrant_engine):
|
|
"""Test transform_conditions with None input"""
|
|
result = qdrant_engine.transform_conditions(None)
|
|
assert result is None
|
|
|
|
def test_transform_conditions_empty_list(self, qdrant_engine):
|
|
"""Test transform_conditions with empty list"""
|
|
result = qdrant_engine.transform_conditions([])
|
|
assert result is None
|
|
|
|
def test_transform_conditions_with_match(self, qdrant_engine):
|
|
"""Test transform_conditions with Match condition"""
|
|
conditions = [Match(key="category", value="document")]
|
|
|
|
result = qdrant_engine.transform_conditions(conditions)
|
|
|
|
assert isinstance(result, models.Filter)
|
|
assert len(result.must) == 1
|
|
|
|
condition = result.must[0]
|
|
assert isinstance(condition, models.FieldCondition)
|
|
assert condition.key == "category"
|
|
assert isinstance(condition.match, models.MatchValue)
|
|
assert condition.match.value == "document"
|
|
|
|
def test_transform_conditions_with_match_any(self, qdrant_engine):
|
|
"""Test transform_conditions with MatchAny condition"""
|
|
conditions = [MatchAny(key="tags", any=["python", "rust", "javascript"])]
|
|
|
|
result = qdrant_engine.transform_conditions(conditions)
|
|
|
|
assert isinstance(result, models.Filter)
|
|
assert len(result.must) == 1
|
|
|
|
condition = result.must[0]
|
|
assert isinstance(condition, models.FieldCondition)
|
|
assert condition.key == "tags"
|
|
assert isinstance(condition.match, models.MatchAny)
|
|
assert condition.match.any == ["python", "rust", "javascript"]
|
|
|
|
def test_transform_conditions_with_match_exclude(self, qdrant_engine):
|
|
"""Test transform_conditions with MatchExclude condition"""
|
|
conditions = [MatchExclude(key="status", exclude=["deleted", "archived"])]
|
|
|
|
result = qdrant_engine.transform_conditions(conditions)
|
|
|
|
assert isinstance(result, models.Filter)
|
|
assert len(result.must) == 1
|
|
|
|
condition = result.must[0]
|
|
assert isinstance(condition, models.FieldCondition)
|
|
assert condition.key == "status"
|
|
assert isinstance(condition.match, models.MatchExcept)
|
|
# MatchExcept uses 'except' parameter which conflicts with Python keyword
|
|
assert hasattr(condition.match, "except_")
|
|
|
|
def test_transform_conditions_multiple(self, qdrant_engine):
|
|
"""Test transform_conditions with multiple conditions"""
|
|
conditions = [
|
|
Match(key="type", value="article"),
|
|
MatchAny(key="language", any=["en", "es"]),
|
|
MatchExclude(key="status", exclude=["draft"]),
|
|
]
|
|
|
|
result = qdrant_engine.transform_conditions(conditions)
|
|
|
|
assert isinstance(result, models.Filter)
|
|
assert len(result.must) == 3
|
|
|
|
# Verify all conditions are FieldCondition instances
|
|
assert all(isinstance(cond, models.FieldCondition) for cond in result.must)
|
|
|
|
def test_transform_response_empty(self, qdrant_engine):
|
|
"""Test transform_response with empty results"""
|
|
response = []
|
|
result = qdrant_engine.transform_response(response)
|
|
assert result == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_index(self, qdrant_engine, mock_client):
|
|
"""Test create_index method"""
|
|
mock_client.create_collection.return_value = True
|
|
|
|
result = await qdrant_engine.create_index("test_collection", 384)
|
|
|
|
assert result is True
|
|
mock_client.create_collection.assert_called_once_with(
|
|
collection_name="test_collection",
|
|
vectors_config=models.VectorParams(
|
|
size=384, distance=models.Distance.COSINE
|
|
),
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_index_failure(self, qdrant_engine, mock_client):
|
|
"""Test create_index method when it fails"""
|
|
mock_client.create_collection.side_effect = Exception("Collection creation failed")
|
|
|
|
with pytest.raises(Exception, match="Collection creation failed"):
|
|
await qdrant_engine.create_index("failing_collection", 512)
|
|
|
|
def test_transform_chunk(self, qdrant_engine):
|
|
"""Test transform_chunk method"""
|
|
from vector_search_mcp.models import Chunk, ChunkData
|
|
|
|
chunk = Chunk(
|
|
id="test-chunk-1",
|
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
|
payload=ChunkData(
|
|
page_content="This is test content",
|
|
filename="test_doc.pdf",
|
|
page=42
|
|
)
|
|
)
|
|
|
|
result = qdrant_engine.transform_chunk(chunk)
|
|
|
|
assert isinstance(result, models.PointStruct)
|
|
assert result.id == "test-chunk-1"
|
|
assert result.vector == [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
assert result.payload == {
|
|
"page_content": "This is test content",
|
|
"filename": "test_doc.pdf",
|
|
"page": 42
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_upload_chunk(self, qdrant_engine, mock_client):
|
|
"""Test run_upload_chunk method"""
|
|
# Setup mock response
|
|
mock_response = MagicMock()
|
|
mock_response.status = models.UpdateStatus.ACKNOWLEDGED
|
|
mock_client.upsert.return_value = mock_response
|
|
|
|
# Create test point
|
|
test_point = models.PointStruct(
|
|
id="test-point-1",
|
|
vector=[0.1, 0.2, 0.3],
|
|
payload={"content": "test"}
|
|
)
|
|
|
|
result = await qdrant_engine.run_upload_chunk("test_index", test_point)
|
|
|
|
assert result is True
|
|
mock_client.upsert.assert_called_once_with(
|
|
collection_name="test_index",
|
|
points=[test_point]
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_upload_chunk_failure(self, qdrant_engine, mock_client):
|
|
"""Test run_upload_chunk method when upload fails"""
|
|
# Setup mock response with failure status
|
|
mock_response = MagicMock()
|
|
mock_response.status = models.UpdateStatus.COMPLETED # Not ACKNOWLEDGED
|
|
mock_client.upsert.return_value = mock_response
|
|
|
|
test_point = models.PointStruct(
|
|
id="test-point-1",
|
|
vector=[0.1, 0.2, 0.3],
|
|
payload={"content": "test"}
|
|
)
|
|
|
|
result = await qdrant_engine.run_upload_chunk("test_index", test_point)
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_chunk_integration(self, qdrant_engine, mock_client):
|
|
"""Test the complete upload_chunk workflow"""
|
|
from vector_search_mcp.models import Chunk, ChunkData
|
|
|
|
# Setup mock response
|
|
mock_response = MagicMock()
|
|
mock_response.status = models.UpdateStatus.ACKNOWLEDGED
|
|
mock_client.upsert.return_value = mock_response
|
|
|
|
chunk = Chunk(
|
|
id="integration-test-chunk",
|
|
vector=[0.5, 0.4, 0.3, 0.2, 0.1],
|
|
payload=ChunkData(
|
|
page_content="Integration test content",
|
|
filename="integration_test.pdf",
|
|
page=1
|
|
)
|
|
)
|
|
|
|
result = await qdrant_engine.upload_chunk("integration_collection", chunk)
|
|
|
|
assert result is True
|
|
# Verify the complete workflow: transform_chunk -> run_upload_chunk
|
|
mock_client.upsert.assert_called_once()
|
|
args, kwargs = mock_client.upsert.call_args
|
|
|
|
assert kwargs["collection_name"] == "integration_collection"
|
|
assert len(kwargs["points"]) == 1
|
|
|
|
uploaded_point = kwargs["points"][0]
|
|
assert uploaded_point.id == "integration-test-chunk"
|
|
assert uploaded_point.vector == [0.5, 0.4, 0.3, 0.2, 0.1]
|
|
assert uploaded_point.payload == {
|
|
"page_content": "Integration test content",
|
|
"filename": "integration_test.pdf",
|
|
"page": 1
|
|
}
|
|
|
|
def test_transform_response_with_scored_points(self, qdrant_engine):
|
|
"""Test transform_response with valid ScoredPoint objects"""
|
|
response = [
|
|
ScoredPoint(
|
|
id=1,
|
|
score=0.95,
|
|
payload={"text": "First document", "category": "tech"},
|
|
version=1,
|
|
),
|
|
ScoredPoint(
|
|
id=2,
|
|
score=0.87,
|
|
payload={"text": "Second document", "category": "science"},
|
|
version=1,
|
|
),
|
|
]
|
|
|
|
result = qdrant_engine.transform_response(response)
|
|
|
|
assert isinstance(result, list)
|
|
assert len(result) == 2
|
|
|
|
# Check first result
|
|
assert isinstance(result[0], SearchRow)
|
|
assert result[0].chunk_id == "1"
|
|
assert result[0].score == 0.95
|
|
assert result[0].payload == {"text": "First document", "category": "tech"}
|
|
|
|
# Check second result
|
|
assert isinstance(result[1], SearchRow)
|
|
assert result[1].chunk_id == "2"
|
|
assert result[1].score == 0.87
|
|
assert result[1].payload == {"text": "Second document", "category": "science"}
|
|
|
|
def test_transform_response_filters_none_payload(self, qdrant_engine):
|
|
"""Test transform_response filters out points with None payload"""
|
|
response = [
|
|
ScoredPoint(
|
|
id=1, score=0.95, payload={"text": "Valid document"}, version=1
|
|
),
|
|
ScoredPoint(
|
|
id=2,
|
|
score=0.87,
|
|
payload=None, # This should be filtered out
|
|
version=1,
|
|
),
|
|
ScoredPoint(
|
|
id=3, score=0.75, payload={"text": "Another valid document"}, version=1
|
|
),
|
|
]
|
|
|
|
result = qdrant_engine.transform_response(response)
|
|
|
|
assert isinstance(result, list)
|
|
assert len(result) == 2 # Only 2 valid results
|
|
assert result[0].chunk_id == "1"
|
|
assert result[1].chunk_id == "3"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_similarity_query_basic(self, qdrant_engine, mock_client):
|
|
"""Test run_similarity_query with basic parameters"""
|
|
# Setup mock response
|
|
mock_response = [
|
|
ScoredPoint(id=1, score=0.9, payload={"text": "Test document"}, version=1)
|
|
]
|
|
mock_client.search.return_value = mock_response
|
|
|
|
embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
collection = "test_collection"
|
|
|
|
result = await qdrant_engine.run_similarity_query(
|
|
embedding=embedding, collection=collection
|
|
)
|
|
|
|
# Verify client.search was called with correct parameters
|
|
mock_client.search.assert_called_once_with(
|
|
collection_name=collection,
|
|
query_vector=embedding,
|
|
query_filter=None,
|
|
limit=10, # default
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
score_threshold=None,
|
|
)
|
|
|
|
assert result == mock_response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_similarity_query_with_all_parameters(
|
|
self, qdrant_engine, mock_client
|
|
):
|
|
"""Test run_similarity_query with all parameters"""
|
|
mock_response = []
|
|
mock_client.search.return_value = mock_response
|
|
|
|
embedding = [0.1, 0.2, 0.3]
|
|
collection = "test_collection"
|
|
limit = 5
|
|
conditions = models.Filter(must=[])
|
|
threshold = 0.75
|
|
|
|
result = await qdrant_engine.run_similarity_query(
|
|
embedding=embedding,
|
|
collection=collection,
|
|
limit=limit,
|
|
conditions=conditions,
|
|
threshold=threshold,
|
|
)
|
|
|
|
mock_client.search.assert_called_once_with(
|
|
collection_name=collection,
|
|
query_vector=embedding,
|
|
query_filter=conditions,
|
|
limit=limit,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
score_threshold=threshold,
|
|
)
|
|
|
|
assert result == mock_response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_similarity_query_with_named_vector(
|
|
self, qdrant_engine, mock_client
|
|
):
|
|
"""Test run_similarity_query with NamedVector"""
|
|
mock_response = []
|
|
mock_client.search.return_value = mock_response
|
|
|
|
named_vector = models.NamedVector(name="text", vector=[0.1, 0.2, 0.3])
|
|
collection = "test_collection"
|
|
|
|
result = await qdrant_engine.run_similarity_query(
|
|
embedding=named_vector, collection=collection
|
|
)
|
|
|
|
mock_client.search.assert_called_once_with(
|
|
collection_name=collection,
|
|
query_vector=named_vector,
|
|
query_filter=None,
|
|
limit=10,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
score_threshold=None,
|
|
)
|
|
|
|
assert result == mock_response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_integration(self, qdrant_engine, mock_client):
|
|
"""Test the full semantic_search flow through QdrantEngine"""
|
|
# Setup mock response
|
|
mock_search_response = [
|
|
ScoredPoint(
|
|
id=1,
|
|
score=0.95,
|
|
payload={"text": "Python programming guide", "category": "tech"},
|
|
version=1,
|
|
),
|
|
ScoredPoint(
|
|
id=2,
|
|
score=0.87,
|
|
payload={"text": "Rust systems programming", "category": "tech"},
|
|
version=1,
|
|
),
|
|
]
|
|
mock_client.search.return_value = mock_search_response
|
|
|
|
# Test data
|
|
vector = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
collection = "documents"
|
|
conditions = [
|
|
Match(key="category", value="tech"),
|
|
MatchAny(key="language", any=["python", "rust"]),
|
|
]
|
|
|
|
result = await qdrant_engine.semantic_search(
|
|
embedding=vector,
|
|
collection=collection,
|
|
limit=5,
|
|
conditions=conditions,
|
|
threshold=0.8,
|
|
)
|
|
|
|
# Verify the search was called with transformed conditions
|
|
assert mock_client.search.called
|
|
call_args = mock_client.search.call_args
|
|
assert call_args[1]["collection_name"] == collection
|
|
assert call_args[1]["query_vector"] == vector
|
|
assert call_args[1]["limit"] == 5
|
|
assert call_args[1]["score_threshold"] == 0.8
|
|
assert isinstance(call_args[1]["query_filter"], models.Filter)
|
|
|
|
# Verify the response was transformed correctly
|
|
assert isinstance(result, list)
|
|
assert len(result) == 2
|
|
assert all(isinstance(row, SearchRow) for row in result)
|
|
assert result[0].chunk_id == "1"
|
|
assert result[0].score == 0.95
|
|
assert result[1].chunk_id == "2"
|
|
assert result[1].score == 0.87
|
|
|
|
def test_initialization_with_settings(self, mock_settings):
|
|
"""Test QdrantEngine initialization uses settings correctly"""
|
|
with (
|
|
patch(
|
|
"vector_search_mcp.engine.qdrant_engine.Settings"
|
|
) as mock_settings_class,
|
|
patch(
|
|
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
|
) as mock_client_class,
|
|
):
|
|
mock_settings_class.return_value = mock_settings
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value = mock_client
|
|
|
|
engine = QdrantEngine()
|
|
|
|
# Verify Settings was instantiated
|
|
mock_settings_class.assert_called_once()
|
|
|
|
# Verify AsyncQdrantClient was created with correct parameters
|
|
mock_client_class.assert_called_once_with(
|
|
url=mock_settings.url, api_key=mock_settings.api_key
|
|
)
|
|
|
|
assert engine.client == mock_client
|
|
assert engine.settings == mock_settings
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_search_exception_propagation(
|
|
self, qdrant_engine, mock_client
|
|
):
|
|
"""Test that exceptions from client.search are properly propagated"""
|
|
# Setup mock to raise an exception
|
|
mock_client.search.side_effect = Exception("Qdrant connection failed")
|
|
|
|
with pytest.raises(Exception, match="Qdrant connection failed"):
|
|
await qdrant_engine.run_similarity_query(
|
|
embedding=[0.1, 0.2, 0.3], collection="test_collection"
|
|
)
|