Compare commits

1 Commits

Author SHA1 Message Date
Anibal Angulo
81fcc83bdf Add metadata filtering 2026-02-24 20:03:28 +00:00
3 changed files with 33 additions and 66 deletions

View File

@@ -6,24 +6,7 @@ An MCP (Model Context Protocol) server that exposes a `knowledge_search` tool fo
1. A natural-language query is embedded using a Gemini embedding model. 1. A natural-language query is embedded using a Gemini embedding model.
2. The embedding is sent to a Vertex AI Matching Engine index endpoint to find nearest neighbors. 2. The embedding is sent to a Vertex AI Matching Engine index endpoint to find nearest neighbors.
3. Optional filters (restricts) can be applied to search only specific source folders. 3. The matched document contents are fetched from a GCS bucket and returned to the caller.
4. The matched document contents are fetched from a GCS bucket and returned to the caller.
## Filtering by Source Folder
The `knowledge_search` tool supports filtering results by source folder:
```python
# Search all folders
knowledge_search(query="what is a savings account?")
# Search only in specific folders
knowledge_search(
query="what is a savings account?",
source_folders=["Educacion Financiera", "Productos y Servicios"]
)
```
## Prerequisites ## Prerequisites

View File

@@ -57,20 +57,9 @@ async def async_main() -> None:
model="gemini-2.0-flash", model="gemini-2.0-flash",
name="knowledge_agent", name="knowledge_agent",
instruction=( instruction=(
"You are a helpful assistant with access to a knowledge base organized by folders. " "You are a helpful assistant with access to a knowledge base. "
"Use the knowledge_search tool to find relevant information when the user asks questions.\n\n" "Use the knowledge_search tool to find relevant information "
"Available folders in the knowledge base:\n" "when the user asks questions. Summarize the results clearly."
"- 'Educacion Financiera': Educational content about finance, savings, investments, financial concepts\n"
"- 'Funcionalidades de la App Movil': Mobile app features, functionality, usage instructions\n"
"- 'Productos y Servicios': Bank products and services, accounts, procedures\n\n"
"IMPORTANT: When the user asks about a specific topic, analyze which folders are relevant "
"and use the source_folders parameter to filter results for more precise answers.\n\n"
"Examples:\n"
"- User asks about 'cuenta de ahorros' → Use source_folders=['Educacion Financiera', 'Productos y Servicios']\n"
"- User asks about 'cómo usar la app móvil' → Use source_folders=['Funcionalidades de App Movil']\n"
"- User asks about 'transferencias en la app' → Use source_folders=['Funcionalidades de App Movil', 'Productos y Servicios']\n"
"- User asks general question → Don't use source_folders (search all)\n\n"
"Summarize the results clearly in Spanish."
), ),
tools=[toolset], tools=[toolset],
) )

63
main.py
View File

