Compare commits

1 Commits

Author SHA1 Message Date
Anibal Angulo
81fcc83bdf Add metadata filtering 2026-02-24 20:03:28 +00:00

26
main.py
View File

@@ -9,6 +9,7 @@ import os
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import BinaryIO, TypedDict from typing import BinaryIO, TypedDict
import aiohttp import aiohttp
@@ -25,6 +26,14 @@ HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500 HTTP_SERVER_ERROR = 500
class SourceNamespace(str, Enum):
"""Allowed values for the 'source' namespace filter."""
EDUCACION_FINANCIERA = "Educacion Financiera"
PRODUCTOS_Y_SERVICIOS = "Productos y Servicios"
FUNCIONALIDADES_APP_MOVIL = "Funcionalidades de la App Movil"
class GoogleCloudFileStorage: class GoogleCloudFileStorage:
"""Cache-aware helper for downloading files from Google Cloud Storage.""" """Cache-aware helper for downloading files from Google Cloud Storage."""
@@ -204,6 +213,7 @@ class GoogleCloudVectorSearch:
deployed_index_id: str, deployed_index_id: str,
query: Sequence[float], query: Sequence[float],
limit: int, limit: int,
source: SourceNamespace | None = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Run an async similarity search via the REST API. """Run an async similarity search via the REST API.
@@ -211,6 +221,7 @@ class GoogleCloudVectorSearch:
deployed_index_id: The ID of the deployed index. deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query. query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return. limit: Maximum number of nearest neighbors to return.
source: Optional namespace filter to restrict results by source.
Returns: Returns:
A list of matched items with id, distance, and content. A list of matched items with id, distance, and content.
@@ -229,11 +240,16 @@ class GoogleCloudVectorSearch:
f"/locations/{self.location}" f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors" f"/indexEndpoints/{endpoint_id}:findNeighbors"
) )
datapoint: dict = {"feature_vector": list(query)}
if source is not None:
datapoint["restricts"] = [
{"namespace": "source", "allow_list": [source.value]},
]
payload = { payload = {
"deployed_index_id": deployed_index_id, "deployed_index_id": deployed_index_id,
"queries": [ "queries": [
{ {
"datapoint": {"feature_vector": list(query)}, "datapoint": datapoint,
"neighbor_count": limit, "neighbor_count": limit,
}, },
], ],
@@ -385,12 +401,16 @@ mcp = FastMCP(
async def knowledge_search( async def knowledge_search(
query: str, query: str,
ctx: Context, ctx: Context,
source: SourceNamespace | None = None,
) -> str: ) -> str:
"""Search a knowledge base using a natural-language query. """Search a knowledge base using a natural-language query.
Args: Args:
query: The text query to search for. query: The text query to search for.
ctx: MCP request context (injected automatically). ctx: MCP request context (injected automatically).
source: Optional filter to restrict results by source.
Allowed values: 'Educacion Financiera',
'Productos y Servicios', 'Funcionalidades de la App Movil'.
Returns: Returns:
A formatted string containing matched documents with id and content. A formatted string containing matched documents with id and content.
@@ -417,6 +437,7 @@ async def knowledge_search(
deployed_index_id=app.settings.deployed_index_id, deployed_index_id=app.settings.deployed_index_id,
query=embedding, query=embedding,
limit=app.settings.search_limit, limit=app.settings.search_limit,
source=source,
) )
t_search = time.perf_counter() t_search = time.perf_counter()
@@ -431,10 +452,11 @@ async def knowledge_search(
] ]
logger.info( logger.info(
"knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, chunks=%s", "knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, source_filter=%s, chunks=%s",
round((t_embed - t0) * 1000, 1), round((t_embed - t0) * 1000, 1),
round((t_search - t_embed) * 1000, 1), round((t_search - t_embed) * 1000, 1),
round((t_search - t0) * 1000, 1), round((t_search - t0) * 1000, 1),
source.value if source is not None else None,
[s["id"] for s in search_results], [s["id"] for s in search_results],
) )