Compare commits

1 Commits

Author SHA1 Message Date
72808b1475 Add filter with metadata using restricts 2026-02-24 03:05:50 +00:00
3 changed files with 66 additions and 33 deletions

View File

@@ -6,7 +6,24 @@ An MCP (Model Context Protocol) server that exposes a `knowledge_search` tool fo
1. A natural-language query is embedded using a Gemini embedding model. 1. A natural-language query is embedded using a Gemini embedding model.
2. The embedding is sent to a Vertex AI Matching Engine index endpoint to find nearest neighbors. 2. The embedding is sent to a Vertex AI Matching Engine index endpoint to find nearest neighbors.
3. The matched document contents are fetched from a GCS bucket and returned to the caller. 3. Optional filters (restricts) can be applied to search only specific source folders.
4. The matched document contents are fetched from a GCS bucket and returned to the caller.
## Filtering by Source Folder
The `knowledge_search` tool supports filtering results by source folder:
```python
# Search all folders
knowledge_search(query="what is a savings account?")
# Search only in specific folders
knowledge_search(
query="what is a savings account?",
source_folders=["Educacion Financiera", "Productos y Servicios"]
)
```
## Prerequisites ## Prerequisites

View File

@@ -57,9 +57,20 @@ async def async_main() -> None:
model="gemini-2.0-flash", model="gemini-2.0-flash",
name="knowledge_agent", name="knowledge_agent",
instruction=( instruction=(
"You are a helpful assistant with access to a knowledge base. " "You are a helpful assistant with access to a knowledge base organized by folders. "
"Use the knowledge_search tool to find relevant information " "Use the knowledge_search tool to find relevant information when the user asks questions.\n\n"
"when the user asks questions. Summarize the results clearly." "Available folders in the knowledge base:\n"
"- 'Educacion Financiera': Educational content about finance, savings, investments, financial concepts\n"
"- 'Funcionalidades de la App Movil': Mobile app features, functionality, usage instructions\n"
"- 'Productos y Servicios': Bank products and services, accounts, procedures\n\n"
"IMPORTANT: When the user asks about a specific topic, analyze which folders are relevant "
"and use the source_folders parameter to filter results for more precise answers.\n\n"
"Examples:\n"
"- User asks about 'cuenta de ahorros' → Use source_folders=['Educacion Financiera', 'Productos y Servicios']\n"
"- User asks about 'cómo usar la app móvil' → Use source_folders=['Funcionalidades de App Movil']\n"
"- User asks about 'transferencias en la app' → Use source_folders=['Funcionalidades de App Movil', 'Productos y Servicios']\n"
"- User asks general question → Don't use source_folders (search all)\n\n"
"Summarize the results clearly in Spanish."
), ),
tools=[toolset], tools=[toolset],
) )

63
main.py
View File

