diff --git a/pyproject.toml b/pyproject.toml index 4ccbcee..474162b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,8 @@ dev = [ [tool.basedpyright] reportAny = false -reportExplicitAny = false enableTypeIgnoreComments = true +reportUnreachable = false [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/src/vector_search_mcp/engine.py b/src/vector_search_mcp/engine.py deleted file mode 100644 index dabcc18..0000000 --- a/src/vector_search_mcp/engine.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Sequence -from typing import Any, final - -from qdrant_client import AsyncQdrantClient, models - -from .config import Settings -from .models import SearchRow - - -@final -class QdrantEngine: - def __init__(self) -> None: - self.settings = Settings() # type: ignore[reportCallIssue] - self.client = AsyncQdrantClient( - url=self.settings.url, api_key=self.settings.api_key - ) - - async def semantic_search( - self, - embedding: Sequence[float] | models.NamedVector, - collection: str, - limit: int = 10, - conditions: Any | None = None, - threshold: float | None = None, - ) -> list[SearchRow]: - points = await self.client.search( - collection_name=collection, - query_vector=embedding, - query_filter=conditions, - limit=limit, - with_payload=True, - with_vectors=False, - score_threshold=threshold, - ) - - return [ - SearchRow(chunk_id=str(point.id), score=point.score, payload=point.payload) - for point in points - if point.payload is not None - ] diff --git a/src/vector_search_mcp/engine/__init__.py b/src/vector_search_mcp/engine/__init__.py new file mode 100644 index 0000000..9dbf918 --- /dev/null +++ b/src/vector_search_mcp/engine/__init__.py @@ -0,0 +1,26 @@ +from enum import StrEnum +from typing import Literal, overload + +from .qdrant_engine import QdrantEngine + + +class EngineType(StrEnum): + QDRANT = "qdrant" + COSMOS = "cosmos" + + +@overload +def get_engine(backend: Literal[EngineType.QDRANT]) -> QdrantEngine: ... + + +@overload +def get_engine(backend: Literal[EngineType.COSMOS]) -> QdrantEngine: ... + + +def get_engine(backend: EngineType): + if backend == EngineType.QDRANT: + return QdrantEngine() + elif backend == EngineType.COSMOS: + raise NotImplementedError("Cosmos engine is not implemented yet") + else: + raise ValueError(f"Unknown engine type: {backend}") diff --git a/src/vector_search_mcp/engine/base_engine.py b/src/vector_search_mcp/engine/base_engine.py new file mode 100644 index 0000000..0cb2b3e --- /dev/null +++ b/src/vector_search_mcp/engine/base_engine.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from ..models import Condition, SearchRow + +ResponseType = TypeVar("ResponseType") +ConditionType = TypeVar("ConditionType") + +__all__ = ["BaseEngine"] + + +class BaseEngine(ABC, Generic[ResponseType, ConditionType]): + @abstractmethod + def transform_conditions( + self, conditions: list[Condition] | None + ) -> ConditionType | None: ... + + @abstractmethod + def transform_response(self, response: ResponseType) -> list[SearchRow]: ... + + @abstractmethod + async def run_similarity_query( + self, + embedding: list[float], + collection: str, + limit: int = 10, + conditions: ConditionType | None = None, + threshold: float | None = None, + ) -> ResponseType: ... + + async def semantic_search( + self, + vector: list[float], + collection: str, + limit: int = 10, + conditions: list[Condition] | None = None, + threshold: float | None = None, + ) -> list[SearchRow]: + transformed_conditions = self.transform_conditions(conditions) + response = await self.run_similarity_query( + vector, collection, limit, transformed_conditions, threshold + ) + return self.transform_response(response) diff --git a/src/vector_search_mcp/engine/qdrant_engine.py b/src/vector_search_mcp/engine/qdrant_engine.py new file mode 100644 index 0000000..79c8739 --- /dev/null +++ b/src/vector_search_mcp/engine/qdrant_engine.py @@ -0,0 +1,79 @@ +from collections.abc import Sequence +from typing import final, override + +from qdrant_client import AsyncQdrantClient, models + +from ..config import Settings +from ..models import SearchRow, Condition, Match, MatchAny, MatchExclude +from .base_engine import BaseEngine + +__all__ = ["QdrantEngine"] + + +@final +class QdrantEngine(BaseEngine[list[models.ScoredPoint], models.Filter]): + def __init__(self) -> None: + self.settings = Settings() # type: ignore[reportCallArgs] + self.client = AsyncQdrantClient( + url=self.settings.url, api_key=self.settings.api_key + ) + + @override + def transform_conditions( + self, conditions: list[Condition] | None + ) -> models.Filter | None: + if not conditions: + return None + + filters: list[models.Condition] = [] + + for condition in conditions: + if isinstance(condition, Match): + filters.append( + models.FieldCondition( + key=condition.key, + match=models.MatchValue(value=condition.value), + ) + ) + elif isinstance(condition, MatchAny): + filters.append( + models.FieldCondition( + key=condition.key, match=models.MatchAny(any=condition.any) + ) + ) + elif isinstance(condition, MatchExclude): + filters.append( + models.FieldCondition( + key=condition.key, + match=models.MatchExcept(**{"except": condition.exclude}), + ) + ) + + return models.Filter(must=filters) + + @override + def transform_response(self, response: list[models.ScoredPoint]) -> list[SearchRow]: + return [ + SearchRow(chunk_id=str(point.id), score=point.score, payload=point.payload) + for point in response + if point.payload is not None + ] + + @override + async def run_similarity_query( + self, + embedding: Sequence[float] | models.NamedVector, + collection: str, + limit: int = 10, + conditions: models.Filter | None = None, + threshold: float | None = None, + ) -> list[models.ScoredPoint]: + return await self.client.search( + collection_name=collection, + query_vector=embedding, + query_filter=conditions, + limit=limit, + with_payload=True, + with_vectors=False, + score_threshold=threshold, + ) diff --git a/src/vector_search_mcp/models.py b/src/vector_search_mcp/models.py index cf6b84c..4727ddc 100644 --- a/src/vector_search_mcp/models.py +++ b/src/vector_search_mcp/models.py @@ -6,4 +6,22 @@ from pydantic import BaseModel class SearchRow(BaseModel): chunk_id: str score: float - payload: dict[str, Any] + payload: dict[str, Any] # type: ignore[reportExplicitAny] + + +class Condition(BaseModel): ... + + +class Match(Condition): + key: str + value: str + + +class MatchAny(Condition): + key: str + any: list[str] + + +class MatchExclude(Condition): + key: str + exclude: list[str]