forked from innovacion/searchbox
Compare commits
1 Commits
main
...
a3d972ddb9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a3d972ddb9 |
3
.github/workflows/ci.yaml
vendored
3
.github/workflows/ci.yaml
vendored
@@ -4,9 +4,6 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
ci:
|
ci:
|
||||||
|
|||||||
@@ -6,10 +6,19 @@ operations across different backend implementations.
|
|||||||
|
|
||||||
from typing import final
|
from typing import final
|
||||||
|
|
||||||
|
from .embedder.base import BaseEmbedder
|
||||||
from .engine import Backend, get_engine
|
from .engine import Backend, get_engine
|
||||||
from .models import Chunk, Condition
|
from .models import Chunk, Condition
|
||||||
|
|
||||||
|
|
||||||
|
class QueryError(ValueError):
|
||||||
|
"""Raised when query parameters are invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedderNotConfiguredError(ValueError):
|
||||||
|
"""Raised when embedder is required but not configured."""
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class Client:
|
class Client:
|
||||||
"""High-level client for vector search operations.
|
"""High-level client for vector search operations.
|
||||||
@@ -20,20 +29,29 @@ class Client:
|
|||||||
Args:
|
Args:
|
||||||
backend: The vector search backend to use (e.g., Backend.QDRANT)
|
backend: The vector search backend to use (e.g., Backend.QDRANT)
|
||||||
collection: Name of the collection to operate on
|
collection: Name of the collection to operate on
|
||||||
|
embedder: Optional embedder for converting text queries to vectors
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backend: Backend, collection: str, **kwargs: str):
|
def __init__(
|
||||||
|
self,
|
||||||
|
backend: Backend,
|
||||||
|
collection: str,
|
||||||
|
embedder: BaseEmbedder | None = None,
|
||||||
|
**kwargs: str,
|
||||||
|
):
|
||||||
"""Initialize the client with a specific backend and collection.
|
"""Initialize the client with a specific backend and collection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
backend: The vector search backend to use
|
backend: The vector search backend to use
|
||||||
collection: Name of the collection to operate on
|
collection: Name of the collection to operate on
|
||||||
|
embedder: Optional embedder for automatic query embedding
|
||||||
**kwargs: Additional keyword arguments to pass to the backend
|
**kwargs: Additional keyword arguments to pass to the backend
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.engine = get_engine(backend, **kwargs)
|
self.engine = get_engine(backend, **kwargs)
|
||||||
self.collection = collection
|
self.collection = collection
|
||||||
|
self.embedder = embedder
|
||||||
|
|
||||||
async def create_index(self, size: int) -> bool:
|
async def create_index(self, size: int) -> bool:
|
||||||
"""Create a vector index with the specified dimension size.
|
"""Create a vector index with the specified dimension size.
|
||||||
@@ -61,7 +79,8 @@ class Client:
|
|||||||
|
|
||||||
async def semantic_search(
|
async def semantic_search(
|
||||||
self,
|
self,
|
||||||
embedding: list[float],
|
query: str | list[float] | None = None,
|
||||||
|
embedding: list[float] | None = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
conditions: list[Condition] | None = None,
|
conditions: list[Condition] | None = None,
|
||||||
threshold: float | None = None,
|
threshold: float | None = None,
|
||||||
@@ -69,7 +88,8 @@ class Client:
|
|||||||
"""Perform semantic search using vector similarity.
|
"""Perform semantic search using vector similarity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding: Query vector as a list of floats
|
query: Text query to embed (requires embedder to be configured)
|
||||||
|
embedding: Pre-computed query vector as a list of floats
|
||||||
limit: Maximum number of results to return (default: 10)
|
limit: Maximum number of results to return (default: 10)
|
||||||
conditions: Optional list of filter conditions to apply
|
conditions: Optional list of filter conditions to apply
|
||||||
threshold: Optional minimum similarity score threshold
|
threshold: Optional minimum similarity score threshold
|
||||||
@@ -77,7 +97,30 @@ class Client:
|
|||||||
Returns:
|
Returns:
|
||||||
List of search results with chunk IDs, scores, and metadata
|
List of search results with chunk IDs, scores, and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If neither query nor embedding is provided, or if query
|
||||||
|
is provided but no embedder is configured
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if query is None and embedding is None:
|
||||||
|
msg = "Either 'query' or 'embedding' must be provided"
|
||||||
|
raise QueryError(msg)
|
||||||
|
|
||||||
|
if query is not None and embedding is not None:
|
||||||
|
msg = "Only one of 'query' or 'embedding' should be provided"
|
||||||
|
raise QueryError(msg)
|
||||||
|
|
||||||
|
# Handle query string
|
||||||
|
if query is not None:
|
||||||
|
if isinstance(query, str):
|
||||||
|
if self.embedder is None:
|
||||||
|
msg = "Cannot use 'query' parameter without an embedder"
|
||||||
|
raise EmbedderNotConfiguredError(msg)
|
||||||
|
embedding = self.embedder.embed(query)
|
||||||
|
else:
|
||||||
|
# query is already a list[float]
|
||||||
|
embedding = query
|
||||||
|
|
||||||
return await self.engine.semantic_search(
|
return await self.engine.semantic_search(
|
||||||
embedding, self.collection, limit, conditions, threshold
|
embedding, self.collection, limit, conditions, threshold
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,69 @@
|
|||||||
"""Embedder class using Azure AI Foundry."""
|
"""Embedder class using Azure AI Foundry."""
|
||||||
|
|
||||||
|
from openai import AzureOpenAI
|
||||||
|
|
||||||
from .base import BaseEmbedder
|
from .base import BaseEmbedder
|
||||||
|
|
||||||
|
|
||||||
class AzureEmbedder(BaseEmbedder):
|
class AzureEmbedder(BaseEmbedder):
|
||||||
def embed(self, text: str) -> list[float]: ...
|
"""Embedder implementation using Azure OpenAI Service.
|
||||||
|
|
||||||
|
Provides text embedding generation through Azure's OpenAI API endpoint.
|
||||||
|
Compatible with any Azure OpenAI embedding model (text-embedding-ada-002,
|
||||||
|
text-embedding-3-small, text-embedding-3-large, etc.).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The embedding model name (e.g., "text-embedding-3-large")
|
||||||
|
azure_endpoint: Azure OpenAI endpoint URL
|
||||||
|
api_key: Azure OpenAI API key
|
||||||
|
openai_api_version: API version (e.g., "2024-02-01")
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> embedder = AzureEmbedder(
|
||||||
|
... model="text-embedding-3-large",
|
||||||
|
... azure_endpoint="https://chatocp.openai.azure.com/",
|
||||||
|
... api_key="your-api-key",
|
||||||
|
... openai_api_version="2024-02-01"
|
||||||
|
... )
|
||||||
|
>>> embedding = embedder.embed("Hello world")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
azure_endpoint: str,
|
||||||
|
api_key: str,
|
||||||
|
openai_api_version: str,
|
||||||
|
):
|
||||||
|
"""Initialize the Azure OpenAI embedder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The embedding model name (e.g., "text-embedding-3-large")
|
||||||
|
azure_endpoint: Azure OpenAI endpoint URL
|
||||||
|
api_key: Azure OpenAI API key
|
||||||
|
openai_api_version: API version (e.g., "2024-02-01")
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.client = AzureOpenAI(
|
||||||
|
azure_endpoint=azure_endpoint,
|
||||||
|
api_key=api_key,
|
||||||
|
api_version=openai_api_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed(self, text: str) -> list[float]:
|
||||||
|
"""Generate embedding vector for the given text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of floats representing the embedding vector
|
||||||
|
|
||||||
|
"""
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=text,
|
||||||
|
)
|
||||||
|
return response.data[0].embedding
|
||||||
|
|||||||
@@ -1,6 +1,50 @@
|
|||||||
|
"""Base embedder interface for text embedding models.
|
||||||
|
|
||||||
|
This module defines the abstract base class that all embedder implementations
|
||||||
|
must inherit from, ensuring a consistent interface across different embedding
|
||||||
|
providers (Azure OpenAI, FastEmbed, OpenAI, Cohere, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbedder(ABC):
|
class BaseEmbedder(ABC):
|
||||||
|
"""Abstract base class for text embedding models.
|
||||||
|
|
||||||
|
This class defines the interface that all embedder implementations must follow,
|
||||||
|
allowing the system to work with any embedding model provider through a
|
||||||
|
unified API.
|
||||||
|
|
||||||
|
Implementations should inherit from this class and provide concrete
|
||||||
|
implementations of the embed() method for their specific embedding service.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> class MyEmbedder(BaseEmbedder):
|
||||||
|
... def embed(self, text: str) -> list[float]:
|
||||||
|
... # Implementation specific to your embedding service
|
||||||
|
... return [0.1, 0.2, 0.3, ...]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def embed(self, text: str) -> list[float]: ...
|
def embed(self, text: str) -> list[float]:
|
||||||
|
"""Generate embedding vector for the given text.
|
||||||
|
|
||||||
|
This method must be implemented by all concrete embedder classes to
|
||||||
|
convert input text into a dense vector representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of floats representing the embedding vector. The dimension
|
||||||
|
of the vector depends on the specific embedding model being used.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> embedder = SomeEmbedder()
|
||||||
|
>>> vector = embedder.embed("Hello world")
|
||||||
|
>>> len(vector)
|
||||||
|
1536
|
||||||
|
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@@ -22,30 +22,43 @@ from fastmcp import FastMCP
|
|||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
from ..engine import get_engine
|
from ..client import Backend, Client
|
||||||
|
from ..embedder.azure import AzureEmbedder
|
||||||
|
|
||||||
mcp = FastMCP("Searchbox MCP")
|
mcp = FastMCP("Searchbox MCP")
|
||||||
|
|
||||||
engine_map = {"qdrant": get_engine("qdrant")}
|
# Initialize Azure embedder
|
||||||
|
embedder = AzureEmbedder(
|
||||||
|
model="",
|
||||||
|
azure_endpoint="",
|
||||||
|
api_key="",
|
||||||
|
openai_api_version="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool(exclude_args=["backend", "embedding", "collection", "limit", "threshold"])
|
@mcp.tool(exclude_args=["backend", "collection", "limit", "threshold"])
|
||||||
async def get_information(
|
async def get_information(
|
||||||
query: Annotated[str, "The user query"],
|
query: Annotated[str, "The user query"],
|
||||||
backend: str = "qdrant",
|
backend: str = "qdrant",
|
||||||
embedding: list[float] = [],
|
|
||||||
collection: str = "default",
|
collection: str = "default",
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
threshold: float | None = None,
|
threshold: float | None = None,
|
||||||
):
|
):
|
||||||
"""Search a private repository for information."""
|
"""Search a private repository for information using semantic search.
|
||||||
_ = query
|
|
||||||
|
|
||||||
engine = engine_map[backend]
|
The query will be automatically converted to an embedding vector using
|
||||||
|
Azure OpenAI's text-embedding-3-large model before searching.
|
||||||
result = await engine.semantic_search(
|
"""
|
||||||
embedding=embedding,
|
# Create client with embedder
|
||||||
|
client = Client(
|
||||||
|
backend=Backend.QDRANT if backend == "qdrant" else Backend.QDRANT,
|
||||||
collection=collection,
|
collection=collection,
|
||||||
|
embedder=embedder,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform semantic search with automatic embedding
|
||||||
|
result = await client.semantic_search(
|
||||||
|
query=query,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastmcp import Client
|
from fastmcp import Client
|
||||||
from fastembed import TextEmbedding
|
|
||||||
|
|
||||||
from searchbox.mcp_server.server import mcp
|
from searchbox.mcp_server.server import mcp
|
||||||
|
|
||||||
embedding_model = TextEmbedding()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mcp_client():
|
async def mcp_client():
|
||||||
@@ -15,19 +13,18 @@ async def mcp_client():
|
|||||||
|
|
||||||
|
|
||||||
async def test_mcp_qdrant_backend(mcp_client):
|
async def test_mcp_qdrant_backend(mcp_client):
|
||||||
embedding = list(embedding_model.embed("Quien es el mas guapo"))[0].tolist()
|
"""Test MCP server with automatic Azure embedding."""
|
||||||
|
|
||||||
result = await mcp_client.call_tool(
|
result = await mcp_client.call_tool(
|
||||||
name="get_information",
|
name="get_information",
|
||||||
arguments={
|
arguments={
|
||||||
"query": "dummy value",
|
"query": "Quien es el mas guapo",
|
||||||
"collection": "dummy_collection",
|
"collection": "azure_collection",
|
||||||
"embedding": embedding,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
content = json.loads(result.content[0].text)[0]
|
content = json.loads(result.content[0].text)[0]
|
||||||
|
|
||||||
assert content["chunk_id"] == "0"
|
assert content["score"] >= 0.65
|
||||||
assert content["score"] >= 0.7
|
assert content["payload"]["page_content"] == "Rick es el mas guapo"
|
||||||
assert content["payload"] == {"text": "Rick es el mas guapo"}
|
assert content["payload"]["filename"] == "test.txt"
|
||||||
|
assert content["payload"]["page"] == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user