@@ -9,6 +9,7 @@ import os
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import BinaryIO, TypedDict from typing import BinaryIO, TypedDict
import aiohttp import aiohttp
@@ -25,6 +26,14 @@ HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500 HTTP_SERVER_ERROR = 500
class SourceNamespace(str, Enum):
"""Allowed values for the 'source' namespace filter."""
EDUCACION_FINANCIERA = "Educacion Financiera"
PRODUCTOS_Y_SERVICIOS = "Productos y Servicios"
FUNCIONALIDADES_APP_MOVIL = "Funcionalidades de la App Movil"
class GoogleCloudFileStorage: class GoogleCloudFileStorage:
"""Cache-aware helper for downloading files from Google Cloud Storage.""" """Cache-aware helper for downloading files from Google Cloud Storage."""
@@ -204,7 +213,7 @@ class GoogleCloudVectorSearch:
deployed_index_id: str, deployed_index_id: str,
query: Sequence[float], query: Sequence[float],
limit: int, limit: int,
restricts: list[dict[str, list[str]]] | None = None, source: SourceNamespace | None = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Run an async similarity search via the REST API. """Run an async similarity search via the REST API.
@@ -212,6 +221,7 @@ class GoogleCloudVectorSearch:
deployed_index_id: The ID of the deployed index. deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query. query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return. limit: Maximum number of nearest neighbors to return.
source: Optional namespace filter to restrict results by source.
Returns: Returns:
A list of matched items with id, distance, and content. A list of matched items with id, distance, and content.
@@ -230,18 +240,19 @@ class GoogleCloudVectorSearch:
f"/locations/{self.location}" f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors" f"/indexEndpoints/{endpoint_id}:findNeighbors"
) )
query_payload = { datapoint: dict = {"feature_vector": list(query)}
"datapoint": {"feature_vector": list(query)}, if source is not None:
"neighbor_count": limit, datapoint["restricts"] = [
} {"namespace": "source", "allow_list": [source.value]},
]
# Add restricts if provided
if restricts:
query_payload["restricts"] = restricts
payload = { payload = {
"deployed_index_id": deployed_index_id, "deployed_index_id": deployed_index_id,
"queries": [query_payload], "queries": [
{
"datapoint": datapoint,
"neighbor_count": limit,
},
],
} }
headers = await self._async_get_auth_headers() headers = await self._async_get_auth_headers()
@@ -390,16 +401,16 @@ mcp = FastMCP(
async def knowledge_search( async def knowledge_search(
query: str, query: str,
ctx: Context, ctx: Context,
source_folders: list[str] | None = None, source: SourceNamespace | None = None,
) -> str: ) -> str:
"""Search a knowledge base using a natural-language query. """Search a knowledge base using a natural-language query.
Args: Args:
query: The text query to search for. query: The text query to search for.
ctx: MCP request context (injected automatically). ctx: MCP request context (injected automatically).
source_folders: Optional list of source folder paths to filter results. source: Optional filter to restrict results by source.
If provided, only documents from these folders will be returned. Allowed values: 'Educacion Financiera',
Example: ["Educacion Financiera", "Productos y Servicios"] 'Productos y Servicios', 'Funcionalidades de la App Movil'.
Returns: Returns:
A formatted string containing matched documents with id and content. A formatted string containing matched documents with id and content.
@@ -422,31 +433,14 @@ async def knowledge_search(
embedding = response.embeddings[0].values embedding = response.embeddings[0].values
t_embed = time.perf_counter() t_embed = time.perf_counter()
# Build restricts for source folder filtering if provided
restricts = None
if source_folders:
restricts = [
{
"namespace": "source_folder",
"allow": source_folders,
}
]
logger.info(f"Filtering by source_folders: {source_folders}")
else:
logger.info("No filtering - searching all folders")
search_results = await app.vector_search.async_run_query( search_results = await app.vector_search.async_run_query(
deployed_index_id=app.settings.deployed_index_id, deployed_index_id=app.settings.deployed_index_id,
query=embedding, query=embedding,
limit=app.settings.search_limit, limit=app.settings.search_limit,
restricts=restricts, source=source,
) )
t_search = time.perf_counter() t_search = time.perf_counter()
# Log raw results from Vertex AI before similarity filtering
logger.info(f"Raw results from Vertex AI (before similarity filter): {len(search_results)} chunks")
logger.info(f"Raw chunk IDs: {[s['id'] for s in search_results]}")
# Apply similarity filtering # Apply similarity filtering
if search_results: if search_results:
max_sim = max(r["distance"] for r in search_results) max_sim = max(r["distance"] for r in search_results)
@@ -458,10 +452,11 @@ async def knowledge_search(
] ]
logger.info( logger.info(
"knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, chunks=%s", "knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, source_filter=%s, chunks=%s",
round((t_embed - t0) * 1000, 1), round((t_embed - t0) * 1000, 1),
round((t_search - t_embed) * 1000, 1), round((t_search - t_embed) * 1000, 1),
round((t_search - t0) * 1000, 1), round((t_search - t0) * 1000, 1),
source.value if source is not None else None,
[s["id"] for s in search_results], [s["id"] for s in search_results],
) )