@@ -9,7 +9,6 @@ import os
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import BinaryIO, TypedDict from typing import BinaryIO, TypedDict
import aiohttp import aiohttp
@@ -26,14 +25,6 @@ HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500 HTTP_SERVER_ERROR = 500
class SourceNamespace(str, Enum):
"""Allowed values for the 'source' namespace filter."""
EDUCACION_FINANCIERA = "Educacion Financiera"
PRODUCTOS_Y_SERVICIOS = "Productos y Servicios"
FUNCIONALIDADES_APP_MOVIL = "Funcionalidades de la App Movil"
class GoogleCloudFileStorage: class GoogleCloudFileStorage:
"""Cache-aware helper for downloading files from Google Cloud Storage.""" """Cache-aware helper for downloading files from Google Cloud Storage."""
@@ -213,7 +204,7 @@ class GoogleCloudVectorSearch:
deployed_index_id: str, deployed_index_id: str,
query: Sequence[float], query: Sequence[float],
limit: int, limit: int,
source: SourceNamespace | None = None, restricts: list[dict[str, list[str]]] | None = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Run an async similarity search via the REST API. """Run an async similarity search via the REST API.
@@ -221,7 +212,6 @@ class GoogleCloudVectorSearch:
deployed_index_id: The ID of the deployed index. deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query. query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return. limit: Maximum number of nearest neighbors to return.
source: Optional namespace filter to restrict results by source.
Returns: Returns:
A list of matched items with id, distance, and content. A list of matched items with id, distance, and content.
@@ -240,19 +230,18 @@ class GoogleCloudVectorSearch:
f"/locations/{self.location}" f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors" f"/indexEndpoints/{endpoint_id}:findNeighbors"
) )
datapoint: dict = {"feature_vector": list(query)} query_payload = {
if source is not None: "datapoint": {"feature_vector": list(query)},
datapoint["restricts"] = [ "neighbor_count": limit,
{"namespace": "source", "allow_list": [source.value]}, }
]
# Add restricts if provided
if restricts:
query_payload["restricts"] = restricts
payload = { payload = {
"deployed_index_id": deployed_index_id, "deployed_index_id": deployed_index_id,
"queries": [ "queries": [query_payload],
{
"datapoint": datapoint,
"neighbor_count": limit,
},
],
} }
headers = await self._async_get_auth_headers() headers = await self._async_get_auth_headers()
@@ -401,16 +390,16 @@ mcp = FastMCP(
async def knowledge_search( async def knowledge_search(
query: str, query: str,
ctx: Context, ctx: Context,
source: SourceNamespace | None = None, source_folders: list[str] | None = None,
) -> str: ) -> str:
"""Search a knowledge base using a natural-language query. """Search a knowledge base using a natural-language query.
Args: Args:
query: The text query to search for. query: The text query to search for.
ctx: MCP request context (injected automatically). ctx: MCP request context (injected automatically).
source: Optional filter to restrict results by source. source_folders: Optional list of source folder paths to filter results.
Allowed values: 'Educacion Financiera', If provided, only documents from these folders will be returned.
'Productos y Servicios', 'Funcionalidades de la App Movil'. Example: ["Educacion Financiera", "Productos y Servicios"]
Returns: Returns:
A formatted string containing matched documents with id and content. A formatted string containing matched documents with id and content.
@@ -433,14 +422,31 @@ async def knowledge_search(
embedding = response.embeddings[0].values embedding = response.embeddings[0].values
t_embed = time.perf_counter() t_embed = time.perf_counter()
# Build restricts for source folder filtering if provided
restricts = None
if source_folders:
restricts = [
{
"namespace": "source_folder",
"allow": source_folders,
}
]
logger.info(f"Filtering by source_folders: {source_folders}")
else:
logger.info("No filtering - searching all folders")
search_results = await app.vector_search.async_run_query( search_results = await app.vector_search.async_run_query(
deployed_index_id=app.settings.deployed_index_id, deployed_index_id=app.settings.deployed_index_id,
query=embedding, query=embedding,
limit=app.settings.search_limit, limit=app.settings.search_limit,
source=source, restricts=restricts,
) )
t_search = time.perf_counter() t_search = time.perf_counter()
# Log raw results from Vertex AI before similarity filtering
logger.info(f"Raw results from Vertex AI (before similarity filter): {len(search_results)} chunks")
logger.info(f"Raw chunk IDs: {[s['id'] for s in search_results]}")
# Apply similarity filtering # Apply similarity filtering
if search_results: if search_results:
max_sim = max(r["distance"] for r in search_results) max_sim = max(r["distance"] for r in search_results)
@@ -452,11 +458,10 @@ async def knowledge_search(
] ]
logger.info( logger.info(
"knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, source_filter=%s, chunks=%s", "knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, chunks=%s",
round((t_embed - t0) * 1000, 1), round((t_embed - t0) * 1000, 1),
round((t_search - t_embed) * 1000, 1), round((t_search - t_embed) * 1000, 1),
round((t_search - t0) * 1000, 1), round((t_search - t0) * 1000, 1),
source.value if source is not None else None,
[s["id"] for s in search_results], [s["id"] for s in search_results],
) )