Add engine abstraction

This commit is contained in:
2025-09-26 14:38:44 +00:00
parent de9826a4b6
commit 0656ed93f1
6 changed files with 168 additions and 42 deletions

View File

@@ -1,40 +0,0 @@
from collections.abc import Sequence
from typing import Any, final
from qdrant_client import AsyncQdrantClient, models
from .config import Settings
from .models import SearchRow
@final
class QdrantEngine:
def __init__(self) -> None:
self.settings = Settings() # type: ignore[reportCallIssue]
self.client = AsyncQdrantClient(
url=self.settings.url, api_key=self.settings.api_key
)
async def semantic_search(
self,
embedding: Sequence[float] | models.NamedVector,
collection: str,
limit: int = 10,
conditions: Any | None = None,
threshold: float | None = None,
) -> list[SearchRow]:
points = await self.client.search(
collection_name=collection,
query_vector=embedding,
query_filter=conditions,
limit=limit,
with_payload=True,
with_vectors=False,
score_threshold=threshold,
)
return [
SearchRow(chunk_id=str(point.id), score=point.score, payload=point.payload)
for point in points
if point.payload is not None
]

View File

@@ -0,0 +1,26 @@
from enum import StrEnum
from typing import Literal, overload
from .qdrant_engine import QdrantEngine
class EngineType(StrEnum):
QDRANT = "qdrant"
COSMOS = "cosmos"
@overload
def get_engine(backend: Literal[EngineType.QDRANT]) -> QdrantEngine: ...
@overload
def get_engine(backend: Literal[EngineType.COSMOS]) -> QdrantEngine: ...
def get_engine(backend: EngineType):
if backend == EngineType.QDRANT:
return QdrantEngine()
elif backend == EngineType.COSMOS:
raise NotImplementedError("Cosmos engine is not implemented yet")
else:
raise ValueError(f"Unknown engine type: {backend}")

View File

@@ -0,0 +1,43 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from ..models import Condition, SearchRow
ResponseType = TypeVar("ResponseType")
ConditionType = TypeVar("ConditionType")
__all__ = ["BaseEngine"]
class BaseEngine(ABC, Generic[ResponseType, ConditionType]):
@abstractmethod
def transform_conditions(
self, conditions: list[Condition] | None
) -> ConditionType | None: ...
@abstractmethod
def transform_response(self, response: ResponseType) -> list[SearchRow]: ...
@abstractmethod
async def run_similarity_query(
self,
embedding: list[float],
collection: str,
limit: int = 10,
conditions: ConditionType | None = None,
threshold: float | None = None,
) -> ResponseType: ...
async def semantic_search(
self,
vector: list[float],
collection: str,
limit: int = 10,
conditions: list[Condition] | None = None,
threshold: float | None = None,
) -> list[SearchRow]:
transformed_conditions = self.transform_conditions(conditions)
response = await self.run_similarity_query(
vector, collection, limit, transformed_conditions, threshold
)
return self.transform_response(response)

View File

@@ -0,0 +1,79 @@
from collections.abc import Sequence
from typing import final, override
from qdrant_client import AsyncQdrantClient, models
from ..config import Settings
from ..models import SearchRow, Condition, Match, MatchAny, MatchExclude
from .base_engine import BaseEngine
__all__ = ["QdrantEngine"]
@final
class QdrantEngine(BaseEngine[list[models.ScoredPoint], models.Filter]):
def __init__(self) -> None:
self.settings = Settings() # type: ignore[reportCallArgs]
self.client = AsyncQdrantClient(
url=self.settings.url, api_key=self.settings.api_key
)
@override
def transform_conditions(
self, conditions: list[Condition] | None
) -> models.Filter | None:
if not conditions:
return None
filters: list[models.Condition] = []
for condition in conditions:
if isinstance(condition, Match):
filters.append(
models.FieldCondition(
key=condition.key,
match=models.MatchValue(value=condition.value),
)
)
elif isinstance(condition, MatchAny):
filters.append(
models.FieldCondition(
key=condition.key, match=models.MatchAny(any=condition.any)
)
)
elif isinstance(condition, MatchExclude):
filters.append(
models.FieldCondition(
key=condition.key,
match=models.MatchExcept(**{"except": condition.exclude}),
)
)
return models.Filter(must=filters)
@override
def transform_response(self, response: list[models.ScoredPoint]) -> list[SearchRow]:
return [
SearchRow(chunk_id=str(point.id), score=point.score, payload=point.payload)
for point in response
if point.payload is not None
]
@override
async def run_similarity_query(
self,
embedding: Sequence[float] | models.NamedVector,
collection: str,
limit: int = 10,
conditions: models.Filter | None = None,
threshold: float | None = None,
) -> list[models.ScoredPoint]:
return await self.client.search(
collection_name=collection,
query_vector=embedding,
query_filter=conditions,
limit=limit,
with_payload=True,
with_vectors=False,
score_threshold=threshold,
)

View File

@@ -6,4 +6,22 @@ from pydantic import BaseModel
class SearchRow(BaseModel):
chunk_id: str
score: float
payload: dict[str, Any]
payload: dict[str, Any] # type: ignore[reportExplicitAny]
class Condition(BaseModel): ...
class Match(Condition):
key: str
value: str
class MatchAny(Condition):
key: str
any: list[str]
class MatchExclude(Condition):
key: str
exclude: list[str]