Rewrite tests while keeping 97% cov

This commit is contained in:
2025-09-27 16:26:03 +00:00
parent cf7e3d8244
commit 51606fc959
23 changed files with 180 additions and 2968 deletions

View File

@@ -1,505 +1,62 @@
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from qdrant_client import models
from qdrant_client.models import ScoredPoint
from searchbox.engine.base_engine import BaseEngine
from searchbox.engine.qdrant_engine import QdrantEngine
from searchbox.models import Match, MatchAny, MatchExclude, SearchRow
from searchbox.models import Chunk, Match, MatchAny, MatchExclude
class TestQdrantEngine:
"""Test suite for QdrantEngine"""
@pytest.fixture(scope="module")
def engine():
return QdrantEngine(":memory:")
@pytest.fixture
def mock_client(self):
"""Create a mock Qdrant client"""
return AsyncMock()
@pytest.fixture
def mock_settings(self):
"""Create mock settings"""
settings = MagicMock()
settings.url = "http://localhost:6333"
settings.api_key = "test_api_key"
return settings
async def test_create_index(engine: QdrantEngine):
result = await engine.create_index("test_index", 3)
@pytest.fixture
def qdrant_engine(self, mock_client, mock_settings):
"""Create a QdrantEngine instance with mocked dependencies"""
with (
patch(
"searchbox.engine.qdrant_engine.Settings"
) as mock_settings_class,
patch(
"searchbox.engine.qdrant_engine.AsyncQdrantClient"
) as mock_client_class,
):
mock_settings_class.return_value = mock_settings
mock_client_class.return_value = mock_client
assert result is True
engine = QdrantEngine()
engine.client = mock_client # Ensure we use our mock
return engine
def test_inheritance(self, qdrant_engine):
"""Test that QdrantEngine properly inherits from BaseEngine"""
assert isinstance(qdrant_engine, BaseEngine)
assert isinstance(qdrant_engine, QdrantEngine)
def test_typing_parameters(self):
"""Test that QdrantEngine has correct generic type parameters"""
# QdrantEngine should be BaseEngine[list[models.ScoredPoint], models.Filter]
# This is verified by the type checker, but we can test the methods exist
assert hasattr(QdrantEngine, "transform_conditions")
assert hasattr(QdrantEngine, "transform_response")
assert hasattr(QdrantEngine, "run_similarity_query")
def test_transform_conditions_none(self, qdrant_engine):
"""Test transform_conditions with None input"""
result = qdrant_engine.transform_conditions(None)
assert result is None
def test_transform_conditions_empty_list(self, qdrant_engine):
"""Test transform_conditions with empty list"""
result = qdrant_engine.transform_conditions([])
assert result is None
def test_transform_conditions_with_match(self, qdrant_engine):
"""Test transform_conditions with Match condition"""
conditions = [Match(key="category", value="document")]
result = qdrant_engine.transform_conditions(conditions)
assert isinstance(result, models.Filter)
assert len(result.must) == 1
condition = result.must[0]
assert isinstance(condition, models.FieldCondition)
assert condition.key == "category"
assert isinstance(condition.match, models.MatchValue)
assert condition.match.value == "document"
def test_transform_conditions_with_match_any(self, qdrant_engine):
"""Test transform_conditions with MatchAny condition"""
conditions = [MatchAny(key="tags", any=["python", "rust", "javascript"])]
result = qdrant_engine.transform_conditions(conditions)
assert isinstance(result, models.Filter)
assert len(result.must) == 1
condition = result.must[0]
assert isinstance(condition, models.FieldCondition)
assert condition.key == "tags"
assert isinstance(condition.match, models.MatchAny)
assert condition.match.any == ["python", "rust", "javascript"]
def test_transform_conditions_with_match_exclude(self, qdrant_engine):
"""Test transform_conditions with MatchExclude condition"""
conditions = [MatchExclude(key="status", exclude=["deleted", "archived"])]
result = qdrant_engine.transform_conditions(conditions)
assert isinstance(result, models.Filter)
assert len(result.must) == 1
condition = result.must[0]
assert isinstance(condition, models.FieldCondition)
assert condition.key == "status"
assert isinstance(condition.match, models.MatchExcept)
# MatchExcept uses 'except' parameter which conflicts with Python keyword
assert hasattr(condition.match, "except_")
def test_transform_conditions_multiple(self, qdrant_engine):
"""Test transform_conditions with multiple conditions"""
conditions = [
Match(key="type", value="article"),
MatchAny(key="language", any=["en", "es"]),
MatchExclude(key="status", exclude=["draft"]),
]
result = qdrant_engine.transform_conditions(conditions)
assert isinstance(result, models.Filter)
assert len(result.must) == 3
# Verify all conditions are FieldCondition instances
assert all(isinstance(cond, models.FieldCondition) for cond in result.must)
def test_transform_response_empty(self, qdrant_engine):
"""Test transform_response with empty results"""
response = []
result = qdrant_engine.transform_response(response)
assert result == []
@pytest.mark.asyncio
async def test_create_index(self, qdrant_engine, mock_client):
"""Test create_index method"""
mock_client.create_collection.return_value = True
result = await qdrant_engine.create_index("test_collection", 384)
assert result is True
mock_client.create_collection.assert_called_once_with(
collection_name="test_collection",
vectors_config=models.VectorParams(
size=384, distance=models.Distance.COSINE
),
)
@pytest.mark.asyncio
async def test_create_index_failure(self, qdrant_engine, mock_client):
"""Test create_index method when it fails"""
mock_client.create_collection.side_effect = Exception("Collection creation failed")
with pytest.raises(Exception, match="Collection creation failed"):
await qdrant_engine.create_index("failing_collection", 512)
def test_transform_chunk(self, qdrant_engine):
"""Test transform_chunk method"""
from searchbox.models import Chunk, ChunkData
chunk = Chunk(
id="test-chunk-1",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
payload=ChunkData(
page_content="This is test content",
filename="test_doc.pdf",
page=42
)
)
result = qdrant_engine.transform_chunk(chunk)
assert isinstance(result, models.PointStruct)
assert result.id == "test-chunk-1"
assert result.vector == [0.1, 0.2, 0.3, 0.4, 0.5]
assert result.payload == {
"page_content": "This is test content",
"filename": "test_doc.pdf",
"page": 42
async def test_upload_chunk(engine: QdrantEngine):
chunk = Chunk.model_validate(
{
"id": uuid4(),
"vector": [0.0, 0.1, 0.3],
"payload": {
"page_content": "This is a test page content.",
"filename": "test.txt",
"page": 1,
},
}
)
@pytest.mark.asyncio
async def test_run_upload_chunk(self, qdrant_engine, mock_client):
"""Test run_upload_chunk method"""
# Setup mock response
mock_response = MagicMock()
mock_response.status = models.UpdateStatus.ACKNOWLEDGED
mock_client.upsert.return_value = mock_response
result = await engine.upload_chunk("test_index", chunk)
# Create test point
test_point = models.PointStruct(
id="test-point-1",
vector=[0.1, 0.2, 0.3],
payload={"content": "test"}
)
assert result is True
result = await qdrant_engine.run_upload_chunk("test_index", test_point)
assert result is True
mock_client.upsert.assert_called_once_with(
collection_name="test_index",
points=[test_point]
)
async def test_search_chunk(engine: QdrantEngine):
result = await engine.semantic_search([0.0, 0.1, 0.3], "test_index")
@pytest.mark.asyncio
async def test_run_upload_chunk_failure(self, qdrant_engine, mock_client):
"""Test run_upload_chunk method when upload fails"""
# Setup mock response with failure status
mock_response = MagicMock()
mock_response.status = models.UpdateStatus.COMPLETED # Not ACKNOWLEDGED
mock_client.upsert.return_value = mock_response
assert len(result) == 1
test_point = models.PointStruct(
id="test-point-1",
vector=[0.1, 0.2, 0.3],
payload={"content": "test"}
)
first_result = result[0]
assert first_result.chunk_id is not None
assert first_result.score > 0.9
assert first_result.payload == {
"page_content": "This is a test page content.",
"filename": "test.txt",
"page": 1,
}
result = await qdrant_engine.run_upload_chunk("test_index", test_point)
assert result is False
async def test_search_chunk_with_conditions(engine: QdrantEngine):
conditions = [
Match(key="filename", value="test.md"),
MatchAny(key="filename", any=["test.md", "test.docx"]),
MatchExclude(key="filename", exclude=["test.txt"]),
]
@pytest.mark.asyncio
async def test_upload_chunk_integration(self, qdrant_engine, mock_client):
"""Test the complete upload_chunk workflow"""
from searchbox.models import Chunk, ChunkData
result = await engine.semantic_search(
[0.0, 0.1, 0.3], "test_index", conditions=conditions
)
# Setup mock response
mock_response = MagicMock()
mock_response.status = models.UpdateStatus.ACKNOWLEDGED
mock_client.upsert.return_value = mock_response
chunk = Chunk(
id="integration-test-chunk",
vector=[0.5, 0.4, 0.3, 0.2, 0.1],
payload=ChunkData(
page_content="Integration test content",
filename="integration_test.pdf",
page=1
)
)
result = await qdrant_engine.upload_chunk("integration_collection", chunk)
assert result is True
# Verify the complete workflow: transform_chunk -> run_upload_chunk
mock_client.upsert.assert_called_once()
args, kwargs = mock_client.upsert.call_args
assert kwargs["collection_name"] == "integration_collection"
assert len(kwargs["points"]) == 1
uploaded_point = kwargs["points"][0]
assert uploaded_point.id == "integration-test-chunk"
assert uploaded_point.vector == [0.5, 0.4, 0.3, 0.2, 0.1]
assert uploaded_point.payload == {
"page_content": "Integration test content",
"filename": "integration_test.pdf",
"page": 1
}
def test_transform_response_with_scored_points(self, qdrant_engine):
"""Test transform_response with valid ScoredPoint objects"""
response = [
ScoredPoint(
id=1,
score=0.95,
payload={"text": "First document", "category": "tech"},
version=1,
),
ScoredPoint(
id=2,
score=0.87,
payload={"text": "Second document", "category": "science"},
version=1,
),
]
result = qdrant_engine.transform_response(response)
assert isinstance(result, list)
assert len(result) == 2
# Check first result
assert isinstance(result[0], SearchRow)
assert result[0].chunk_id == "1"
assert result[0].score == 0.95
assert result[0].payload == {"text": "First document", "category": "tech"}
# Check second result
assert isinstance(result[1], SearchRow)
assert result[1].chunk_id == "2"
assert result[1].score == 0.87
assert result[1].payload == {"text": "Second document", "category": "science"}
def test_transform_response_filters_none_payload(self, qdrant_engine):
"""Test transform_response filters out points with None payload"""
response = [
ScoredPoint(
id=1, score=0.95, payload={"text": "Valid document"}, version=1
),
ScoredPoint(
id=2,
score=0.87,
payload=None, # This should be filtered out
version=1,
),
ScoredPoint(
id=3, score=0.75, payload={"text": "Another valid document"}, version=1
),
]
result = qdrant_engine.transform_response(response)
assert isinstance(result, list)
assert len(result) == 2 # Only 2 valid results
assert result[0].chunk_id == "1"
assert result[1].chunk_id == "3"
@pytest.mark.asyncio
async def test_run_similarity_query_basic(self, qdrant_engine, mock_client):
"""Test run_similarity_query with basic parameters"""
# Setup mock response
mock_response = [
ScoredPoint(id=1, score=0.9, payload={"text": "Test document"}, version=1)
]
mock_client.search.return_value = mock_response
embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
collection = "test_collection"
result = await qdrant_engine.run_similarity_query(
embedding=embedding, collection=collection
)
# Verify client.search was called with correct parameters
mock_client.search.assert_called_once_with(
collection_name=collection,
query_vector=embedding,
query_filter=None,
limit=10, # default
with_payload=True,
with_vectors=False,
score_threshold=None,
)
assert result == mock_response
@pytest.mark.asyncio
async def test_run_similarity_query_with_all_parameters(
self, qdrant_engine, mock_client
):
"""Test run_similarity_query with all parameters"""
mock_response = []
mock_client.search.return_value = mock_response
embedding = [0.1, 0.2, 0.3]
collection = "test_collection"
limit = 5
conditions = models.Filter(must=[])
threshold = 0.75
result = await qdrant_engine.run_similarity_query(
embedding=embedding,
collection=collection,
limit=limit,
conditions=conditions,
threshold=threshold,
)
mock_client.search.assert_called_once_with(
collection_name=collection,
query_vector=embedding,
query_filter=conditions,
limit=limit,
with_payload=True,
with_vectors=False,
score_threshold=threshold,
)
assert result == mock_response
@pytest.mark.asyncio
async def test_run_similarity_query_with_named_vector(
self, qdrant_engine, mock_client
):
"""Test run_similarity_query with NamedVector"""
mock_response = []
mock_client.search.return_value = mock_response
named_vector = models.NamedVector(name="text", vector=[0.1, 0.2, 0.3])
collection = "test_collection"
result = await qdrant_engine.run_similarity_query(
embedding=named_vector, collection=collection
)
mock_client.search.assert_called_once_with(
collection_name=collection,
query_vector=named_vector,
query_filter=None,
limit=10,
with_payload=True,
with_vectors=False,
score_threshold=None,
)
assert result == mock_response
@pytest.mark.asyncio
async def test_semantic_search_integration(self, qdrant_engine, mock_client):
"""Test the full semantic_search flow through QdrantEngine"""
# Setup mock response
mock_search_response = [
ScoredPoint(
id=1,
score=0.95,
payload={"text": "Python programming guide", "category": "tech"},
version=1,
),
ScoredPoint(
id=2,
score=0.87,
payload={"text": "Rust systems programming", "category": "tech"},
version=1,
),
]
mock_client.search.return_value = mock_search_response
# Test data
vector = [0.1, 0.2, 0.3, 0.4, 0.5]
collection = "documents"
conditions = [
Match(key="category", value="tech"),
MatchAny(key="language", any=["python", "rust"]),
]
result = await qdrant_engine.semantic_search(
embedding=vector,
collection=collection,
limit=5,
conditions=conditions,
threshold=0.8,
)
# Verify the search was called with transformed conditions
assert mock_client.search.called
call_args = mock_client.search.call_args
assert call_args[1]["collection_name"] == collection
assert call_args[1]["query_vector"] == vector
assert call_args[1]["limit"] == 5
assert call_args[1]["score_threshold"] == 0.8
assert isinstance(call_args[1]["query_filter"], models.Filter)
# Verify the response was transformed correctly
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(row, SearchRow) for row in result)
assert result[0].chunk_id == "1"
assert result[0].score == 0.95
assert result[1].chunk_id == "2"
assert result[1].score == 0.87
def test_initialization_with_settings(self, mock_settings):
"""Test QdrantEngine initialization uses settings correctly"""
with (
patch(
"searchbox.engine.qdrant_engine.Settings"
) as mock_settings_class,
patch(
"searchbox.engine.qdrant_engine.AsyncQdrantClient"
) as mock_client_class,
):
mock_settings_class.return_value = mock_settings
mock_client = AsyncMock()
mock_client_class.return_value = mock_client
engine = QdrantEngine()
# Verify Settings was instantiated
mock_settings_class.assert_called_once()
# Verify AsyncQdrantClient was created with correct parameters
mock_client_class.assert_called_once_with(
url=mock_settings.url, api_key=mock_settings.api_key
)
assert engine.client == mock_client
assert engine.settings == mock_settings
@pytest.mark.asyncio
async def test_client_search_exception_propagation(
self, qdrant_engine, mock_client
):
"""Test that exceptions from client.search are properly propagated"""
# Setup mock to raise an exception
mock_client.search.side_effect = Exception("Qdrant connection failed")
with pytest.raises(Exception, match="Qdrant connection failed"):
await qdrant_engine.run_similarity_query(
embedding=[0.1, 0.2, 0.3], collection="test_collection"
)
assert len(result) == 0