forked from innovacion/searchbox
Add engine abstraction
This commit is contained in:
79
src/vector_search_mcp/engine/qdrant_engine.py
Normal file
79
src/vector_search_mcp/engine/qdrant_engine.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final, override
|
||||
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
||||
from ..config import Settings
|
||||
from ..models import SearchRow, Condition, Match, MatchAny, MatchExclude
|
||||
from .base_engine import BaseEngine
|
||||
|
||||
__all__ = ["QdrantEngine"]
|
||||
|
||||
|
||||
@final
|
||||
class QdrantEngine(BaseEngine[list[models.ScoredPoint], models.Filter]):
|
||||
def __init__(self) -> None:
|
||||
self.settings = Settings() # type: ignore[reportCallArgs]
|
||||
self.client = AsyncQdrantClient(
|
||||
url=self.settings.url, api_key=self.settings.api_key
|
||||
)
|
||||
|
||||
@override
|
||||
def transform_conditions(
|
||||
self, conditions: list[Condition] | None
|
||||
) -> models.Filter | None:
|
||||
if not conditions:
|
||||
return None
|
||||
|
||||
filters: list[models.Condition] = []
|
||||
|
||||
for condition in conditions:
|
||||
if isinstance(condition, Match):
|
||||
filters.append(
|
||||
models.FieldCondition(
|
||||
key=condition.key,
|
||||
match=models.MatchValue(value=condition.value),
|
||||
)
|
||||
)
|
||||
elif isinstance(condition, MatchAny):
|
||||
filters.append(
|
||||
models.FieldCondition(
|
||||
key=condition.key, match=models.MatchAny(any=condition.any)
|
||||
)
|
||||
)
|
||||
elif isinstance(condition, MatchExclude):
|
||||
filters.append(
|
||||
models.FieldCondition(
|
||||
key=condition.key,
|
||||
match=models.MatchExcept(**{"except": condition.exclude}),
|
||||
)
|
||||
)
|
||||
|
||||
return models.Filter(must=filters)
|
||||
|
||||
@override
|
||||
def transform_response(self, response: list[models.ScoredPoint]) -> list[SearchRow]:
|
||||
return [
|
||||
SearchRow(chunk_id=str(point.id), score=point.score, payload=point.payload)
|
||||
for point in response
|
||||
if point.payload is not None
|
||||
]
|
||||
|
||||
@override
|
||||
async def run_similarity_query(
|
||||
self,
|
||||
embedding: Sequence[float] | models.NamedVector,
|
||||
collection: str,
|
||||
limit: int = 10,
|
||||
conditions: models.Filter | None = None,
|
||||
threshold: float | None = None,
|
||||
) -> list[models.ScoredPoint]:
|
||||
return await self.client.search(
|
||||
collection_name=collection,
|
||||
query_vector=embedding,
|
||||
query_filter=conditions,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
score_threshold=threshold,
|
||||
)
|
||||
Reference in New Issue
Block a user