diff --git a/main.py b/main.py index c70a6c9..7199cb3 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ import io from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass +from enum import Enum from typing import BinaryIO, TypedDict import aiohttp @@ -21,6 +22,14 @@ HTTP_TOO_MANY_REQUESTS = 429 HTTP_SERVER_ERROR = 500 +class SourceNamespace(str, Enum): + """Allowed values for the 'source' namespace filter.""" + + EDUCACION_FINANCIERA = "Educacion Financiera" + PRODUCTOS_Y_SERVICIOS = "Productos y Servicios" + FUNCIONALIDADES_APP_MOVIL = "Funcionalidades de la App Movil" + + class GoogleCloudFileStorage: """Cache-aware helper for downloading files from Google Cloud Storage.""" @@ -236,6 +245,7 @@ class GoogleCloudVectorSearch: deployed_index_id: str, query: Sequence[float], limit: int, + source: SourceNamespace | None = None, ) -> list[SearchResult]: """Run an async similarity search via the REST API. @@ -243,6 +253,7 @@ class GoogleCloudVectorSearch: deployed_index_id: The ID of the deployed index. query: The embedding vector for the search query. limit: Maximum number of nearest neighbors to return. + source: Optional namespace filter to restrict results by source. Returns: A list of matched items with id, distance, and content. @@ -279,11 +290,16 @@ class GoogleCloudVectorSearch: } ) + datapoint: dict = {"feature_vector": list(query)} + if source is not None: + datapoint["restricts"] = [ + {"namespace": "source", "allow_list": [source.value]}, + ] payload = { "deployed_index_id": deployed_index_id, "queries": [ { - "datapoint": {"feature_vector": list(query)}, + "datapoint": datapoint, "neighbor_count": limit, }, ], @@ -636,12 +652,16 @@ mcp = FastMCP( async def knowledge_search( query: str, ctx: Context, + source: SourceNamespace | None = None, ) -> str: """Search a knowledge base using a natural-language query. Args: query: The text query to search for. ctx: MCP request context (injected automatically). + source: Optional filter to restrict results by source. + Allowed values: 'Educacion Financiera', + 'Productos y Servicios', 'Funcionalidades de la App Movil'. Returns: A formatted string containing matched documents with id and content. @@ -712,6 +732,7 @@ async def knowledge_search( deployed_index_id=app.settings.deployed_index_id, query=embedding, limit=app.settings.search_limit, + source=source, ) t_search = time.perf_counter() except Exception as e: @@ -743,6 +764,7 @@ async def knowledge_search( "embedding_ms": f"{round((t_embed - t0) * 1000, 1)}ms", "vector_search_ms": f"{round((t_search - t_embed) * 1000, 1)}ms", "total_ms": f"{round((t_search - t0) * 1000, 1)}ms", + "source_filter": source.value if source is not None else None, "results_count": len(search_results), "chunks": [s["id"] for s in search_results] }