diff --git a/src/vector_search_mcp/engine/__init__.py b/src/vector_search_mcp/engine/__init__.py index 9dbf918..4fd5768 100644 --- a/src/vector_search_mcp/engine/__init__.py +++ b/src/vector_search_mcp/engine/__init__.py @@ -1,26 +1,28 @@ from enum import StrEnum +from functools import cache from typing import Literal, overload from .qdrant_engine import QdrantEngine -class EngineType(StrEnum): +class Backend(StrEnum): QDRANT = "qdrant" COSMOS = "cosmos" @overload -def get_engine(backend: Literal[EngineType.QDRANT]) -> QdrantEngine: ... +def get_engine(backend: Literal[Backend.QDRANT]) -> QdrantEngine: ... @overload -def get_engine(backend: Literal[EngineType.COSMOS]) -> QdrantEngine: ... +def get_engine(backend: Literal[Backend.COSMOS]) -> QdrantEngine: ... -def get_engine(backend: EngineType): - if backend == EngineType.QDRANT: +@cache +def get_engine(backend: Backend): + if backend == Backend.QDRANT: return QdrantEngine() - elif backend == EngineType.COSMOS: + elif backend == Backend.COSMOS: raise NotImplementedError("Cosmos engine is not implemented yet") else: raise ValueError(f"Unknown engine type: {backend}") diff --git a/src/vector_search_mcp/engine/base_engine.py b/src/vector_search_mcp/engine/base_engine.py index 0cb2b3e..1c8456a 100644 --- a/src/vector_search_mcp/engine/base_engine.py +++ b/src/vector_search_mcp/engine/base_engine.py @@ -30,7 +30,7 @@ class BaseEngine(ABC, Generic[ResponseType, ConditionType]): async def semantic_search( self, - vector: list[float], + embedding: list[float], collection: str, limit: int = 10, conditions: list[Condition] | None = None, @@ -38,6 +38,6 @@ class BaseEngine(ABC, Generic[ResponseType, ConditionType]): ) -> list[SearchRow]: transformed_conditions = self.transform_conditions(conditions) response = await self.run_similarity_query( - vector, collection, limit, transformed_conditions, threshold + embedding, collection, limit, transformed_conditions, threshold ) return self.transform_response(response) diff --git a/src/vector_search_mcp/engine/qdrant_engine.py b/src/vector_search_mcp/engine/qdrant_engine.py index 79c8739..a288423 100644 --- a/src/vector_search_mcp/engine/qdrant_engine.py +++ b/src/vector_search_mcp/engine/qdrant_engine.py @@ -4,7 +4,7 @@ from typing import final, override from qdrant_client import AsyncQdrantClient, models from ..config import Settings -from ..models import SearchRow, Condition, Match, MatchAny, MatchExclude +from ..models import Condition, Match, MatchAny, MatchExclude, SearchRow from .base_engine import BaseEngine __all__ = ["QdrantEngine"] diff --git a/src/vector_search_mcp/main.py b/src/vector_search_mcp/main.py index 68692a9..a84ee6d 100644 --- a/src/vector_search_mcp/main.py +++ b/src/vector_search_mcp/main.py @@ -1,9 +1,9 @@ from fastmcp import FastMCP -from .engine import QdrantEngine +from .engine import Backend, get_engine mcp = FastMCP("Vector Search MCP") -engine = QdrantEngine() +engine = get_engine(Backend.QDRANT) _ = mcp.tool(engine.semantic_search)