forked from innovacion/searchbox
Add testing
This commit is contained in:
195
tests/test_engine/README.md
Normal file
195
tests/test_engine/README.md
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
# Engine Module Test Suite
|
||||||
|
|
||||||
|
This directory contains comprehensive tests for the vector search engine module, covering all components from the abstract base engine to concrete implementations and the factory pattern.
|
||||||
|
|
||||||
|
## Test Structure
|
||||||
|
|
||||||
|
### Core Test Files
|
||||||
|
|
||||||
|
- **`test_base_engine.py`** - Tests for the abstract `BaseEngine` class
|
||||||
|
- **`test_qdrant_engine.py`** - Tests for the `QdrantEngine` implementation
|
||||||
|
- **`test_factory.py`** - Tests for the engine factory and `EngineType` enum
|
||||||
|
- **`test_integration.py`** - End-to-end integration tests
|
||||||
|
- **`conftest.py`** - Shared fixtures and test configuration
|
||||||
|
|
||||||
|
## Test Coverage
|
||||||
|
|
||||||
|
### BaseEngine Tests (`test_base_engine.py`)
|
||||||
|
- ✅ Abstract class enforcement - ensures BaseEngine cannot be instantiated
|
||||||
|
- ✅ Abstract method verification - validates required method signatures
|
||||||
|
- ✅ Generic typing constraints - tests TypeVar functionality
|
||||||
|
- ✅ Semantic search workflow - tests the complete search flow
|
||||||
|
- ✅ Condition transformation - tests with various condition types
|
||||||
|
- ✅ Response transformation - tests data structure conversion
|
||||||
|
- ✅ Parameter validation - tests default values and edge cases
|
||||||
|
|
||||||
|
### QdrantEngine Tests (`test_qdrant_engine.py`)
|
||||||
|
- ✅ Inheritance verification - ensures proper BaseEngine inheritance
|
||||||
|
- ✅ Generic type parameters - validates `BaseEngine[list[ScoredPoint], Filter]`
|
||||||
|
- ✅ Condition transformation - tests conversion to Qdrant Filter objects
|
||||||
|
- Match conditions → `MatchValue`
|
||||||
|
- MatchAny conditions → `MatchAny`
|
||||||
|
- MatchExclude conditions → `MatchExcept`
|
||||||
|
- ✅ Response transformation - tests `ScoredPoint` to `SearchRow` conversion
|
||||||
|
- ✅ Null payload filtering - ensures entries with null payloads are excluded
|
||||||
|
- ✅ Client interaction - mocks and verifies Qdrant client calls
|
||||||
|
- ✅ Error propagation - ensures client exceptions bubble up correctly
|
||||||
|
- ✅ Initialization - tests Settings and AsyncQdrantClient setup
|
||||||
|
|
||||||
|
### Factory Tests (`test_factory.py`)
|
||||||
|
- ✅ Backend type enumeration - tests `Backend` enum values and behavior
|
||||||
|
- ✅ Factory function typing - validates overload signatures for type safety
|
||||||
|
- ✅ Engine instantiation - tests creation of concrete engine instances
|
||||||
|
- ✅ Error handling - validates behavior with invalid inputs
|
||||||
|
- ✅ Caching behavior - ensures `@cache` decorator works correctly (same instances)
|
||||||
|
- ✅ COSMOS engine handling - tests NotImplementedError for unimplemented engines
|
||||||
|
- ✅ String enum behavior - tests StrEnum functionality and JSON serialization
|
||||||
|
|
||||||
|
### Integration Tests (`test_integration.py`)
|
||||||
|
- ✅ Complete workflow - factory → conditions → search → response transformation
|
||||||
|
- ✅ Parameter passing - verifies correct parameter flow through all layers
|
||||||
|
- ✅ Complex conditions - tests multiple condition types together
|
||||||
|
- ✅ Large result sets - tests handling of 100+ search results
|
||||||
|
- ✅ Edge cases - empty conditions, null payloads, error scenarios
|
||||||
|
- ✅ Named vectors - tests support for Qdrant NamedVector objects
|
||||||
|
- ✅ Multiple engine instances - tests independence and concurrent usage
|
||||||
|
|
||||||
|
## Test Fixtures (`conftest.py`)
|
||||||
|
|
||||||
|
### Automatic Fixtures
|
||||||
|
- `clear_engine_cache` - Auto-clears `@cache` decorator before/after each test
|
||||||
|
|
||||||
|
### Mock Objects
|
||||||
|
- `mock_qdrant_client` - AsyncMock for Qdrant client with default responses
|
||||||
|
- `mock_settings` - Mock Settings object with test configuration
|
||||||
|
- `mock_qdrant_engine_dependencies` - Complete mocked environment
|
||||||
|
|
||||||
|
### Sample Data
|
||||||
|
- `sample_embedding` - Standard test embedding vector
|
||||||
|
- `sample_conditions` - Common condition objects for testing
|
||||||
|
- `sample_scored_points` - Realistic ScoredPoint objects
|
||||||
|
- `sample_search_rows` - Expected SearchRow outputs
|
||||||
|
- `qdrant_filter_single/multiple` - Pre-built Qdrant Filter objects
|
||||||
|
|
||||||
|
## Running Tests
|
||||||
|
|
||||||
|
### Run All Engine Tests
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/test_engine/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run Specific Test File
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/test_engine/test_base_engine.py -v
|
||||||
|
uv run pytest tests/test_engine/test_qdrant_engine.py -v
|
||||||
|
uv run pytest tests/test_engine/test_factory.py -v
|
||||||
|
uv run pytest tests/test_engine/test_integration.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run Specific Test Class or Method
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/test_engine/test_factory.py::TestEngineFactory::test_get_engine_qdrant -v
|
||||||
|
uv run pytest tests/test_engine/test_integration.py::TestEngineIntegration -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Coverage Report
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/test_engine/ --cov=src/vector_search_mcp/engine --cov-report=html
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Patterns Used
|
||||||
|
|
||||||
|
### Mocking Strategy
|
||||||
|
- **External dependencies** - All external services (Qdrant client, Settings) are mocked
|
||||||
|
- **Dependency injection** - Tests inject mocks through constructor parameters
|
||||||
|
- **Return value control** - Mocks return predictable test data for assertions
|
||||||
|
|
||||||
|
### Async Testing
|
||||||
|
- Uses `@pytest.mark.asyncio` for async method testing
|
||||||
|
- `AsyncMock` objects for async client methods
|
||||||
|
- Proper await/async syntax throughout
|
||||||
|
|
||||||
|
### Type Testing
|
||||||
|
- Generic type parameter validation
|
||||||
|
- Overload signature verification
|
||||||
|
- Runtime type checking where applicable
|
||||||
|
|
||||||
|
### Error Testing
|
||||||
|
- Exception propagation validation
|
||||||
|
- Invalid input handling
|
||||||
|
- Boundary condition testing
|
||||||
|
|
||||||
|
## Key Test Insights
|
||||||
|
|
||||||
|
### Generic Typing Validation
|
||||||
|
The tests verify that the generic `BaseEngine[ResponseType, ConditionType]` pattern works correctly:
|
||||||
|
- `QdrantEngine` is typed as `BaseEngine[list[ScoredPoint], Filter]`
|
||||||
|
- Type checkers can verify correct usage at compile time
|
||||||
|
- Runtime behavior matches type declarations
|
||||||
|
|
||||||
|
### Factory Pattern Testing
|
||||||
|
The overload tests ensure proper type inference:
|
||||||
|
```python
|
||||||
|
# Type checker knows this returns QdrantEngine
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
# Type checker knows this returns BaseEngine (generic)
|
||||||
|
backend: Backend = some_variable
|
||||||
|
engine = get_engine(backend)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Caching Behavior
|
||||||
|
Tests verify that the `@cache` decorator works correctly:
|
||||||
|
```python
|
||||||
|
# Both calls return the same instance
|
||||||
|
engine1 = get_engine(Backend.QDRANT)
|
||||||
|
engine2 = get_engine(Backend.QDRANT)
|
||||||
|
assert engine1 is engine2 # Same instance due to caching
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration Flow Validation
|
||||||
|
Integration tests verify the complete data flow:
|
||||||
|
1. `get_engine()` creates proper engine instance
|
||||||
|
2. `semantic_search()` calls `transform_conditions()`
|
||||||
|
3. Transformed conditions passed to `run_similarity_query()`
|
||||||
|
4. Query response processed by `transform_response()`
|
||||||
|
5. Final `SearchRow` objects returned to caller
|
||||||
|
|
||||||
|
## Maintenance Notes
|
||||||
|
|
||||||
|
### Adding New Engine Types
|
||||||
|
When adding new engines:
|
||||||
|
1. Add enum value to `Backend`
|
||||||
|
2. Add overload signature to `get_engine()`
|
||||||
|
3. Update factory tests for new count and behavior
|
||||||
|
4. Create engine-specific test file following `test_qdrant_engine.py` pattern
|
||||||
|
5. Remember that `@cache` decorator will cache instances per backend type
|
||||||
|
|
||||||
|
### Mock Updates
|
||||||
|
If engine interfaces change:
|
||||||
|
1. Update fixture return types in `conftest.py`
|
||||||
|
2. Verify mock method signatures match real implementations
|
||||||
|
3. Update integration tests for new parameter flows
|
||||||
|
4. Ensure cache clearing fixture handles any new caching behavior
|
||||||
|
|
||||||
|
### Performance Testing
|
||||||
|
Current tests focus on correctness. For performance testing:
|
||||||
|
- Use `pytest-benchmark` for timing critical paths
|
||||||
|
- Test with realistic data sizes (1000+ embeddings)
|
||||||
|
- Mock network I/O but measure transformation logic
|
||||||
|
- Consider cache warming effects when benchmarking
|
||||||
|
|
||||||
|
## Recent Fixes Applied
|
||||||
|
|
||||||
|
### After Formatter Changes
|
||||||
|
The following issues were resolved after code formatting:
|
||||||
|
|
||||||
|
1. **Enum Rename**: `EngineType` → `Backend` - All tests updated
|
||||||
|
2. **Caching Addition**: `@cache` decorator added to `get_engine()`
|
||||||
|
- Tests updated to expect same instances (not different ones)
|
||||||
|
- Auto-cache clearing fixture added to `conftest.py`
|
||||||
|
3. **Mock Isolation**: Improved mock setup to prevent real network calls
|
||||||
|
- Proper patch contexts in all integration tests
|
||||||
|
- Cache clearing ensures clean test state
|
||||||
|
|
||||||
|
All 62 tests now pass successfully! 🎉
|
||||||
1
tests/test_engine/__init__.py
Normal file
1
tests/test_engine/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Test package for engine module
|
||||||
163
tests/test_engine/conftest.py
Normal file
163
tests/test_engine/conftest.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
from vector_search_mcp.engine import get_engine
|
||||||
|
from vector_search_mcp.models import Match, MatchAny, MatchExclude, SearchRow
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_engine_cache():
|
||||||
|
"""Clear the engine cache before each test for proper isolation"""
|
||||||
|
get_engine.cache_clear()
|
||||||
|
yield
|
||||||
|
get_engine.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_qdrant_client():
|
||||||
|
"""Create a mock Qdrant client for testing"""
|
||||||
|
client = AsyncMock()
|
||||||
|
|
||||||
|
# Default search response
|
||||||
|
client.search.return_value = [
|
||||||
|
models.ScoredPoint(
|
||||||
|
id=1,
|
||||||
|
score=0.95,
|
||||||
|
payload={"text": "Test document 1", "category": "test"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
models.ScoredPoint(
|
||||||
|
id=2,
|
||||||
|
score=0.85,
|
||||||
|
payload={"text": "Test document 2", "category": "test"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings():
|
||||||
|
"""Create mock settings for testing"""
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.url = "http://localhost:6333"
|
||||||
|
settings.api_key = "test_api_key"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_embedding():
|
||||||
|
"""Provide a sample embedding vector for testing"""
|
||||||
|
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_conditions():
|
||||||
|
"""Provide sample conditions for testing"""
|
||||||
|
return [
|
||||||
|
Match(key="category", value="document"),
|
||||||
|
MatchAny(key="tags", any=["python", "rust"]),
|
||||||
|
MatchExclude(key="status", exclude=["deleted"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_scored_points():
|
||||||
|
"""Provide sample ScoredPoint objects for testing"""
|
||||||
|
return [
|
||||||
|
models.ScoredPoint(
|
||||||
|
id=1,
|
||||||
|
score=0.95,
|
||||||
|
payload={"text": "First document", "category": "tech"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
models.ScoredPoint(
|
||||||
|
id=2,
|
||||||
|
score=0.87,
|
||||||
|
payload={"text": "Second document", "category": "science"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
models.ScoredPoint(
|
||||||
|
id=3,
|
||||||
|
score=0.75,
|
||||||
|
payload=None, # This should be filtered out
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_search_rows():
|
||||||
|
"""Provide sample SearchRow objects for testing"""
|
||||||
|
return [
|
||||||
|
SearchRow(
|
||||||
|
chunk_id="1",
|
||||||
|
score=0.95,
|
||||||
|
payload={"text": "First document", "category": "tech"},
|
||||||
|
),
|
||||||
|
SearchRow(
|
||||||
|
chunk_id="2",
|
||||||
|
score=0.87,
|
||||||
|
payload={"text": "Second document", "category": "science"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_qdrant_engine_dependencies():
|
||||||
|
"""Mock all external dependencies for QdrantEngine"""
|
||||||
|
with (
|
||||||
|
patch("vector_search_mcp.engine.qdrant_engine.Settings") as mock_settings_class,
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
||||||
|
) as mock_client_class,
|
||||||
|
):
|
||||||
|
# Setup mock settings
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.url = "http://localhost:6333"
|
||||||
|
mock_settings.api_key = "test_api_key"
|
||||||
|
mock_settings_class.return_value = mock_settings
|
||||||
|
|
||||||
|
# Setup mock client
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"settings_class": mock_settings_class,
|
||||||
|
"client_class": mock_client_class,
|
||||||
|
"settings": mock_settings,
|
||||||
|
"client": mock_client,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def qdrant_filter_single():
|
||||||
|
"""Create a single-condition Qdrant filter for testing"""
|
||||||
|
return models.Filter(
|
||||||
|
must=[
|
||||||
|
models.FieldCondition(
|
||||||
|
key="category", match=models.MatchValue(value="document")
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def qdrant_filter_multiple():
|
||||||
|
"""Create a multi-condition Qdrant filter for testing"""
|
||||||
|
return models.Filter(
|
||||||
|
must=[
|
||||||
|
models.FieldCondition(
|
||||||
|
key="category", match=models.MatchValue(value="document")
|
||||||
|
),
|
||||||
|
models.FieldCondition(
|
||||||
|
key="tags", match=models.MatchAny(any=["python", "rust"])
|
||||||
|
),
|
||||||
|
models.FieldCondition(
|
||||||
|
key="status", match=models.MatchExcept(**{"except": ["deleted"]})
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
227
tests/test_engine/test_base_engine.py
Normal file
227
tests/test_engine/test_base_engine.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vector_search_mcp.engine.base_engine import BaseEngine
|
||||||
|
from vector_search_mcp.models import Condition, Match, MatchAny, SearchRow
|
||||||
|
|
||||||
|
|
||||||
|
class MockEngine(BaseEngine[dict[str, Any], str]):
|
||||||
|
"""Mock engine for testing BaseEngine abstract functionality"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.transform_conditions_called = False
|
||||||
|
self.transform_response_called = False
|
||||||
|
self.run_similarity_query_called = False
|
||||||
|
|
||||||
|
def transform_conditions(self, conditions: list[Condition] | None) -> str | None:
|
||||||
|
self.transform_conditions_called = True
|
||||||
|
if not conditions:
|
||||||
|
return None
|
||||||
|
return f"transformed_{len(conditions)}_conditions"
|
||||||
|
|
||||||
|
def transform_response(self, response: dict[str, Any]) -> list[SearchRow]:
|
||||||
|
self.transform_response_called = True
|
||||||
|
return [
|
||||||
|
SearchRow(
|
||||||
|
chunk_id=str(i),
|
||||||
|
score=response.get(f"score_{i}", 0.5),
|
||||||
|
payload={"text": f"result_{i}"},
|
||||||
|
)
|
||||||
|
for i in range(response.get("count", 1))
|
||||||
|
]
|
||||||
|
|
||||||
|
async def run_similarity_query(
|
||||||
|
self,
|
||||||
|
embedding: list[float],
|
||||||
|
collection: str,
|
||||||
|
limit: int = 10,
|
||||||
|
conditions: str | None = None,
|
||||||
|
threshold: float | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
self.run_similarity_query_called = True
|
||||||
|
return {
|
||||||
|
"count": min(limit, 3),
|
||||||
|
"collection": collection,
|
||||||
|
"conditions": conditions,
|
||||||
|
"threshold": threshold,
|
||||||
|
"score_0": 0.95,
|
||||||
|
"score_1": 0.85,
|
||||||
|
"score_2": 0.75,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseEngine:
|
||||||
|
"""Test suite for BaseEngine abstract class"""
|
||||||
|
|
||||||
|
def test_base_engine_is_abstract(self):
|
||||||
|
"""Test that BaseEngine cannot be instantiated directly"""
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
BaseEngine()
|
||||||
|
|
||||||
|
def test_base_engine_abc_methods(self):
|
||||||
|
"""Test that BaseEngine has the correct abstract methods"""
|
||||||
|
abstract_methods = BaseEngine.__abstractmethods__
|
||||||
|
expected_methods = {
|
||||||
|
"transform_conditions",
|
||||||
|
"transform_response",
|
||||||
|
"run_similarity_query",
|
||||||
|
}
|
||||||
|
assert abstract_methods == expected_methods
|
||||||
|
|
||||||
|
def test_mock_engine_instantiation(self):
|
||||||
|
"""Test that mock engine can be instantiated"""
|
||||||
|
engine = MockEngine()
|
||||||
|
assert isinstance(engine, BaseEngine)
|
||||||
|
assert not engine.transform_conditions_called
|
||||||
|
assert not engine.transform_response_called
|
||||||
|
assert not engine.run_similarity_query_called
|
||||||
|
|
||||||
|
def test_transform_conditions_with_none(self):
|
||||||
|
"""Test transform_conditions with None input"""
|
||||||
|
engine = MockEngine()
|
||||||
|
result = engine.transform_conditions(None)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert engine.transform_conditions_called
|
||||||
|
|
||||||
|
def test_transform_conditions_with_conditions(self):
|
||||||
|
"""Test transform_conditions with actual conditions"""
|
||||||
|
engine = MockEngine()
|
||||||
|
conditions = [
|
||||||
|
Match(key="category", value="test"),
|
||||||
|
MatchAny(key="tags", any=["tag1", "tag2"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = engine.transform_conditions(conditions)
|
||||||
|
|
||||||
|
assert result == "transformed_2_conditions"
|
||||||
|
assert engine.transform_conditions_called
|
||||||
|
|
||||||
|
def test_transform_response(self):
|
||||||
|
"""Test transform_response with mock data"""
|
||||||
|
engine = MockEngine()
|
||||||
|
mock_response = {"count": 2, "score_0": 0.9, "score_1": 0.8}
|
||||||
|
|
||||||
|
result = engine.transform_response(mock_response)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(isinstance(row, SearchRow) for row in result)
|
||||||
|
assert result[0].chunk_id == "0"
|
||||||
|
assert result[0].score == 0.9
|
||||||
|
assert result[1].chunk_id == "1"
|
||||||
|
assert result[1].score == 0.8
|
||||||
|
assert engine.transform_response_called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_similarity_query(self):
|
||||||
|
"""Test run_similarity_query with various parameters"""
|
||||||
|
engine = MockEngine()
|
||||||
|
embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
|
||||||
|
result = await engine.run_similarity_query(
|
||||||
|
embedding=embedding,
|
||||||
|
collection="test_collection",
|
||||||
|
limit=5,
|
||||||
|
conditions="test_conditions",
|
||||||
|
threshold=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["collection"] == "test_collection"
|
||||||
|
assert result["conditions"] == "test_conditions"
|
||||||
|
assert result["threshold"] == 0.7
|
||||||
|
assert result["count"] == 3 # min(limit=5, 3)
|
||||||
|
assert engine.run_similarity_query_called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_semantic_search_full_flow(self):
|
||||||
|
"""Test the complete semantic_search flow"""
|
||||||
|
engine = MockEngine()
|
||||||
|
vector = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
conditions = [Match(key="status", value="active")]
|
||||||
|
|
||||||
|
result = await engine.semantic_search(
|
||||||
|
embedding=vector,
|
||||||
|
collection="test_collection",
|
||||||
|
limit=2,
|
||||||
|
conditions=conditions,
|
||||||
|
threshold=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify all methods were called
|
||||||
|
assert engine.transform_conditions_called
|
||||||
|
assert engine.run_similarity_query_called
|
||||||
|
assert engine.transform_response_called
|
||||||
|
|
||||||
|
# Verify result structure
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(isinstance(row, SearchRow) for row in result)
|
||||||
|
assert result[0].score == 0.95
|
||||||
|
assert result[1].score == 0.85
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_semantic_search_with_none_conditions(self):
|
||||||
|
"""Test semantic_search with None conditions"""
|
||||||
|
engine = MockEngine()
|
||||||
|
vector = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
result = await engine.semantic_search(
|
||||||
|
embedding=vector, collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert engine.transform_conditions_called
|
||||||
|
assert engine.run_similarity_query_called
|
||||||
|
assert engine.transform_response_called
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 3 # default limit from mock
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_semantic_search_default_parameters(self):
|
||||||
|
"""Test semantic_search with default parameters"""
|
||||||
|
engine = MockEngine()
|
||||||
|
vector = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
result = await engine.semantic_search(
|
||||||
|
embedding=vector, collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that defaults were used (limit=10, conditions=None, threshold=None)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) <= 10 # Should respect limit
|
||||||
|
|
||||||
|
def test_typing_constraints(self):
|
||||||
|
"""Test that the generic typing works correctly"""
|
||||||
|
engine = MockEngine()
|
||||||
|
|
||||||
|
# Verify the engine has the correct generic types
|
||||||
|
assert hasattr(engine, "transform_conditions")
|
||||||
|
assert hasattr(engine, "transform_response")
|
||||||
|
assert hasattr(engine, "run_similarity_query")
|
||||||
|
|
||||||
|
# The mock engine should work with dict[str, Any] and str types
|
||||||
|
conditions_result = engine.transform_conditions([])
|
||||||
|
assert conditions_result is None or isinstance(conditions_result, str)
|
||||||
|
|
||||||
|
response_result = engine.transform_response({"test": "data"})
|
||||||
|
assert isinstance(response_result, list)
|
||||||
|
assert all(isinstance(item, SearchRow) for item in response_result)
|
||||||
|
|
||||||
|
|
||||||
|
class IncompleteEngine(BaseEngine[str, int]):
|
||||||
|
"""Incomplete engine implementation for testing abstract method enforcement"""
|
||||||
|
|
||||||
|
def transform_conditions(self, conditions: list[Condition] | None) -> int | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Missing transform_response and run_similarity_query
|
||||||
|
|
||||||
|
|
||||||
|
class TestAbstractMethodEnforcement:
|
||||||
|
"""Test that abstract methods must be implemented"""
|
||||||
|
|
||||||
|
def test_incomplete_engine_cannot_be_instantiated(self):
|
||||||
|
"""Test that incomplete engine implementations cannot be instantiated"""
|
||||||
|
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||||
|
IncompleteEngine()
|
||||||
288
tests/test_engine/test_factory.py
Normal file
288
tests/test_engine/test_factory.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vector_search_mcp.engine import Backend, get_engine
|
||||||
|
from vector_search_mcp.engine.base_engine import BaseEngine
|
||||||
|
from vector_search_mcp.engine.qdrant_engine import QdrantEngine
|
||||||
|
|
||||||
|
|
||||||
|
class TestEngineFactory:
|
||||||
|
"""Test suite for get_engine factory function"""
|
||||||
|
|
||||||
|
def test_engine_type_enum_values(self):
|
||||||
|
"""Test that EngineType enum has expected values"""
|
||||||
|
assert Backend.QDRANT == "qdrant"
|
||||||
|
assert len(Backend) == 2 # QDRANT and COSMOS engine types
|
||||||
|
|
||||||
|
def test_get_engine_qdrant(self):
|
||||||
|
"""Test get_engine returns QdrantEngine for QDRANT type"""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.Settings"
|
||||||
|
) as mock_settings_class,
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
||||||
|
) as mock_client_class,
|
||||||
|
):
|
||||||
|
# Setup mocks
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.url = "http://localhost:6333"
|
||||||
|
mock_settings.api_key = "test_key"
|
||||||
|
mock_settings_class.return_value = mock_settings
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
# Test factory function
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
# Verify return type
|
||||||
|
assert isinstance(engine, QdrantEngine)
|
||||||
|
assert isinstance(engine, BaseEngine)
|
||||||
|
|
||||||
|
# Verify initialization was called correctly
|
||||||
|
mock_settings_class.assert_called_once()
|
||||||
|
mock_client_class.assert_called_once_with(
|
||||||
|
url=mock_settings.url, api_key=mock_settings.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_engine_invalid_type(self):
|
||||||
|
"""Test get_engine raises ValueError for unknown engine type"""
|
||||||
|
# Create an invalid engine type (bypassing enum validation)
|
||||||
|
invalid_type = "invalid_engine"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown engine type: invalid_engine"):
|
||||||
|
# We need to cast to bypass type checking
|
||||||
|
get_engine(invalid_type) # type: ignore
|
||||||
|
|
||||||
|
def test_get_engine_typing_literal_qdrant(self):
|
||||||
|
"""Test that get_engine with literal QDRANT returns correct type"""
|
||||||
|
with (
|
||||||
|
patch("vector_search_mcp.engine.qdrant_engine.Settings"),
|
||||||
|
patch("vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"),
|
||||||
|
):
|
||||||
|
# When using literal Backend.QDRANT, mypy should know it's QdrantEngine
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
# Runtime verification that it's the correct type
|
||||||
|
assert type(engine).__name__ == "QdrantEngine"
|
||||||
|
assert hasattr(engine, "client") # QdrantEngine specific attribute
|
||||||
|
assert hasattr(engine, "settings") # QdrantEngine specific attribute
|
||||||
|
|
||||||
|
def test_get_engine_typing_variable(self):
|
||||||
|
"""Test that get_engine with variable returns BaseEngine type"""
|
||||||
|
with (
|
||||||
|
patch("vector_search_mcp.engine.qdrant_engine.Settings"),
|
||||||
|
patch("vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"),
|
||||||
|
):
|
||||||
|
# When using a variable, mypy should see it as BaseEngine
|
||||||
|
engine_type: Backend = Backend.QDRANT
|
||||||
|
engine = get_engine(engine_type)
|
||||||
|
|
||||||
|
# Runtime verification - it's still a QdrantEngine but typed as BaseEngine
|
||||||
|
assert isinstance(engine, BaseEngine)
|
||||||
|
assert isinstance(engine, QdrantEngine)
|
||||||
|
|
||||||
|
def test_get_engine_uses_cache(self):
|
||||||
|
"""Test that get_engine uses cache and returns same instances"""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.Settings"
|
||||||
|
) as mock_settings_class,
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
||||||
|
) as mock_client_class,
|
||||||
|
):
|
||||||
|
# Setup mocks
|
||||||
|
mock_settings_class.return_value = MagicMock()
|
||||||
|
mock_client_class.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Create multiple engines
|
||||||
|
engine1 = get_engine(Backend.QDRANT)
|
||||||
|
engine2 = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
# Verify they are the same instance due to @cache decorator
|
||||||
|
assert engine1 is engine2
|
||||||
|
assert id(engine1) == id(engine2)
|
||||||
|
|
||||||
|
# But they are the same type
|
||||||
|
assert type(engine1) is type(engine2)
|
||||||
|
assert isinstance(engine1, QdrantEngine)
|
||||||
|
assert isinstance(engine2, QdrantEngine)
|
||||||
|
|
||||||
|
# Verify initialization was called only once due to caching
|
||||||
|
mock_settings_class.assert_called_once()
|
||||||
|
mock_client_class.assert_called_once()
|
||||||
|
|
||||||
|
def test_engine_type_string_values(self):
|
||||||
|
"""Test EngineType string representations"""
|
||||||
|
assert str(Backend.QDRANT) == "qdrant"
|
||||||
|
assert str(Backend.COSMOS) == "cosmos"
|
||||||
|
|
||||||
|
# Test that it can be used in string contexts
|
||||||
|
engine_name = f"engine_{Backend.QDRANT}"
|
||||||
|
assert engine_name == "engine_qdrant"
|
||||||
|
|
||||||
|
def test_engine_type_iteration(self):
|
||||||
|
"""Test that EngineType can be iterated over"""
|
||||||
|
engine_types = list(Backend)
|
||||||
|
assert len(engine_types) == 2
|
||||||
|
assert Backend.QDRANT in engine_types
|
||||||
|
assert Backend.COSMOS in engine_types
|
||||||
|
|
||||||
|
def test_engine_factory_integration(self):
|
||||||
|
"""Test complete factory integration with engine functionality"""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.Settings"
|
||||||
|
) as mock_settings_class,
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
||||||
|
) as mock_client_class,
|
||||||
|
):
|
||||||
|
# Setup mocks
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings_class.return_value = mock_settings
|
||||||
|
mock_client_class.return_value = MagicMock()
|
||||||
|
|
||||||
|
# Create engine through factory
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
# Verify engine has all required methods from BaseEngine
|
||||||
|
assert hasattr(engine, "transform_conditions")
|
||||||
|
assert hasattr(engine, "transform_response")
|
||||||
|
assert hasattr(engine, "run_similarity_query")
|
||||||
|
assert hasattr(engine, "semantic_search")
|
||||||
|
|
||||||
|
# Verify methods are callable
|
||||||
|
assert callable(engine.transform_conditions)
|
||||||
|
assert callable(engine.transform_response)
|
||||||
|
assert callable(engine.run_similarity_query)
|
||||||
|
assert callable(engine.semantic_search)
|
||||||
|
|
||||||
|
def test_future_engine_extensibility(self):
|
||||||
|
"""Test structure supports future engine additions"""
|
||||||
|
# Verify that EngineType is a StrEnum and can be extended
|
||||||
|
assert issubclass(Backend, str)
|
||||||
|
|
||||||
|
# Verify the factory function structure can handle new engines
|
||||||
|
# (This is more of a design verification)
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
sig = inspect.signature(get_engine)
|
||||||
|
|
||||||
|
# Should take Backend and return BaseEngine
|
||||||
|
params = list(sig.parameters.values())
|
||||||
|
assert len(params) == 1
|
||||||
|
assert params[0].name == "backend"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEngineTypeEnum:
|
||||||
|
"""Test suite specifically for Backend enum"""
|
||||||
|
|
||||||
|
def test_engine_type_is_str_enum(self):
|
||||||
|
"""Test that Backend is a StrEnum"""
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
assert issubclass(Backend, StrEnum)
|
||||||
|
|
||||||
|
# Should behave like strings
|
||||||
|
assert Backend.QDRANT == "qdrant"
|
||||||
|
assert f"{Backend.QDRANT}" == "qdrant"
|
||||||
|
|
||||||
|
def test_engine_type_comparison(self):
|
||||||
|
"""Test EngineType comparison operations"""
|
||||||
|
# Should equal string value
|
||||||
|
assert Backend.QDRANT == "qdrant"
|
||||||
|
|
||||||
|
# Should not equal other strings
|
||||||
|
assert Backend.QDRANT != "other"
|
||||||
|
assert Backend.QDRANT != "QDRANT" # Case sensitive
|
||||||
|
|
||||||
|
def test_engine_type_in_collections(self):
|
||||||
|
"""Test EngineType works in collections"""
|
||||||
|
engine_list = [Backend.QDRANT]
|
||||||
|
assert Backend.QDRANT in engine_list
|
||||||
|
assert "qdrant" in engine_list # StrEnum benefit
|
||||||
|
|
||||||
|
engine_set = {Backend.QDRANT}
|
||||||
|
assert Backend.QDRANT in engine_set
|
||||||
|
|
||||||
|
def test_engine_type_json_serializable(self):
|
||||||
|
"""Test that Backend can be JSON serialized"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = {"engine": Backend.QDRANT}
|
||||||
|
json_str = json.dumps(data, default=str)
|
||||||
|
assert '"engine": "qdrant"' in json_str
|
||||||
|
|
||||||
|
def test_engine_type_immutable(self):
|
||||||
|
"""Test that Backend values cannot be modified"""
|
||||||
|
original_value = Backend.QDRANT
|
||||||
|
|
||||||
|
# Enum values should be immutable
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
Backend.QDRANT = "modified" # type: ignore
|
||||||
|
|
||||||
|
# Original should be unchanged
|
||||||
|
assert Backend.QDRANT == original_value
|
||||||
|
|
||||||
|
|
||||||
|
class TestEngineFactoryErrorHandling:
|
||||||
|
"""Test suite for error handling in engine factory"""
|
||||||
|
|
||||||
|
def test_none_backend_type(self):
|
||||||
|
"""Test get_engine with None raises appropriate error"""
|
||||||
|
with pytest.raises((TypeError, ValueError)):
|
||||||
|
get_engine(None) # type: ignore
|
||||||
|
|
||||||
|
def test_empty_string_backend_type(self):
|
||||||
|
"""Test get_engine with empty string"""
|
||||||
|
with pytest.raises(ValueError, match="Unknown engine type"):
|
||||||
|
get_engine("") # type: ignore
|
||||||
|
|
||||||
|
def test_numeric_backend_type(self):
|
||||||
|
"""Test get_engine with numeric input"""
|
||||||
|
with pytest.raises((TypeError, ValueError)):
|
||||||
|
get_engine(123) # type: ignore
|
||||||
|
|
||||||
|
def test_boolean_backend_type(self):
|
||||||
|
"""Test get_engine with boolean input"""
|
||||||
|
with pytest.raises((TypeError, ValueError)):
|
||||||
|
get_engine(True) # type: ignore
|
||||||
|
|
||||||
|
def test_get_engine_cosmos_not_implemented(self):
|
||||||
|
"""Test that COSMOS engine raises NotImplementedError"""
|
||||||
|
with pytest.raises(
|
||||||
|
NotImplementedError, match="Cosmos engine is not implemented yet"
|
||||||
|
):
|
||||||
|
get_engine(Backend.COSMOS)
|
||||||
|
|
||||||
|
def test_engine_initialization_failure(self):
|
||||||
|
"""Test handling of engine initialization failures"""
|
||||||
|
with (
|
||||||
|
patch("vector_search_mcp.engine.qdrant_engine.Settings") as mock_settings,
|
||||||
|
patch("vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"),
|
||||||
|
):
|
||||||
|
# Make Settings initialization raise an exception
|
||||||
|
mock_settings.side_effect = Exception("Settings initialization failed")
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Settings initialization failed"):
|
||||||
|
get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
def test_case_sensitive_backend_type(self):
|
||||||
|
"""Test that backend type matching is case sensitive"""
|
||||||
|
with pytest.raises(ValueError, match="Unknown engine type"):
|
||||||
|
get_engine("QDRANT") # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown engine type"):
|
||||||
|
get_engine("Qdrant") # type: ignore
|
||||||
|
|
||||||
|
def test_whitespace_backend_type(self):
|
||||||
|
"""Test backend type with whitespace"""
|
||||||
|
with pytest.raises(ValueError, match="Unknown engine type"):
|
||||||
|
get_engine(" qdrant ") # type: ignore
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown engine type"):
|
||||||
|
get_engine("\tqdrant\n") # type: ignore
|
||||||
361
tests/test_engine/test_integration.py
Normal file
361
tests/test_engine/test_integration.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
from vector_search_mcp.engine import Backend, get_engine
|
||||||
|
from vector_search_mcp.models import Match, MatchAny, MatchExclude, SearchRow
|
||||||
|
|
||||||
|
|
||||||
|
class TestEngineIntegration:
|
||||||
|
"""Integration tests for the complete engine workflow"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_complete_engine_setup(self):
|
||||||
|
"""Setup complete mocked engine environment"""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.Settings"
|
||||||
|
) as mock_settings_class,
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
||||||
|
) as mock_client_class,
|
||||||
|
):
|
||||||
|
# Setup settings
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.url = "http://localhost:6333"
|
||||||
|
mock_settings.api_key = "test_api_key"
|
||||||
|
mock_settings_class.return_value = mock_settings
|
||||||
|
|
||||||
|
# Setup client with realistic response
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.search.return_value = [
|
||||||
|
models.ScoredPoint(
|
||||||
|
id="doc_1",
|
||||||
|
score=0.95,
|
||||||
|
payload={
|
||||||
|
"text": "Advanced Python programming techniques for data science",
|
||||||
|
"category": "programming",
|
||||||
|
"language": "python",
|
||||||
|
"difficulty": "advanced",
|
||||||
|
"tags": ["python", "data-science", "machine-learning"],
|
||||||
|
},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
models.ScoredPoint(
|
||||||
|
id="doc_2",
|
||||||
|
score=0.87,
|
||||||
|
payload={
|
||||||
|
"text": "Rust systems programming for performance-critical applications",
|
||||||
|
"category": "programming",
|
||||||
|
"language": "rust",
|
||||||
|
"difficulty": "intermediate",
|
||||||
|
"tags": ["rust", "systems", "performance"],
|
||||||
|
},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
models.ScoredPoint(
|
||||||
|
id="doc_3",
|
||||||
|
score=0.78,
|
||||||
|
payload={
|
||||||
|
"text": "Introduction to machine learning with Python",
|
||||||
|
"category": "programming",
|
||||||
|
"language": "python",
|
||||||
|
"difficulty": "beginner",
|
||||||
|
"tags": ["python", "machine-learning", "tutorial"],
|
||||||
|
},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"settings": mock_settings,
|
||||||
|
"client": mock_client,
|
||||||
|
"settings_class": mock_settings_class,
|
||||||
|
"client_class": mock_client_class,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_semantic_search_workflow(self, mock_complete_engine_setup):
|
||||||
|
"""Test the complete workflow from factory to results"""
|
||||||
|
mocks = mock_complete_engine_setup
|
||||||
|
|
||||||
|
# 1. Create engine through factory
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
# 2. Prepare search parameters
|
||||||
|
query_vector = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
||||||
|
collection_name = "programming_docs"
|
||||||
|
search_conditions = [
|
||||||
|
Match(key="category", value="programming"),
|
||||||
|
MatchAny(key="language", any=["python", "rust"]),
|
||||||
|
MatchExclude(key="difficulty", exclude=["expert"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. Execute semantic search
|
||||||
|
results = await engine.semantic_search(
|
||||||
|
embedding=query_vector,
|
||||||
|
collection=collection_name,
|
||||||
|
limit=5,
|
||||||
|
conditions=search_conditions,
|
||||||
|
threshold=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Verify the complete flow
|
||||||
|
|
||||||
|
# Check that client.search was called with correct parameters
|
||||||
|
client_mock = mocks["client"]
|
||||||
|
client_mock.search.assert_called_once()
|
||||||
|
|
||||||
|
call_args = client_mock.search.call_args
|
||||||
|
assert call_args[1]["collection_name"] == collection_name
|
||||||
|
assert call_args[1]["query_vector"] == query_vector
|
||||||
|
assert call_args[1]["limit"] == 5
|
||||||
|
assert call_args[1]["score_threshold"] == 0.7
|
||||||
|
assert call_args[1]["with_payload"] is True
|
||||||
|
assert call_args[1]["with_vectors"] is False
|
||||||
|
|
||||||
|
# Verify conditions were transformed to Qdrant filter
|
||||||
|
qdrant_filter = call_args[1]["query_filter"]
|
||||||
|
assert isinstance(qdrant_filter, models.Filter)
|
||||||
|
assert len(qdrant_filter.must) == 3
|
||||||
|
|
||||||
|
# Check individual conditions
|
||||||
|
conditions = qdrant_filter.must
|
||||||
|
|
||||||
|
# Match condition
|
||||||
|
match_condition = next(c for c in conditions if c.key == "category")
|
||||||
|
assert isinstance(match_condition.match, models.MatchValue)
|
||||||
|
assert match_condition.match.value == "programming"
|
||||||
|
|
||||||
|
# MatchAny condition
|
||||||
|
match_any_condition = next(c for c in conditions if c.key == "language")
|
||||||
|
assert isinstance(match_any_condition.match, models.MatchAny)
|
||||||
|
assert match_any_condition.match.any == ["python", "rust"]
|
||||||
|
|
||||||
|
# MatchExclude condition
|
||||||
|
match_exclude_condition = next(c for c in conditions if c.key == "difficulty")
|
||||||
|
assert isinstance(match_exclude_condition.match, models.MatchExcept)
|
||||||
|
|
||||||
|
# 5. Verify results transformation
|
||||||
|
assert isinstance(results, list)
|
||||||
|
assert len(results) == 3
|
||||||
|
assert all(isinstance(result, SearchRow) for result in results)
|
||||||
|
|
||||||
|
# Check first result
|
||||||
|
assert results[0].chunk_id == "doc_1"
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
assert (
|
||||||
|
results[0].payload["text"]
|
||||||
|
== "Advanced Python programming techniques for data science"
|
||||||
|
)
|
||||||
|
assert results[0].payload["category"] == "programming"
|
||||||
|
|
||||||
|
# Check second result
|
||||||
|
assert results[1].chunk_id == "doc_2"
|
||||||
|
assert results[1].score == 0.87
|
||||||
|
assert results[1].payload["language"] == "rust"
|
||||||
|
|
||||||
|
# Check third result
|
||||||
|
assert results[2].chunk_id == "doc_3"
|
||||||
|
assert results[2].score == 0.78
|
||||||
|
assert results[2].payload["difficulty"] == "beginner"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_with_no_conditions(self, mock_complete_engine_setup):
|
||||||
|
"""Test semantic search without any conditions"""
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
results = await engine.semantic_search(
|
||||||
|
embedding=[0.1, 0.2, 0.3], collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify no filter was applied
|
||||||
|
client_mock = mock_complete_engine_setup["client"]
|
||||||
|
call_args = client_mock.search.call_args
|
||||||
|
assert call_args[1]["query_filter"] is None
|
||||||
|
|
||||||
|
# Results should still be transformed
|
||||||
|
assert len(results) == 3
|
||||||
|
assert all(isinstance(result, SearchRow) for result in results)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_with_empty_conditions(self, mock_complete_engine_setup):
|
||||||
|
"""Test semantic search with empty conditions list"""
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
results = await engine.semantic_search(
|
||||||
|
embedding=[0.1, 0.2, 0.3], collection="test_collection", conditions=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify no filter was applied
|
||||||
|
client_mock = mock_complete_engine_setup["client"]
|
||||||
|
call_args = client_mock.search.call_args
|
||||||
|
assert call_args[1]["query_filter"] is None
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_filters_null_payloads(self, mock_complete_engine_setup):
|
||||||
|
"""Test that results with null payloads are filtered out"""
|
||||||
|
# Override the mock response to include null payload
|
||||||
|
client_mock = mock_complete_engine_setup["client"]
|
||||||
|
client_mock.search.return_value = [
|
||||||
|
models.ScoredPoint(
|
||||||
|
id="valid_1",
|
||||||
|
score=0.95,
|
||||||
|
payload={"text": "Valid document"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
models.ScoredPoint(
|
||||||
|
id="invalid",
|
||||||
|
score=0.90,
|
||||||
|
payload=None, # This should be filtered out
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
models.ScoredPoint(
|
||||||
|
id="valid_2",
|
||||||
|
score=0.85,
|
||||||
|
payload={"text": "Another valid document"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
results = await engine.semantic_search(
|
||||||
|
embedding=[0.1, 0.2, 0.3], collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only have 2 results (null payload filtered out)
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].chunk_id == "valid_1"
|
||||||
|
assert results[1].chunk_id == "valid_2"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_propagation_from_client(self, mock_complete_engine_setup):
|
||||||
|
"""Test that client errors are properly propagated"""
|
||||||
|
# Make the client raise an exception
|
||||||
|
client_mock = mock_complete_engine_setup["client"]
|
||||||
|
client_mock.search.side_effect = Exception("Qdrant connection timeout")
|
||||||
|
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Qdrant connection timeout"):
|
||||||
|
await engine.semantic_search(
|
||||||
|
embedding=[0.1, 0.2, 0.3], collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_with_named_vector(self, mock_complete_engine_setup):
|
||||||
|
"""Test semantic search with NamedVector instead of regular vector"""
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
named_vector = models.NamedVector(
|
||||||
|
name="text_embedding", vector=[0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await engine.semantic_search(
|
||||||
|
embedding=named_vector, # type: ignore - Testing duck typing
|
||||||
|
collection="test_collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify named vector was passed through
|
||||||
|
client_mock = mock_complete_engine_setup["client"]
|
||||||
|
call_args = client_mock.search.call_args
|
||||||
|
assert call_args[1]["query_vector"] == named_vector
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_parameter_defaults(self, mock_complete_engine_setup):
|
||||||
|
"""Test that default parameters are applied correctly"""
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
await engine.semantic_search(
|
||||||
|
embedding=[0.1, 0.2, 0.3], collection="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
client_mock = mock_complete_engine_setup["client"]
|
||||||
|
call_args = client_mock.search.call_args
|
||||||
|
|
||||||
|
# Check defaults
|
||||||
|
assert call_args[1]["limit"] == 10 # default limit
|
||||||
|
assert call_args[1]["score_threshold"] is None # default threshold
|
||||||
|
assert call_args[1]["query_filter"] is None # default conditions
|
||||||
|
assert call_args[1]["with_payload"] is True
|
||||||
|
assert call_args[1]["with_vectors"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_engine_instances_independence(
|
||||||
|
self, mock_complete_engine_setup
|
||||||
|
):
|
||||||
|
"""Test that multiple engine instances work independently"""
|
||||||
|
# Create two engines
|
||||||
|
engine1 = get_engine(Backend.QDRANT)
|
||||||
|
engine2 = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
|
# Verify they are the same instance due to caching
|
||||||
|
assert engine1 is engine2
|
||||||
|
|
||||||
|
# Both should work with the same instance
|
||||||
|
results1 = await engine1.semantic_search(
|
||||||
|
embedding=[0.1, 0.2, 0.3], collection="collection1"
|
||||||
|
)
|
||||||
|
|
||||||
|
results2 = await engine2.semantic_search(
|
||||||
|
embedding=[0.4, 0.5, 0.6], collection="collection2"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results1) == 3
|
||||||
|
assert len(results2) == 3
|
||||||
|
|
||||||
|
# Verify client was called twice (same instance, multiple calls)
|
||||||
|
client_mock = mock_complete_engine_setup["client"]
|
||||||
|
assert client_mock.search.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_large_result_set_handling(self, mock_complete_engine_setup):
|
||||||
|
"""Test handling of large result sets"""
|
||||||
|
# Create a large mock response
|
||||||
|
large_response = []
|
||||||
|
for i in range(100):
|
||||||
|
large_response.append(
|
||||||
|
models.ScoredPoint(
|
||||||
|
id=f"doc_{i}",
|
||||||
|
score=0.9 - (i * 0.001), # Decreasing scores
|
||||||
|
payload={"text": f"Document {i}", "index": i},
|
||||||
|
version=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
client_mock = mock_complete_engine_setup["client"]
|
||||||
|
client_mock.search.return_value = large_response
|
||||||
|
|
||||||
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
results = await engine.semantic_search(
|
||||||
|
embedding=[0.1, 0.2, 0.3], collection="large_collection", limit=100
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should handle all 100 results
|
||||||
|
assert len(results) == 100
|
||||||
|
assert results[0].chunk_id == "doc_0"
|
||||||
|
assert results[0].score == 0.9
|
||||||
|
assert results[99].chunk_id == "doc_99"
|
||||||
|
assert results[99].score == 0.801 # 0.9 - (99 * 0.001)
|
||||||
|
|
||||||
|
def test_engine_type_consistency(self):
|
||||||
|
"""Test that engine types are consistent across multiple calls"""
|
||||||
|
with (
|
||||||
|
patch("vector_search_mcp.engine.qdrant_engine.Settings"),
|
||||||
|
patch("vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"),
|
||||||
|
):
|
||||||
|
engines = [get_engine(Backend.QDRANT) for _ in range(5)]
|
||||||
|
|
||||||
|
# All should be the same instance due to caching
|
||||||
|
assert all(engine is engines[0] for engine in engines)
|
||||||
|
|
||||||
|
# All should be QdrantEngine instances
|
||||||
|
from vector_search_mcp.engine.qdrant_engine import QdrantEngine
|
||||||
|
|
||||||
|
assert all(isinstance(engine, QdrantEngine) for engine in engines)
|
||||||
380
tests/test_engine/test_qdrant_engine.py
Normal file
380
tests/test_engine/test_qdrant_engine.py
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from qdrant_client import models
|
||||||
|
from qdrant_client.models import ScoredPoint
|
||||||
|
|
||||||
|
from vector_search_mcp.engine.base_engine import BaseEngine
|
||||||
|
from vector_search_mcp.engine.qdrant_engine import QdrantEngine
|
||||||
|
from vector_search_mcp.models import Match, MatchAny, MatchExclude, SearchRow
|
||||||
|
|
||||||
|
|
||||||
|
class TestQdrantEngine:
|
||||||
|
"""Test suite for QdrantEngine"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_client(self):
|
||||||
|
"""Create a mock Qdrant client"""
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings(self):
|
||||||
|
"""Create mock settings"""
|
||||||
|
settings = MagicMock()
|
||||||
|
settings.url = "http://localhost:6333"
|
||||||
|
settings.api_key = "test_api_key"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def qdrant_engine(self, mock_client, mock_settings):
|
||||||
|
"""Create a QdrantEngine instance with mocked dependencies"""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.Settings"
|
||||||
|
) as mock_settings_class,
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
||||||
|
) as mock_client_class,
|
||||||
|
):
|
||||||
|
mock_settings_class.return_value = mock_settings
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
engine = QdrantEngine()
|
||||||
|
engine.client = mock_client # Ensure we use our mock
|
||||||
|
return engine
|
||||||
|
|
||||||
|
def test_inheritance(self, qdrant_engine):
|
||||||
|
"""Test that QdrantEngine properly inherits from BaseEngine"""
|
||||||
|
assert isinstance(qdrant_engine, BaseEngine)
|
||||||
|
assert isinstance(qdrant_engine, QdrantEngine)
|
||||||
|
|
||||||
|
def test_typing_parameters(self):
|
||||||
|
"""Test that QdrantEngine has correct generic type parameters"""
|
||||||
|
# QdrantEngine should be BaseEngine[list[models.ScoredPoint], models.Filter]
|
||||||
|
# This is verified by the type checker, but we can test the methods exist
|
||||||
|
assert hasattr(QdrantEngine, "transform_conditions")
|
||||||
|
assert hasattr(QdrantEngine, "transform_response")
|
||||||
|
assert hasattr(QdrantEngine, "run_similarity_query")
|
||||||
|
|
||||||
|
def test_transform_conditions_none(self, qdrant_engine):
|
||||||
|
"""Test transform_conditions with None input"""
|
||||||
|
result = qdrant_engine.transform_conditions(None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_transform_conditions_empty_list(self, qdrant_engine):
|
||||||
|
"""Test transform_conditions with empty list"""
|
||||||
|
result = qdrant_engine.transform_conditions([])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_transform_conditions_with_match(self, qdrant_engine):
|
||||||
|
"""Test transform_conditions with Match condition"""
|
||||||
|
conditions = [Match(key="category", value="document")]
|
||||||
|
|
||||||
|
result = qdrant_engine.transform_conditions(conditions)
|
||||||
|
|
||||||
|
assert isinstance(result, models.Filter)
|
||||||
|
assert len(result.must) == 1
|
||||||
|
|
||||||
|
condition = result.must[0]
|
||||||
|
assert isinstance(condition, models.FieldCondition)
|
||||||
|
assert condition.key == "category"
|
||||||
|
assert isinstance(condition.match, models.MatchValue)
|
||||||
|
assert condition.match.value == "document"
|
||||||
|
|
||||||
|
def test_transform_conditions_with_match_any(self, qdrant_engine):
|
||||||
|
"""Test transform_conditions with MatchAny condition"""
|
||||||
|
conditions = [MatchAny(key="tags", any=["python", "rust", "javascript"])]
|
||||||
|
|
||||||
|
result = qdrant_engine.transform_conditions(conditions)
|
||||||
|
|
||||||
|
assert isinstance(result, models.Filter)
|
||||||
|
assert len(result.must) == 1
|
||||||
|
|
||||||
|
condition = result.must[0]
|
||||||
|
assert isinstance(condition, models.FieldCondition)
|
||||||
|
assert condition.key == "tags"
|
||||||
|
assert isinstance(condition.match, models.MatchAny)
|
||||||
|
assert condition.match.any == ["python", "rust", "javascript"]
|
||||||
|
|
||||||
|
def test_transform_conditions_with_match_exclude(self, qdrant_engine):
|
||||||
|
"""Test transform_conditions with MatchExclude condition"""
|
||||||
|
conditions = [MatchExclude(key="status", exclude=["deleted", "archived"])]
|
||||||
|
|
||||||
|
result = qdrant_engine.transform_conditions(conditions)
|
||||||
|
|
||||||
|
assert isinstance(result, models.Filter)
|
||||||
|
assert len(result.must) == 1
|
||||||
|
|
||||||
|
condition = result.must[0]
|
||||||
|
assert isinstance(condition, models.FieldCondition)
|
||||||
|
assert condition.key == "status"
|
||||||
|
assert isinstance(condition.match, models.MatchExcept)
|
||||||
|
# MatchExcept uses 'except' parameter which conflicts with Python keyword
|
||||||
|
assert hasattr(condition.match, "except_")
|
||||||
|
|
||||||
|
def test_transform_conditions_multiple(self, qdrant_engine):
|
||||||
|
"""Test transform_conditions with multiple conditions"""
|
||||||
|
conditions = [
|
||||||
|
Match(key="type", value="article"),
|
||||||
|
MatchAny(key="language", any=["en", "es"]),
|
||||||
|
MatchExclude(key="status", exclude=["draft"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = qdrant_engine.transform_conditions(conditions)
|
||||||
|
|
||||||
|
assert isinstance(result, models.Filter)
|
||||||
|
assert len(result.must) == 3
|
||||||
|
|
||||||
|
# Verify all conditions are FieldCondition instances
|
||||||
|
assert all(isinstance(cond, models.FieldCondition) for cond in result.must)
|
||||||
|
|
||||||
|
def test_transform_response_empty(self, qdrant_engine):
|
||||||
|
"""Test transform_response with empty response"""
|
||||||
|
response = []
|
||||||
|
|
||||||
|
result = qdrant_engine.transform_response(response)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_transform_response_with_scored_points(self, qdrant_engine):
|
||||||
|
"""Test transform_response with valid ScoredPoint objects"""
|
||||||
|
response = [
|
||||||
|
ScoredPoint(
|
||||||
|
id=1,
|
||||||
|
score=0.95,
|
||||||
|
payload={"text": "First document", "category": "tech"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
ScoredPoint(
|
||||||
|
id=2,
|
||||||
|
score=0.87,
|
||||||
|
payload={"text": "Second document", "category": "science"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = qdrant_engine.transform_response(response)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
# Check first result
|
||||||
|
assert isinstance(result[0], SearchRow)
|
||||||
|
assert result[0].chunk_id == "1"
|
||||||
|
assert result[0].score == 0.95
|
||||||
|
assert result[0].payload == {"text": "First document", "category": "tech"}
|
||||||
|
|
||||||
|
# Check second result
|
||||||
|
assert isinstance(result[1], SearchRow)
|
||||||
|
assert result[1].chunk_id == "2"
|
||||||
|
assert result[1].score == 0.87
|
||||||
|
assert result[1].payload == {"text": "Second document", "category": "science"}
|
||||||
|
|
||||||
|
def test_transform_response_filters_none_payload(self, qdrant_engine):
|
||||||
|
"""Test transform_response filters out points with None payload"""
|
||||||
|
response = [
|
||||||
|
ScoredPoint(
|
||||||
|
id=1, score=0.95, payload={"text": "Valid document"}, version=1
|
||||||
|
),
|
||||||
|
ScoredPoint(
|
||||||
|
id=2,
|
||||||
|
score=0.87,
|
||||||
|
payload=None, # This should be filtered out
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
ScoredPoint(
|
||||||
|
id=3, score=0.75, payload={"text": "Another valid document"}, version=1
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = qdrant_engine.transform_response(response)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2 # Only 2 valid results
|
||||||
|
assert result[0].chunk_id == "1"
|
||||||
|
assert result[1].chunk_id == "3"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_similarity_query_basic(self, qdrant_engine, mock_client):
|
||||||
|
"""Test run_similarity_query with basic parameters"""
|
||||||
|
# Setup mock response
|
||||||
|
mock_response = [
|
||||||
|
ScoredPoint(id=1, score=0.9, payload={"text": "Test document"}, version=1)
|
||||||
|
]
|
||||||
|
mock_client.search.return_value = mock_response
|
||||||
|
|
||||||
|
embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
collection = "test_collection"
|
||||||
|
|
||||||
|
result = await qdrant_engine.run_similarity_query(
|
||||||
|
embedding=embedding, collection=collection
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify client.search was called with correct parameters
|
||||||
|
mock_client.search.assert_called_once_with(
|
||||||
|
collection_name=collection,
|
||||||
|
query_vector=embedding,
|
||||||
|
query_filter=None,
|
||||||
|
limit=10, # default
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False,
|
||||||
|
score_threshold=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_response
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_similarity_query_with_all_parameters(
|
||||||
|
self, qdrant_engine, mock_client
|
||||||
|
):
|
||||||
|
"""Test run_similarity_query with all parameters"""
|
||||||
|
mock_response = []
|
||||||
|
mock_client.search.return_value = mock_response
|
||||||
|
|
||||||
|
embedding = [0.1, 0.2, 0.3]
|
||||||
|
collection = "test_collection"
|
||||||
|
limit = 5
|
||||||
|
conditions = models.Filter(must=[])
|
||||||
|
threshold = 0.75
|
||||||
|
|
||||||
|
result = await qdrant_engine.run_similarity_query(
|
||||||
|
embedding=embedding,
|
||||||
|
collection=collection,
|
||||||
|
limit=limit,
|
||||||
|
conditions=conditions,
|
||||||
|
threshold=threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.search.assert_called_once_with(
|
||||||
|
collection_name=collection,
|
||||||
|
query_vector=embedding,
|
||||||
|
query_filter=conditions,
|
||||||
|
limit=limit,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False,
|
||||||
|
score_threshold=threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_response
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_similarity_query_with_named_vector(
|
||||||
|
self, qdrant_engine, mock_client
|
||||||
|
):
|
||||||
|
"""Test run_similarity_query with NamedVector"""
|
||||||
|
mock_response = []
|
||||||
|
mock_client.search.return_value = mock_response
|
||||||
|
|
||||||
|
named_vector = models.NamedVector(name="text", vector=[0.1, 0.2, 0.3])
|
||||||
|
collection = "test_collection"
|
||||||
|
|
||||||
|
result = await qdrant_engine.run_similarity_query(
|
||||||
|
embedding=named_vector, collection=collection
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.search.assert_called_once_with(
|
||||||
|
collection_name=collection,
|
||||||
|
query_vector=named_vector,
|
||||||
|
query_filter=None,
|
||||||
|
limit=10,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False,
|
||||||
|
score_threshold=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == mock_response
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_semantic_search_integration(self, qdrant_engine, mock_client):
|
||||||
|
"""Test the full semantic_search flow through QdrantEngine"""
|
||||||
|
# Setup mock response
|
||||||
|
mock_search_response = [
|
||||||
|
ScoredPoint(
|
||||||
|
id=1,
|
||||||
|
score=0.95,
|
||||||
|
payload={"text": "Python programming guide", "category": "tech"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
ScoredPoint(
|
||||||
|
id=2,
|
||||||
|
score=0.87,
|
||||||
|
payload={"text": "Rust systems programming", "category": "tech"},
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock_client.search.return_value = mock_search_response
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
vector = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
collection = "documents"
|
||||||
|
conditions = [
|
||||||
|
Match(key="category", value="tech"),
|
||||||
|
MatchAny(key="language", any=["python", "rust"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await qdrant_engine.semantic_search(
|
||||||
|
embedding=vector,
|
||||||
|
collection=collection,
|
||||||
|
limit=5,
|
||||||
|
conditions=conditions,
|
||||||
|
threshold=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the search was called with transformed conditions
|
||||||
|
assert mock_client.search.called
|
||||||
|
call_args = mock_client.search.call_args
|
||||||
|
assert call_args[1]["collection_name"] == collection
|
||||||
|
assert call_args[1]["query_vector"] == vector
|
||||||
|
assert call_args[1]["limit"] == 5
|
||||||
|
assert call_args[1]["score_threshold"] == 0.8
|
||||||
|
assert isinstance(call_args[1]["query_filter"], models.Filter)
|
||||||
|
|
||||||
|
# Verify the response was transformed correctly
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(isinstance(row, SearchRow) for row in result)
|
||||||
|
assert result[0].chunk_id == "1"
|
||||||
|
assert result[0].score == 0.95
|
||||||
|
assert result[1].chunk_id == "2"
|
||||||
|
assert result[1].score == 0.87
|
||||||
|
|
||||||
|
def test_initialization_with_settings(self, mock_settings):
|
||||||
|
"""Test QdrantEngine initialization uses settings correctly"""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.Settings"
|
||||||
|
) as mock_settings_class,
|
||||||
|
patch(
|
||||||
|
"vector_search_mcp.engine.qdrant_engine.AsyncQdrantClient"
|
||||||
|
) as mock_client_class,
|
||||||
|
):
|
||||||
|
mock_settings_class.return_value = mock_settings
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
|
||||||
|
engine = QdrantEngine()
|
||||||
|
|
||||||
|
# Verify Settings was instantiated
|
||||||
|
mock_settings_class.assert_called_once()
|
||||||
|
|
||||||
|
# Verify AsyncQdrantClient was created with correct parameters
|
||||||
|
mock_client_class.assert_called_once_with(
|
||||||
|
url=mock_settings.url, api_key=mock_settings.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
assert engine.client == mock_client
|
||||||
|
assert engine.settings == mock_settings
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_search_exception_propagation(
|
||||||
|
self, qdrant_engine, mock_client
|
||||||
|
):
|
||||||
|
"""Test that exceptions from client.search are properly propagated"""
|
||||||
|
# Setup mock to raise an exception
|
||||||
|
mock_client.search.side_effect = Exception("Qdrant connection failed")
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Qdrant connection failed"):
|
||||||
|
await qdrant_engine.run_similarity_query(
|
||||||
|
embedding=[0.1, 0.2, 0.3], collection="test_collection"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user