forked from innovacion/searchbox
Add Backend enum
This commit is contained in:
@@ -1,26 +1,28 @@
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from functools import cache
|
||||||
from typing import Literal, overload
|
from typing import Literal, overload
|
||||||
|
|
||||||
from .qdrant_engine import QdrantEngine
|
from .qdrant_engine import QdrantEngine
|
||||||
|
|
||||||
|
|
||||||
class EngineType(StrEnum):
|
class Backend(StrEnum):
|
||||||
QDRANT = "qdrant"
|
QDRANT = "qdrant"
|
||||||
COSMOS = "cosmos"
|
COSMOS = "cosmos"
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_engine(backend: Literal[EngineType.QDRANT]) -> QdrantEngine: ...
|
def get_engine(backend: Literal[Backend.QDRANT]) -> QdrantEngine: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_engine(backend: Literal[EngineType.COSMOS]) -> QdrantEngine: ...
|
def get_engine(backend: Literal[Backend.COSMOS]) -> QdrantEngine: ...
|
||||||
|
|
||||||
|
|
||||||
def get_engine(backend: EngineType):
|
@cache
|
||||||
if backend == EngineType.QDRANT:
|
def get_engine(backend: Backend):
|
||||||
|
if backend == Backend.QDRANT:
|
||||||
return QdrantEngine()
|
return QdrantEngine()
|
||||||
elif backend == EngineType.COSMOS:
|
elif backend == Backend.COSMOS:
|
||||||
raise NotImplementedError("Cosmos engine is not implemented yet")
|
raise NotImplementedError("Cosmos engine is not implemented yet")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown engine type: {backend}")
|
raise ValueError(f"Unknown engine type: {backend}")
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class BaseEngine(ABC, Generic[ResponseType, ConditionType]):
|
|||||||
|
|
||||||
async def semantic_search(
|
async def semantic_search(
|
||||||
self,
|
self,
|
||||||
vector: list[float],
|
embedding: list[float],
|
||||||
collection: str,
|
collection: str,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
conditions: list[Condition] | None = None,
|
conditions: list[Condition] | None = None,
|
||||||
@@ -38,6 +38,6 @@ class BaseEngine(ABC, Generic[ResponseType, ConditionType]):
|
|||||||
) -> list[SearchRow]:
|
) -> list[SearchRow]:
|
||||||
transformed_conditions = self.transform_conditions(conditions)
|
transformed_conditions = self.transform_conditions(conditions)
|
||||||
response = await self.run_similarity_query(
|
response = await self.run_similarity_query(
|
||||||
vector, collection, limit, transformed_conditions, threshold
|
embedding, collection, limit, transformed_conditions, threshold
|
||||||
)
|
)
|
||||||
return self.transform_response(response)
|
return self.transform_response(response)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import final, override
|
|||||||
from qdrant_client import AsyncQdrantClient, models
|
from qdrant_client import AsyncQdrantClient, models
|
||||||
|
|
||||||
from ..config import Settings
|
from ..config import Settings
|
||||||
from ..models import SearchRow, Condition, Match, MatchAny, MatchExclude
|
from ..models import Condition, Match, MatchAny, MatchExclude, SearchRow
|
||||||
from .base_engine import BaseEngine
|
from .base_engine import BaseEngine
|
||||||
|
|
||||||
__all__ = ["QdrantEngine"]
|
__all__ = ["QdrantEngine"]
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from fastmcp import FastMCP
|
from fastmcp import FastMCP
|
||||||
|
|
||||||
from .engine import QdrantEngine
|
from .engine import Backend, get_engine
|
||||||
|
|
||||||
mcp = FastMCP("Vector Search MCP")
|
mcp = FastMCP("Vector Search MCP")
|
||||||
|
|
||||||
engine = QdrantEngine()
|
engine = get_engine(Backend.QDRANT)
|
||||||
|
|
||||||
_ = mcp.tool(engine.semantic_search)
|
_ = mcp.tool(engine.semantic_search)
|
||||||
|
|||||||
Reference in New Issue
Block a user