Add metadata filtering #5
24
main.py
24
main.py
@@ -6,6 +6,7 @@ import io
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import BinaryIO, TypedDict
|
||||
|
||||
import aiohttp
|
||||
@@ -21,6 +22,14 @@ HTTP_TOO_MANY_REQUESTS = 429
|
||||
HTTP_SERVER_ERROR = 500
|
||||
|
||||
|
||||
class SourceNamespace(str, Enum):
|
||||
"""Allowed values for the 'source' namespace filter."""
|
||||
|
||||
EDUCACION_FINANCIERA = "Educacion Financiera"
|
||||
PRODUCTOS_Y_SERVICIOS = "Productos y Servicios"
|
||||
FUNCIONALIDADES_APP_MOVIL = "Funcionalidades de la App Movil"
|
||||
|
||||
|
||||
class GoogleCloudFileStorage:
|
||||
"""Cache-aware helper for downloading files from Google Cloud Storage."""
|
||||
|
||||
@@ -236,6 +245,7 @@ class GoogleCloudVectorSearch:
|
||||
deployed_index_id: str,
|
||||
query: Sequence[float],
|
||||
limit: int,
|
||||
source: SourceNamespace | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Run an async similarity search via the REST API.
|
||||
|
||||
@@ -243,6 +253,7 @@ class GoogleCloudVectorSearch:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
source: Optional namespace filter to restrict results by source.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
@@ -279,11 +290,16 @@ class GoogleCloudVectorSearch:
|
||||
}
|
||||
)
|
||||
|
||||
datapoint: dict = {"feature_vector": list(query)}
|
||||
if source is not None:
|
||||
datapoint["restricts"] = [
|
||||
{"namespace": "source", "allow_list": [source.value]},
|
||||
]
|
||||
payload = {
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"queries": [
|
||||
{
|
||||
"datapoint": {"feature_vector": list(query)},
|
||||
"datapoint": datapoint,
|
||||
"neighbor_count": limit,
|
||||
},
|
||||
],
|
||||
@@ -636,12 +652,16 @@ mcp = FastMCP(
|
||||
async def knowledge_search(
|
||||
query: str,
|
||||
ctx: Context,
|
||||
source: SourceNamespace | None = None,
|
||||
) -> str:
|
||||
"""Search a knowledge base using a natural-language query.
|
||||
|
||||
Args:
|
||||
query: The text query to search for.
|
||||
ctx: MCP request context (injected automatically).
|
||||
source: Optional filter to restrict results by source.
|
||||
Allowed values: 'Educacion Financiera',
|
||||
'Productos y Servicios', 'Funcionalidades de la App Movil'.
|
||||
|
||||
Returns:
|
||||
A formatted string containing matched documents with id and content.
|
||||
@@ -712,6 +732,7 @@ async def knowledge_search(
|
||||
deployed_index_id=app.settings.deployed_index_id,
|
||||
query=embedding,
|
||||
limit=app.settings.search_limit,
|
||||
source=source,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
except Exception as e:
|
||||
@@ -743,6 +764,7 @@ async def knowledge_search(
|
||||
"embedding_ms": f"{round((t_embed - t0) * 1000, 1)}ms",
|
||||
"vector_search_ms": f"{round((t_search - t_embed) * 1000, 1)}ms",
|
||||
"total_ms": f"{round((t_search - t0) * 1000, 1)}ms",
|
||||
"source_filter": source.value if source is not None else None,
|
||||
"results_count": len(search_results),
|
||||
"chunks": [s["id"] for s in search_results]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user