Add Backend enum

This commit is contained in:
2025-09-26 15:13:22 +00:00
parent 0656ed93f1
commit a91188e83f
4 changed files with 13 additions and 11 deletions

View File

@@ -1,26 +1,28 @@
from enum import StrEnum
from functools import cache
from typing import Literal, overload
from .qdrant_engine import QdrantEngine
class EngineType(StrEnum):
class Backend(StrEnum):
QDRANT = "qdrant"
COSMOS = "cosmos"
@overload
def get_engine(backend: Literal[EngineType.QDRANT]) -> QdrantEngine: ...
def get_engine(backend: Literal[Backend.QDRANT]) -> QdrantEngine: ...
@overload
def get_engine(backend: Literal[EngineType.COSMOS]) -> QdrantEngine: ...
def get_engine(backend: Literal[Backend.COSMOS]) -> QdrantEngine: ...
def get_engine(backend: EngineType):
if backend == EngineType.QDRANT:
@cache
def get_engine(backend: Backend):
if backend == Backend.QDRANT:
return QdrantEngine()
elif backend == EngineType.COSMOS:
elif backend == Backend.COSMOS:
raise NotImplementedError("Cosmos engine is not implemented yet")
else:
raise ValueError(f"Unknown engine type: {backend}")

View File

@@ -30,7 +30,7 @@ class BaseEngine(ABC, Generic[ResponseType, ConditionType]):
async def semantic_search(
self,
vector: list[float],
embedding: list[float],
collection: str,
limit: int = 10,
conditions: list[Condition] | None = None,
@@ -38,6 +38,6 @@ class BaseEngine(ABC, Generic[ResponseType, ConditionType]):
) -> list[SearchRow]:
transformed_conditions = self.transform_conditions(conditions)
response = await self.run_similarity_query(
vector, collection, limit, transformed_conditions, threshold
embedding, collection, limit, transformed_conditions, threshold
)
return self.transform_response(response)

View File

@@ -4,7 +4,7 @@ from typing import final, override
from qdrant_client import AsyncQdrantClient, models
from ..config import Settings
from ..models import SearchRow, Condition, Match, MatchAny, MatchExclude
from ..models import Condition, Match, MatchAny, MatchExclude, SearchRow
from .base_engine import BaseEngine
__all__ = ["QdrantEngine"]

View File

@@ -1,9 +1,9 @@
from fastmcp import FastMCP
from .engine import QdrantEngine
from .engine import Backend, get_engine
mcp = FastMCP("Vector Search MCP")
engine = QdrantEngine()
engine = get_engine(Backend.QDRANT)
_ = mcp.tool(engine.semantic_search)