forked from innovacion/searchbox
Add Backend enum
This commit is contained in:
@@ -1,26 +1,28 @@
|
||||
from enum import StrEnum
|
||||
from functools import cache
|
||||
from typing import Literal, overload
|
||||
|
||||
from .qdrant_engine import QdrantEngine
|
||||
|
||||
|
||||
class EngineType(StrEnum):
|
||||
class Backend(StrEnum):
|
||||
QDRANT = "qdrant"
|
||||
COSMOS = "cosmos"
|
||||
|
||||
|
||||
@overload
|
||||
def get_engine(backend: Literal[EngineType.QDRANT]) -> QdrantEngine: ...
|
||||
def get_engine(backend: Literal[Backend.QDRANT]) -> QdrantEngine: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_engine(backend: Literal[EngineType.COSMOS]) -> QdrantEngine: ...
|
||||
def get_engine(backend: Literal[Backend.COSMOS]) -> QdrantEngine: ...
|
||||
|
||||
|
||||
def get_engine(backend: EngineType):
|
||||
if backend == EngineType.QDRANT:
|
||||
@cache
|
||||
def get_engine(backend: Backend):
|
||||
if backend == Backend.QDRANT:
|
||||
return QdrantEngine()
|
||||
elif backend == EngineType.COSMOS:
|
||||
elif backend == Backend.COSMOS:
|
||||
raise NotImplementedError("Cosmos engine is not implemented yet")
|
||||
else:
|
||||
raise ValueError(f"Unknown engine type: {backend}")
|
||||
|
||||
@@ -30,7 +30,7 @@ class BaseEngine(ABC, Generic[ResponseType, ConditionType]):
|
||||
|
||||
async def semantic_search(
|
||||
self,
|
||||
vector: list[float],
|
||||
embedding: list[float],
|
||||
collection: str,
|
||||
limit: int = 10,
|
||||
conditions: list[Condition] | None = None,
|
||||
@@ -38,6 +38,6 @@ class BaseEngine(ABC, Generic[ResponseType, ConditionType]):
|
||||
) -> list[SearchRow]:
|
||||
transformed_conditions = self.transform_conditions(conditions)
|
||||
response = await self.run_similarity_query(
|
||||
vector, collection, limit, transformed_conditions, threshold
|
||||
embedding, collection, limit, transformed_conditions, threshold
|
||||
)
|
||||
return self.transform_response(response)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import final, override
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
|
||||
from ..config import Settings
|
||||
from ..models import SearchRow, Condition, Match, MatchAny, MatchExclude
|
||||
from ..models import Condition, Match, MatchAny, MatchExclude, SearchRow
|
||||
from .base_engine import BaseEngine
|
||||
|
||||
__all__ = ["QdrantEngine"]
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from fastmcp import FastMCP
|
||||
|
||||
from .engine import QdrantEngine
|
||||
from .engine import Backend, get_engine
|
||||
|
||||
mcp = FastMCP("Vector Search MCP")
|
||||
|
||||
engine = QdrantEngine()
|
||||
engine = get_engine(Backend.QDRANT)
|
||||
|
||||
_ = mcp.tool(engine.semantic_search)
|
||||
|
||||
Reference in New Issue
Block a user