forked from innovacion/searchbox
Compare commits
1 Commits
main
...
a3d972ddb9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a3d972ddb9 |
3
.github/workflows/ci.yaml
vendored
3
.github/workflows/ci.yaml
vendored
@@ -4,9 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
ci:
|
||||
|
||||
@@ -6,10 +6,19 @@ operations across different backend implementations.
|
||||
|
||||
from typing import final
|
||||
|
||||
from .embedder.base import BaseEmbedder
|
||||
from .engine import Backend, get_engine
|
||||
from .models import Chunk, Condition
|
||||
|
||||
|
||||
class QueryError(ValueError):
|
||||
"""Raised when query parameters are invalid."""
|
||||
|
||||
|
||||
class EmbedderNotConfiguredError(ValueError):
|
||||
"""Raised when embedder is required but not configured."""
|
||||
|
||||
|
||||
@final
|
||||
class Client:
|
||||
"""High-level client for vector search operations.
|
||||
@@ -20,20 +29,29 @@ class Client:
|
||||
Args:
|
||||
backend: The vector search backend to use (e.g., Backend.QDRANT)
|
||||
collection: Name of the collection to operate on
|
||||
embedder: Optional embedder for converting text queries to vectors
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Backend, collection: str, **kwargs: str):
|
||||
def __init__(
|
||||
self,
|
||||
backend: Backend,
|
||||
collection: str,
|
||||
embedder: BaseEmbedder | None = None,
|
||||
**kwargs: str,
|
||||
):
|
||||
"""Initialize the client with a specific backend and collection.
|
||||
|
||||
Args:
|
||||
backend: The vector search backend to use
|
||||
collection: Name of the collection to operate on
|
||||
embedder: Optional embedder for automatic query embedding
|
||||
**kwargs: Additional keyword arguments to pass to the backend
|
||||
|
||||
"""
|
||||
self.engine = get_engine(backend, **kwargs)
|
||||
self.collection = collection
|
||||
self.embedder = embedder
|
||||
|
||||
async def create_index(self, size: int) -> bool:
|
||||
"""Create a vector index with the specified dimension size.
|
||||
@@ -61,7 +79,8 @@ class Client:
|
||||
|
||||
async def semantic_search(
|
||||
self,
|
||||
embedding: list[float],
|
||||
query: str | list[float] | None = None,
|
||||
embedding: list[float] | None = None,
|
||||
limit: int = 10,
|
||||
conditions: list[Condition] | None = None,
|
||||
threshold: float | None = None,
|
||||
@@ -69,7 +88,8 @@ class Client:
|
||||
"""Perform semantic search using vector similarity.
|
||||
|
||||
Args:
|
||||
embedding: Query vector as a list of floats
|
||||
query: Text query to embed (requires embedder to be configured)
|
||||
embedding: Pre-computed query vector as a list of floats
|
||||
limit: Maximum number of results to return (default: 10)
|
||||
conditions: Optional list of filter conditions to apply
|
||||
threshold: Optional minimum similarity score threshold
|
||||
@@ -77,7 +97,30 @@ class Client:
|
||||
Returns:
|
||||
List of search results with chunk IDs, scores, and metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If neither query nor embedding is provided, or if query
|
||||
is provided but no embedder is configured
|
||||
|
||||
"""
|
||||
if query is None and embedding is None:
|
||||
msg = "Either 'query' or 'embedding' must be provided"
|
||||
raise QueryError(msg)
|
||||
|
||||
if query is not None and embedding is not None:
|
||||
msg = "Only one of 'query' or 'embedding' should be provided"
|
||||
raise QueryError(msg)
|
||||
|
||||
# Handle query string
|
||||
if query is not None:
|
||||
if isinstance(query, str):
|
||||
if self.embedder is None:
|
||||
msg = "Cannot use 'query' parameter without an embedder"
|
||||
raise EmbedderNotConfiguredError(msg)
|
||||
embedding = self.embedder.embed(query)
|
||||
else:
|
||||
# query is already a list[float]
|
||||
embedding = query
|
||||
|
||||
return await self.engine.semantic_search(
|
||||
embedding, self.collection, limit, conditions, threshold
|
||||
)
|
||||
|
||||
@@ -1,7 +1,69 @@
|
||||
"""Embedder class using Azure AI Foundry."""
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
|
||||
class AzureEmbedder(BaseEmbedder):
|
||||
def embed(self, text: str) -> list[float]: ...
|
||||
"""Embedder implementation using Azure OpenAI Service.
|
||||
|
||||
Provides text embedding generation through Azure's OpenAI API endpoint.
|
||||
Compatible with any Azure OpenAI embedding model (text-embedding-ada-002,
|
||||
text-embedding-3-small, text-embedding-3-large, etc.).
|
||||
|
||||
Args:
|
||||
model: The embedding model name (e.g., "text-embedding-3-large")
|
||||
azure_endpoint: Azure OpenAI endpoint URL
|
||||
api_key: Azure OpenAI API key
|
||||
openai_api_version: API version (e.g., "2024-02-01")
|
||||
|
||||
Example:
|
||||
>>> embedder = AzureEmbedder(
|
||||
... model="text-embedding-3-large",
|
||||
... azure_endpoint="https://chatocp.openai.azure.com/",
|
||||
... api_key="your-api-key",
|
||||
... openai_api_version="2024-02-01"
|
||||
... )
|
||||
>>> embedding = embedder.embed("Hello world")
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
azure_endpoint: str,
|
||||
api_key: str,
|
||||
openai_api_version: str,
|
||||
):
|
||||
"""Initialize the Azure OpenAI embedder.
|
||||
|
||||
Args:
|
||||
model: The embedding model name (e.g., "text-embedding-3-large")
|
||||
azure_endpoint: Azure OpenAI endpoint URL
|
||||
api_key: Azure OpenAI API key
|
||||
openai_api_version: API version (e.g., "2024-02-01")
|
||||
|
||||
"""
|
||||
self.model = model
|
||||
self.client = AzureOpenAI(
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_key=api_key,
|
||||
api_version=openai_api_version,
|
||||
)
|
||||
|
||||
def embed(self, text: str) -> list[float]:
|
||||
"""Generate embedding vector for the given text.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
|
||||
Returns:
|
||||
List of floats representing the embedding vector
|
||||
|
||||
"""
|
||||
response = self.client.embeddings.create(
|
||||
model=self.model,
|
||||
input=text,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
@@ -1,6 +1,50 @@
|
||||
"""Base embedder interface for text embedding models.
|
||||
|
||||
This module defines the abstract base class that all embedder implementations
|
||||
must inherit from, ensuring a consistent interface across different embedding
|
||||
providers (Azure OpenAI, FastEmbed, OpenAI, Cohere, etc.).
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseEmbedder(ABC):
|
||||
"""Abstract base class for text embedding models.
|
||||
|
||||
This class defines the interface that all embedder implementations must follow,
|
||||
allowing the system to work with any embedding model provider through a
|
||||
unified API.
|
||||
|
||||
Implementations should inherit from this class and provide concrete
|
||||
implementations of the embed() method for their specific embedding service.
|
||||
|
||||
Example:
|
||||
>>> class MyEmbedder(BaseEmbedder):
|
||||
... def embed(self, text: str) -> list[float]:
|
||||
... # Implementation specific to your embedding service
|
||||
... return [0.1, 0.2, 0.3, ...]
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> list[float]: ...
|
||||
def embed(self, text: str) -> list[float]:
|
||||
"""Generate embedding vector for the given text.
|
||||
|
||||
This method must be implemented by all concrete embedder classes to
|
||||
convert input text into a dense vector representation.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
|
||||
Returns:
|
||||
A list of floats representing the embedding vector. The dimension
|
||||
of the vector depends on the specific embedding model being used.
|
||||
|
||||
Example:
|
||||
>>> embedder = SomeEmbedder()
|
||||
>>> vector = embedder.embed("Hello world")
|
||||
>>> len(vector)
|
||||
1536
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -22,30 +22,43 @@ from fastmcp import FastMCP
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from ..engine import get_engine
|
||||
from ..client import Backend, Client
|
||||
from ..embedder.azure import AzureEmbedder
|
||||
|
||||
mcp = FastMCP("Searchbox MCP")
|
||||
|
||||
engine_map = {"qdrant": get_engine("qdrant")}
|
||||
# Initialize Azure embedder
|
||||
embedder = AzureEmbedder(
|
||||
model="",
|
||||
azure_endpoint="",
|
||||
api_key="",
|
||||
openai_api_version="",
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool(exclude_args=["backend", "embedding", "collection", "limit", "threshold"])
|
||||
@mcp.tool(exclude_args=["backend", "collection", "limit", "threshold"])
|
||||
async def get_information(
|
||||
query: Annotated[str, "The user query"],
|
||||
backend: str = "qdrant",
|
||||
embedding: list[float] = [],
|
||||
collection: str = "default",
|
||||
limit: int = 10,
|
||||
threshold: float | None = None,
|
||||
):
|
||||
"""Search a private repository for information."""
|
||||
_ = query
|
||||
"""Search a private repository for information using semantic search.
|
||||
|
||||
engine = engine_map[backend]
|
||||
|
||||
result = await engine.semantic_search(
|
||||
embedding=embedding,
|
||||
The query will be automatically converted to an embedding vector using
|
||||
Azure OpenAI's text-embedding-3-large model before searching.
|
||||
"""
|
||||
# Create client with embedder
|
||||
client = Client(
|
||||
backend=Backend.QDRANT if backend == "qdrant" else Backend.QDRANT,
|
||||
collection=collection,
|
||||
embedder=embedder,
|
||||
)
|
||||
|
||||
# Perform semantic search with automatic embedding
|
||||
result = await client.semantic_search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
from searchbox.mcp_server.server import mcp
|
||||
|
||||
embedding_model = TextEmbedding()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client():
|
||||
@@ -15,19 +13,18 @@ async def mcp_client():
|
||||
|
||||
|
||||
async def test_mcp_qdrant_backend(mcp_client):
|
||||
embedding = list(embedding_model.embed("Quien es el mas guapo"))[0].tolist()
|
||||
|
||||
"""Test MCP server with automatic Azure embedding."""
|
||||
result = await mcp_client.call_tool(
|
||||
name="get_information",
|
||||
arguments={
|
||||
"query": "dummy value",
|
||||
"collection": "dummy_collection",
|
||||
"embedding": embedding,
|
||||
"query": "Quien es el mas guapo",
|
||||
"collection": "azure_collection",
|
||||
},
|
||||
)
|
||||
|
||||
content = json.loads(result.content[0].text)[0]
|
||||
|
||||
assert content["chunk_id"] == "0"
|
||||
assert content["score"] >= 0.7
|
||||
assert content["payload"] == {"text": "Rick es el mas guapo"}
|
||||
assert content["score"] >= 0.65
|
||||
assert content["payload"]["page_content"] == "Rick es el mas guapo"
|
||||
assert content["payload"]["filename"] == "test.txt"
|
||||
assert content["payload"]["page"] == 1
|
||||
|
||||
Reference in New Issue
Block a user