from unittest.mock import AsyncMock, MagicMock, patch import pytest from qdrant_client import models from qdrant_client.models import ScoredPoint from searchbox.engine.base_engine import BaseEngine from searchbox.engine.qdrant_engine import QdrantEngine from searchbox.models import Match, MatchAny, MatchExclude, SearchRow class TestQdrantEngine: """Test suite for QdrantEngine""" @pytest.fixture def mock_client(self): """Create a mock Qdrant client""" return AsyncMock() @pytest.fixture def mock_settings(self): """Create mock settings""" settings = MagicMock() settings.url = "http://localhost:6333" settings.api_key = "test_api_key" return settings @pytest.fixture def qdrant_engine(self, mock_client, mock_settings): """Create a QdrantEngine instance with mocked dependencies""" with ( patch( "searchbox.engine.qdrant_engine.Settings" ) as mock_settings_class, patch( "searchbox.engine.qdrant_engine.AsyncQdrantClient" ) as mock_client_class, ): mock_settings_class.return_value = mock_settings mock_client_class.return_value = mock_client engine = QdrantEngine() engine.client = mock_client # Ensure we use our mock return engine def test_inheritance(self, qdrant_engine): """Test that QdrantEngine properly inherits from BaseEngine""" assert isinstance(qdrant_engine, BaseEngine) assert isinstance(qdrant_engine, QdrantEngine) def test_typing_parameters(self): """Test that QdrantEngine has correct generic type parameters""" # QdrantEngine should be BaseEngine[list[models.ScoredPoint], models.Filter] # This is verified by the type checker, but we can test the methods exist assert hasattr(QdrantEngine, "transform_conditions") assert hasattr(QdrantEngine, "transform_response") assert hasattr(QdrantEngine, "run_similarity_query") def test_transform_conditions_none(self, qdrant_engine): """Test transform_conditions with None input""" result = qdrant_engine.transform_conditions(None) assert result is None def test_transform_conditions_empty_list(self, qdrant_engine): """Test transform_conditions with empty list""" result = qdrant_engine.transform_conditions([]) assert result is None def test_transform_conditions_with_match(self, qdrant_engine): """Test transform_conditions with Match condition""" conditions = [Match(key="category", value="document")] result = qdrant_engine.transform_conditions(conditions) assert isinstance(result, models.Filter) assert len(result.must) == 1 condition = result.must[0] assert isinstance(condition, models.FieldCondition) assert condition.key == "category" assert isinstance(condition.match, models.MatchValue) assert condition.match.value == "document" def test_transform_conditions_with_match_any(self, qdrant_engine): """Test transform_conditions with MatchAny condition""" conditions = [MatchAny(key="tags", any=["python", "rust", "javascript"])] result = qdrant_engine.transform_conditions(conditions) assert isinstance(result, models.Filter) assert len(result.must) == 1 condition = result.must[0] assert isinstance(condition, models.FieldCondition) assert condition.key == "tags" assert isinstance(condition.match, models.MatchAny) assert condition.match.any == ["python", "rust", "javascript"] def test_transform_conditions_with_match_exclude(self, qdrant_engine): """Test transform_conditions with MatchExclude condition""" conditions = [MatchExclude(key="status", exclude=["deleted", "archived"])] result = qdrant_engine.transform_conditions(conditions) assert isinstance(result, models.Filter) assert len(result.must) == 1 condition = result.must[0] assert isinstance(condition, models.FieldCondition) assert condition.key == "status" assert isinstance(condition.match, models.MatchExcept) # MatchExcept uses 'except' parameter which conflicts with Python keyword assert hasattr(condition.match, "except_") def test_transform_conditions_multiple(self, qdrant_engine): """Test transform_conditions with multiple conditions""" conditions = [ Match(key="type", value="article"), MatchAny(key="language", any=["en", "es"]), MatchExclude(key="status", exclude=["draft"]), ] result = qdrant_engine.transform_conditions(conditions) assert isinstance(result, models.Filter) assert len(result.must) == 3 # Verify all conditions are FieldCondition instances assert all(isinstance(cond, models.FieldCondition) for cond in result.must) def test_transform_response_empty(self, qdrant_engine): """Test transform_response with empty results""" response = [] result = qdrant_engine.transform_response(response) assert result == [] @pytest.mark.asyncio async def test_create_index(self, qdrant_engine, mock_client): """Test create_index method""" mock_client.create_collection.return_value = True result = await qdrant_engine.create_index("test_collection", 384) assert result is True mock_client.create_collection.assert_called_once_with( collection_name="test_collection", vectors_config=models.VectorParams( size=384, distance=models.Distance.COSINE ), ) @pytest.mark.asyncio async def test_create_index_failure(self, qdrant_engine, mock_client): """Test create_index method when it fails""" mock_client.create_collection.side_effect = Exception("Collection creation failed") with pytest.raises(Exception, match="Collection creation failed"): await qdrant_engine.create_index("failing_collection", 512) def test_transform_chunk(self, qdrant_engine): """Test transform_chunk method""" from searchbox.models import Chunk, ChunkData chunk = Chunk( id="test-chunk-1", vector=[0.1, 0.2, 0.3, 0.4, 0.5], payload=ChunkData( page_content="This is test content", filename="test_doc.pdf", page=42 ) ) result = qdrant_engine.transform_chunk(chunk) assert isinstance(result, models.PointStruct) assert result.id == "test-chunk-1" assert result.vector == [0.1, 0.2, 0.3, 0.4, 0.5] assert result.payload == { "page_content": "This is test content", "filename": "test_doc.pdf", "page": 42 } @pytest.mark.asyncio async def test_run_upload_chunk(self, qdrant_engine, mock_client): """Test run_upload_chunk method""" # Setup mock response mock_response = MagicMock() mock_response.status = models.UpdateStatus.ACKNOWLEDGED mock_client.upsert.return_value = mock_response # Create test point test_point = models.PointStruct( id="test-point-1", vector=[0.1, 0.2, 0.3], payload={"content": "test"} ) result = await qdrant_engine.run_upload_chunk("test_index", test_point) assert result is True mock_client.upsert.assert_called_once_with( collection_name="test_index", points=[test_point] ) @pytest.mark.asyncio async def test_run_upload_chunk_failure(self, qdrant_engine, mock_client): """Test run_upload_chunk method when upload fails""" # Setup mock response with failure status mock_response = MagicMock() mock_response.status = models.UpdateStatus.COMPLETED # Not ACKNOWLEDGED mock_client.upsert.return_value = mock_response test_point = models.PointStruct( id="test-point-1", vector=[0.1, 0.2, 0.3], payload={"content": "test"} ) result = await qdrant_engine.run_upload_chunk("test_index", test_point) assert result is False @pytest.mark.asyncio async def test_upload_chunk_integration(self, qdrant_engine, mock_client): """Test the complete upload_chunk workflow""" from searchbox.models import Chunk, ChunkData # Setup mock response mock_response = MagicMock() mock_response.status = models.UpdateStatus.ACKNOWLEDGED mock_client.upsert.return_value = mock_response chunk = Chunk( id="integration-test-chunk", vector=[0.5, 0.4, 0.3, 0.2, 0.1], payload=ChunkData( page_content="Integration test content", filename="integration_test.pdf", page=1 ) ) result = await qdrant_engine.upload_chunk("integration_collection", chunk) assert result is True # Verify the complete workflow: transform_chunk -> run_upload_chunk mock_client.upsert.assert_called_once() args, kwargs = mock_client.upsert.call_args assert kwargs["collection_name"] == "integration_collection" assert len(kwargs["points"]) == 1 uploaded_point = kwargs["points"][0] assert uploaded_point.id == "integration-test-chunk" assert uploaded_point.vector == [0.5, 0.4, 0.3, 0.2, 0.1] assert uploaded_point.payload == { "page_content": "Integration test content", "filename": "integration_test.pdf", "page": 1 } def test_transform_response_with_scored_points(self, qdrant_engine): """Test transform_response with valid ScoredPoint objects""" response = [ ScoredPoint( id=1, score=0.95, payload={"text": "First document", "category": "tech"}, version=1, ), ScoredPoint( id=2, score=0.87, payload={"text": "Second document", "category": "science"}, version=1, ), ] result = qdrant_engine.transform_response(response) assert isinstance(result, list) assert len(result) == 2 # Check first result assert isinstance(result[0], SearchRow) assert result[0].chunk_id == "1" assert result[0].score == 0.95 assert result[0].payload == {"text": "First document", "category": "tech"} # Check second result assert isinstance(result[1], SearchRow) assert result[1].chunk_id == "2" assert result[1].score == 0.87 assert result[1].payload == {"text": "Second document", "category": "science"} def test_transform_response_filters_none_payload(self, qdrant_engine): """Test transform_response filters out points with None payload""" response = [ ScoredPoint( id=1, score=0.95, payload={"text": "Valid document"}, version=1 ), ScoredPoint( id=2, score=0.87, payload=None, # This should be filtered out version=1, ), ScoredPoint( id=3, score=0.75, payload={"text": "Another valid document"}, version=1 ), ] result = qdrant_engine.transform_response(response) assert isinstance(result, list) assert len(result) == 2 # Only 2 valid results assert result[0].chunk_id == "1" assert result[1].chunk_id == "3" @pytest.mark.asyncio async def test_run_similarity_query_basic(self, qdrant_engine, mock_client): """Test run_similarity_query with basic parameters""" # Setup mock response mock_response = [ ScoredPoint(id=1, score=0.9, payload={"text": "Test document"}, version=1) ] mock_client.search.return_value = mock_response embedding = [0.1, 0.2, 0.3, 0.4, 0.5] collection = "test_collection" result = await qdrant_engine.run_similarity_query( embedding=embedding, collection=collection ) # Verify client.search was called with correct parameters mock_client.search.assert_called_once_with( collection_name=collection, query_vector=embedding, query_filter=None, limit=10, # default with_payload=True, with_vectors=False, score_threshold=None, ) assert result == mock_response @pytest.mark.asyncio async def test_run_similarity_query_with_all_parameters( self, qdrant_engine, mock_client ): """Test run_similarity_query with all parameters""" mock_response = [] mock_client.search.return_value = mock_response embedding = [0.1, 0.2, 0.3] collection = "test_collection" limit = 5 conditions = models.Filter(must=[]) threshold = 0.75 result = await qdrant_engine.run_similarity_query( embedding=embedding, collection=collection, limit=limit, conditions=conditions, threshold=threshold, ) mock_client.search.assert_called_once_with( collection_name=collection, query_vector=embedding, query_filter=conditions, limit=limit, with_payload=True, with_vectors=False, score_threshold=threshold, ) assert result == mock_response @pytest.mark.asyncio async def test_run_similarity_query_with_named_vector( self, qdrant_engine, mock_client ): """Test run_similarity_query with NamedVector""" mock_response = [] mock_client.search.return_value = mock_response named_vector = models.NamedVector(name="text", vector=[0.1, 0.2, 0.3]) collection = "test_collection" result = await qdrant_engine.run_similarity_query( embedding=named_vector, collection=collection ) mock_client.search.assert_called_once_with( collection_name=collection, query_vector=named_vector, query_filter=None, limit=10, with_payload=True, with_vectors=False, score_threshold=None, ) assert result == mock_response @pytest.mark.asyncio async def test_semantic_search_integration(self, qdrant_engine, mock_client): """Test the full semantic_search flow through QdrantEngine""" # Setup mock response mock_search_response = [ ScoredPoint( id=1, score=0.95, payload={"text": "Python programming guide", "category": "tech"}, version=1, ), ScoredPoint( id=2, score=0.87, payload={"text": "Rust systems programming", "category": "tech"}, version=1, ), ] mock_client.search.return_value = mock_search_response # Test data vector = [0.1, 0.2, 0.3, 0.4, 0.5] collection = "documents" conditions = [ Match(key="category", value="tech"), MatchAny(key="language", any=["python", "rust"]), ] result = await qdrant_engine.semantic_search( embedding=vector, collection=collection, limit=5, conditions=conditions, threshold=0.8, ) # Verify the search was called with transformed conditions assert mock_client.search.called call_args = mock_client.search.call_args assert call_args[1]["collection_name"] == collection assert call_args[1]["query_vector"] == vector assert call_args[1]["limit"] == 5 assert call_args[1]["score_threshold"] == 0.8 assert isinstance(call_args[1]["query_filter"], models.Filter) # Verify the response was transformed correctly assert isinstance(result, list) assert len(result) == 2 assert all(isinstance(row, SearchRow) for row in result) assert result[0].chunk_id == "1" assert result[0].score == 0.95 assert result[1].chunk_id == "2" assert result[1].score == 0.87 def test_initialization_with_settings(self, mock_settings): """Test QdrantEngine initialization uses settings correctly""" with ( patch( "searchbox.engine.qdrant_engine.Settings" ) as mock_settings_class, patch( "searchbox.engine.qdrant_engine.AsyncQdrantClient" ) as mock_client_class, ): mock_settings_class.return_value = mock_settings mock_client = AsyncMock() mock_client_class.return_value = mock_client engine = QdrantEngine() # Verify Settings was instantiated mock_settings_class.assert_called_once() # Verify AsyncQdrantClient was created with correct parameters mock_client_class.assert_called_once_with( url=mock_settings.url, api_key=mock_settings.api_key ) assert engine.client == mock_client assert engine.settings == mock_settings @pytest.mark.asyncio async def test_client_search_exception_propagation( self, qdrant_engine, mock_client ): """Test that exceptions from client.search are properly propagated""" # Setup mock to raise an exception mock_client.search.side_effect = Exception("Qdrant connection failed") with pytest.raises(Exception, match="Qdrant connection failed"): await qdrant_engine.run_similarity_query( embedding=[0.1, 0.2, 0.3], collection="test_collection" )