forked from innovacion/searchbox
Add engine abstraction
This commit is contained in:
@@ -30,8 +30,8 @@ dev = [
|
|||||||
|
|
||||||
[tool.basedpyright]
|
[tool.basedpyright]
|
||||||
reportAny = false
|
reportAny = false
|
||||||
reportExplicitAny = false
|
|
||||||
enableTypeIgnoreComments = true
|
enableTypeIgnoreComments = true
|
||||||
|
reportUnreachable = false
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any, final
|
|
||||||
|
|
||||||
from qdrant_client import AsyncQdrantClient, models
|
|
||||||
|
|
||||||
from .config import Settings
|
|
||||||
from .models import SearchRow
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class QdrantEngine:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.settings = Settings() # type: ignore[reportCallIssue]
|
|
||||||
self.client = AsyncQdrantClient(
|
|
||||||
url=self.settings.url, api_key=self.settings.api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
async def semantic_search(
|
|
||||||
self,
|
|
||||||
embedding: Sequence[float] | models.NamedVector,
|
|
||||||
collection: str,
|
|
||||||
limit: int = 10,
|
|
||||||
conditions: Any | None = None,
|
|
||||||
threshold: float | None = None,
|
|
||||||
) -> list[SearchRow]:
|
|
||||||
points = await self.client.search(
|
|
||||||
collection_name=collection,
|
|
||||||
query_vector=embedding,
|
|
||||||
query_filter=conditions,
|
|
||||||
limit=limit,
|
|
||||||
with_payload=True,
|
|
||||||
with_vectors=False,
|
|
||||||
score_threshold=threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
SearchRow(chunk_id=str(point.id), score=point.score, payload=point.payload)
|
|
||||||
for point in points
|
|
||||||
if point.payload is not None
|
|
||||||
]
|
|
||||||
26
src/vector_search_mcp/engine/__init__.py
Normal file
26
src/vector_search_mcp/engine/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
from typing import Literal, overload
|
||||||
|
|
||||||
|
from .qdrant_engine import QdrantEngine
|
||||||
|
|
||||||
|
|
||||||
|
class EngineType(StrEnum):
|
||||||
|
QDRANT = "qdrant"
|
||||||
|
COSMOS = "cosmos"
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_engine(backend: Literal[EngineType.QDRANT]) -> QdrantEngine: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_engine(backend: Literal[EngineType.COSMOS]) -> QdrantEngine: ...
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine(backend: EngineType):
|
||||||
|
if backend == EngineType.QDRANT:
|
||||||
|
return QdrantEngine()
|
||||||
|
elif backend == EngineType.COSMOS:
|
||||||
|
raise NotImplementedError("Cosmos engine is not implemented yet")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown engine type: {backend}")
|
||||||
43
src/vector_search_mcp/engine/base_engine.py
Normal file
43
src/vector_search_mcp/engine/base_engine.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
from ..models import Condition, SearchRow
|
||||||
|
|
||||||
|
ResponseType = TypeVar("ResponseType")
|
||||||
|
ConditionType = TypeVar("ConditionType")
|
||||||
|
|
||||||
|
__all__ = ["BaseEngine"]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEngine(ABC, Generic[ResponseType, ConditionType]):
|
||||||
|
@abstractmethod
|
||||||
|
def transform_conditions(
|
||||||
|
self, conditions: list[Condition] | None
|
||||||
|
) -> ConditionType | None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform_response(self, response: ResponseType) -> list[SearchRow]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def run_similarity_query(
|
||||||
|
self,
|
||||||
|
embedding: list[float],
|
||||||
|
collection: str,
|
||||||
|
limit: int = 10,
|
||||||
|
conditions: ConditionType | None = None,
|
||||||
|
threshold: float | None = None,
|
||||||
|
) -> ResponseType: ...
|
||||||
|
|
||||||
|
async def semantic_search(
|
||||||
|
self,
|
||||||
|
vector: list[float],
|
||||||
|
collection: str,
|
||||||
|
limit: int = 10,
|
||||||
|
conditions: list[Condition] | None = None,
|
||||||
|
threshold: float | None = None,
|
||||||
|
) -> list[SearchRow]:
|
||||||
|
transformed_conditions = self.transform_conditions(conditions)
|
||||||
|
response = await self.run_similarity_query(
|
||||||
|
vector, collection, limit, transformed_conditions, threshold
|
||||||
|
)
|
||||||
|
return self.transform_response(response)
|
||||||
79
src/vector_search_mcp/engine/qdrant_engine.py
Normal file
79
src/vector_search_mcp/engine/qdrant_engine.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import final, override
|
||||||
|
|
||||||
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
|
||||||
|
from ..config import Settings
|
||||||
|
from ..models import SearchRow, Condition, Match, MatchAny, MatchExclude
|
||||||
|
from .base_engine import BaseEngine
|
||||||
|
|
||||||
|
__all__ = ["QdrantEngine"]
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class QdrantEngine(BaseEngine[list[models.ScoredPoint], models.Filter]):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.settings = Settings() # type: ignore[reportCallArgs]
|
||||||
|
self.client = AsyncQdrantClient(
|
||||||
|
url=self.settings.url, api_key=self.settings.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def transform_conditions(
|
||||||
|
self, conditions: list[Condition] | None
|
||||||
|
) -> models.Filter | None:
|
||||||
|
if not conditions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
filters: list[models.Condition] = []
|
||||||
|
|
||||||
|
for condition in conditions:
|
||||||
|
if isinstance(condition, Match):
|
||||||
|
filters.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=condition.key,
|
||||||
|
match=models.MatchValue(value=condition.value),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(condition, MatchAny):
|
||||||
|
filters.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=condition.key, match=models.MatchAny(any=condition.any)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(condition, MatchExclude):
|
||||||
|
filters.append(
|
||||||
|
models.FieldCondition(
|
||||||
|
key=condition.key,
|
||||||
|
match=models.MatchExcept(**{"except": condition.exclude}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return models.Filter(must=filters)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def transform_response(self, response: list[models.ScoredPoint]) -> list[SearchRow]:
|
||||||
|
return [
|
||||||
|
SearchRow(chunk_id=str(point.id), score=point.score, payload=point.payload)
|
||||||
|
for point in response
|
||||||
|
if point.payload is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def run_similarity_query(
|
||||||
|
self,
|
||||||
|
embedding: Sequence[float] | models.NamedVector,
|
||||||
|
collection: str,
|
||||||
|
limit: int = 10,
|
||||||
|
conditions: models.Filter | None = None,
|
||||||
|
threshold: float | None = None,
|
||||||
|
) -> list[models.ScoredPoint]:
|
||||||
|
return await self.client.search(
|
||||||
|
collection_name=collection,
|
||||||
|
query_vector=embedding,
|
||||||
|
query_filter=conditions,
|
||||||
|
limit=limit,
|
||||||
|
with_payload=True,
|
||||||
|
with_vectors=False,
|
||||||
|
score_threshold=threshold,
|
||||||
|
)
|
||||||
@@ -6,4 +6,22 @@ from pydantic import BaseModel
|
|||||||
class SearchRow(BaseModel):
|
class SearchRow(BaseModel):
|
||||||
chunk_id: str
|
chunk_id: str
|
||||||
score: float
|
score: float
|
||||||
payload: dict[str, Any]
|
payload: dict[str, Any] # type: ignore[reportExplicitAny]
|
||||||
|
|
||||||
|
|
||||||
|
class Condition(BaseModel): ...
|
||||||
|
|
||||||
|
|
||||||
|
class Match(Condition):
|
||||||
|
key: str
|
||||||
|
value: str
|
||||||
|
|
||||||
|
|
||||||
|
class MatchAny(Condition):
|
||||||
|
key: str
|
||||||
|
any: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class MatchExclude(Condition):
|
||||||
|
key: str
|
||||||
|
exclude: list[str]
|
||||||
|
|||||||
Reference in New Issue
Block a user