forked from innovacion/searchbox
This commit renames the package from vector-search-mcp to searchbox. The package imports and executable name are updated accordingly.
260 lines
10 KiB
Python
260 lines
10 KiB
Python
"""Tests for the vector search client module."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from searchbox.client import Client
|
|
from searchbox.engine import Backend
|
|
from searchbox.models import Chunk, ChunkData, Match
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_engine():
|
|
"""Create a mock engine for testing."""
|
|
engine = AsyncMock()
|
|
engine.create_index.return_value = True
|
|
engine.upload_chunk.return_value = True
|
|
engine.semantic_search.return_value = [
|
|
{"chunk_id": "1", "score": 0.95, "payload": {"text": "result 1"}},
|
|
{"chunk_id": "2", "score": 0.85, "payload": {"text": "result 2"}},
|
|
]
|
|
return engine
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_chunk():
|
|
"""Create a sample chunk for testing."""
|
|
return Chunk(
|
|
id="test-chunk-1",
|
|
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
|
|
payload=ChunkData(
|
|
page_content="This is a test chunk content",
|
|
filename="test_document.pdf",
|
|
page=1,
|
|
),
|
|
)
|
|
|
|
|
|
class TestClient:
|
|
"""Test suite for the Client class."""
|
|
|
|
def test_client_initialization(self, mock_engine, monkeypatch):
|
|
"""Test that Client initializes correctly with backend and collection."""
|
|
# Mock the get_engine function
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="test_collection")
|
|
|
|
assert client.collection == "test_collection"
|
|
assert client.engine == mock_engine
|
|
mock_get_engine.assert_called_once_with(Backend.QDRANT)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_index(self, mock_engine, monkeypatch):
|
|
"""Test create_index method delegates to engine."""
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="test_collection")
|
|
result = await client.create_index(size=512)
|
|
|
|
assert result is True
|
|
mock_engine.create_index.assert_called_once_with("test_collection", 512)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_index_failure(self, mock_engine, monkeypatch):
|
|
"""Test create_index method handles failure."""
|
|
mock_engine.create_index.return_value = False
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="test_collection")
|
|
result = await client.create_index(size=256)
|
|
|
|
assert result is False
|
|
mock_engine.create_index.assert_called_once_with("test_collection", 256)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_chunk(self, mock_engine, monkeypatch, sample_chunk):
|
|
"""Test upload_chunk method delegates to engine."""
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="documents")
|
|
result = await client.upload_chunk(sample_chunk)
|
|
|
|
assert result is True
|
|
mock_engine.upload_chunk.assert_called_once_with("documents", sample_chunk)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_chunk_failure(self, mock_engine, monkeypatch, sample_chunk):
|
|
"""Test upload_chunk method handles failure."""
|
|
mock_engine.upload_chunk.return_value = False
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="documents")
|
|
result = await client.upload_chunk(sample_chunk)
|
|
|
|
assert result is False
|
|
mock_engine.upload_chunk.assert_called_once_with("documents", sample_chunk)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_default_parameters(self, mock_engine, monkeypatch):
|
|
"""Test semantic_search with default parameters."""
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="search_collection")
|
|
embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
|
|
result = await client.semantic_search(embedding)
|
|
|
|
mock_engine.semantic_search.assert_called_once_with(
|
|
embedding, "search_collection", 10, None, None
|
|
)
|
|
assert result == mock_engine.semantic_search.return_value
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_with_limit(self, mock_engine, monkeypatch):
|
|
"""Test semantic_search with custom limit."""
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="search_collection")
|
|
embedding = [0.1, 0.2, 0.3]
|
|
|
|
result = await client.semantic_search(embedding, limit=5)
|
|
|
|
mock_engine.semantic_search.assert_called_once_with(
|
|
embedding, "search_collection", 5, None, None
|
|
)
|
|
assert result == mock_engine.semantic_search.return_value
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_with_conditions(self, mock_engine, monkeypatch):
|
|
"""Test semantic_search with filter conditions."""
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="filtered_collection")
|
|
embedding = [0.1, 0.2, 0.3, 0.4]
|
|
conditions = [Match(key="category", value="technology")]
|
|
|
|
result = await client.semantic_search(embedding, conditions=conditions)
|
|
|
|
mock_engine.semantic_search.assert_called_once_with(
|
|
embedding, "filtered_collection", 10, conditions, None
|
|
)
|
|
assert result == mock_engine.semantic_search.return_value
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_with_threshold(self, mock_engine, monkeypatch):
|
|
"""Test semantic_search with similarity threshold."""
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="threshold_collection")
|
|
embedding = [0.5, 0.4, 0.3, 0.2, 0.1]
|
|
|
|
result = await client.semantic_search(embedding, threshold=0.8)
|
|
|
|
mock_engine.semantic_search.assert_called_once_with(
|
|
embedding, "threshold_collection", 10, None, 0.8
|
|
)
|
|
assert result == mock_engine.semantic_search.return_value
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_all_parameters(self, mock_engine, monkeypatch):
|
|
"""Test semantic_search with all parameters specified."""
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="full_params_collection")
|
|
embedding = [0.2, 0.4, 0.6, 0.8, 1.0]
|
|
conditions = [
|
|
Match(key="status", value="published"),
|
|
Match(key="author", value="john_doe"),
|
|
]
|
|
|
|
result = await client.semantic_search(
|
|
embedding=embedding,
|
|
limit=3,
|
|
conditions=conditions,
|
|
threshold=0.75,
|
|
)
|
|
|
|
mock_engine.semantic_search.assert_called_once_with(
|
|
embedding, "full_params_collection", 3, conditions, 0.75
|
|
)
|
|
assert result == mock_engine.semantic_search.return_value
|
|
|
|
def test_client_is_final(self):
|
|
"""Test that Client class is marked as final."""
|
|
from typing import get_origin
|
|
|
|
# Check if Client is decorated with @final
|
|
assert hasattr(Client, "__final__") or Client.__dict__.get("__final__", False)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_integration_workflow(self, mock_engine, monkeypatch, sample_chunk):
|
|
"""Test a complete workflow: create index, upload chunk, search."""
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="workflow_test")
|
|
|
|
# Create index
|
|
index_result = await client.create_index(size=384)
|
|
assert index_result is True
|
|
|
|
# Upload chunk
|
|
upload_result = await client.upload_chunk(sample_chunk)
|
|
assert upload_result is True
|
|
|
|
# Search
|
|
search_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
search_result = await client.semantic_search(search_embedding, limit=5)
|
|
|
|
# Verify all operations were called correctly
|
|
mock_engine.create_index.assert_called_once_with("workflow_test", 384)
|
|
mock_engine.upload_chunk.assert_called_once_with("workflow_test", sample_chunk)
|
|
mock_engine.semantic_search.assert_called_once_with(
|
|
search_embedding, "workflow_test", 5, None, None
|
|
)
|
|
assert search_result == mock_engine.semantic_search.return_value
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_search_empty_results(self, mock_engine, monkeypatch):
|
|
"""Test semantic_search when no results are found."""
|
|
mock_engine.semantic_search.return_value = []
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="empty_results")
|
|
embedding = [0.1, 0.2, 0.3]
|
|
|
|
result = await client.semantic_search(embedding)
|
|
|
|
assert result == []
|
|
mock_engine.semantic_search.assert_called_once_with(
|
|
embedding, "empty_results", 10, None, None
|
|
)
|
|
|
|
def test_client_attributes_after_init(self, mock_engine, monkeypatch):
|
|
"""Test that client has the expected attributes after initialization."""
|
|
mock_get_engine = MagicMock(return_value=mock_engine)
|
|
monkeypatch.setattr("searchbox.client.get_engine", mock_get_engine)
|
|
|
|
client = Client(backend=Backend.QDRANT, collection="attr_test")
|
|
|
|
assert hasattr(client, "engine")
|
|
assert hasattr(client, "collection")
|
|
assert client.engine is mock_engine
|
|
assert client.collection == "attr_test"
|
|
assert hasattr(client, "create_index")
|
|
assert hasattr(client, "upload_chunk")
|
|
assert hasattr(client, "semantic_search")
|