Compare commits

1 Commits

Author SHA1 Message Date
72808b1475 Add filter with metadata using restricts 2026-02-24 03:05:50 +00:00
6 changed files with 142 additions and 135 deletions

View File

@@ -6,7 +6,24 @@ An MCP (Model Context Protocol) server that exposes a `knowledge_search` tool fo
1. A natural-language query is embedded using a Gemini embedding model.
2. The embedding is sent to a Vertex AI Matching Engine index endpoint to find nearest neighbors.
3. The matched document contents are fetched from a GCS bucket and returned to the caller.
3. Optional filters (restricts) can be applied to search only specific source folders.
4. The matched document contents are fetched from a GCS bucket and returned to the caller.
## Filtering by Source Folder
The `knowledge_search` tool supports filtering results by source folder:
```python
# Search all folders
knowledge_search(query="what is a savings account?")
# Search only in specific folders
knowledge_search(
query="what is a savings account?",
source_folders=["Educacion Financiera", "Productos y Servicios"]
)
```
## Prerequisites

View File

@@ -57,9 +57,20 @@ async def async_main() -> None:
model="gemini-2.0-flash",
name="knowledge_agent",
instruction=(
"You are a helpful assistant with access to a knowledge base. "
"Use the knowledge_search tool to find relevant information "
"when the user asks questions. Summarize the results clearly."
"You are a helpful assistant with access to a knowledge base organized by folders. "
"Use the knowledge_search tool to find relevant information when the user asks questions.\n\n"
"Available folders in the knowledge base:\n"
"- 'Educacion Financiera': Educational content about finance, savings, investments, financial concepts\n"
"- 'Funcionalidades de la App Movil': Mobile app features, functionality, usage instructions\n"
"- 'Productos y Servicios': Bank products and services, accounts, procedures\n\n"
"IMPORTANT: When the user asks about a specific topic, analyze which folders are relevant "
"and use the source_folders parameter to filter results for more precise answers.\n\n"
"Examples:\n"
"- User asks about 'cuenta de ahorros' → Use source_folders=['Educacion Financiera', 'Productos y Servicios']\n"
"- User asks about 'cómo usar la app móvil' → Use source_folders=['Funcionalidades de App Movil']\n"
"- User asks about 'transferencias en la app' → Use source_folders=['Funcionalidades de App Movil', 'Productos y Servicios']\n"
"- User asks general question → Don't use source_folders (search all)\n\n"
"Summarize the results clearly in Spanish."
),
tools=[toolset],
)

133
main.py
View File

@@ -1,8 +1,11 @@
# ruff: noqa: INP001
"""Async helpers for querying Vertex AI vector search via MCP."""
import argparse
import asyncio
import io
import logging
import os
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from dataclasses import dataclass
@@ -14,8 +17,9 @@ from gcloud.aio.storage import Storage
from google import genai
from google.genai import types as genai_types
from mcp.server.fastmcp import Context, FastMCP
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, YamlConfigSettingsSource
from .utils import Settings, _args, log_structured_entry
logger = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500
@@ -87,9 +91,12 @@ class GoogleCloudFileStorage:
file_stream.name = file_name
except TimeoutError as exc:
last_exception = exc
log_structured_entry(
f"Timeout downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})"
"WARNING"
logger.warning(
"Timeout downloading gs://%s/%s (attempt %d/%d)",
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
except aiohttp.ClientResponseError as exc:
last_exception = exc
@@ -97,9 +104,13 @@ class GoogleCloudFileStorage:
exc.status == HTTP_TOO_MANY_REQUESTS
or exc.status >= HTTP_SERVER_ERROR
):
log_structured_entry(
f"HTTP {exc.status} downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})"
"WARNING"
logger.warning(
"HTTP %d downloading gs://%s/%s (attempt %d/%d)",
exc.status,
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
else:
raise
@@ -193,6 +204,7 @@ class GoogleCloudVectorSearch:
deployed_index_id: str,
query: Sequence[float],
limit: int,
restricts: list[dict[str, list[str]]] | None = None,
) -> list[SearchResult]:
"""Run an async similarity search via the REST API.
@@ -218,14 +230,18 @@ class GoogleCloudVectorSearch:
f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors"
)
query_payload = {
"datapoint": {"feature_vector": list(query)},
"neighbor_count": limit,
}
# Add restricts if provided
if restricts:
query_payload["restricts"] = restricts
payload = {
"deployed_index_id": deployed_index_id,
"queries": [
{
"datapoint": {"feature_vector": list(query)},
"neighbor_count": limit,
},
],
"queries": [query_payload],
}
headers = await self._async_get_auth_headers()
@@ -272,6 +288,58 @@ class GoogleCloudVectorSearch:
# ---------------------------------------------------------------------------
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--transport",
choices=["stdio", "sse"],
default="stdio",
)
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080)
parser.add_argument(
"--config",
default=os.environ.get("CONFIG_FILE", "config.yaml"),
)
return parser.parse_args()
_args = _parse_args()
class Settings(BaseSettings):
"""Server configuration populated from env vars and a YAML config file."""
model_config = {"env_file": ".env", "yaml_file": _args.config}
project_id: str
location: str
bucket: str
index_name: str
deployed_index_id: str
endpoint_name: str
endpoint_domain: str
embedding_model: str = "gemini-embedding-001"
search_limit: int = 10
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
dotenv_settings,
YamlConfigSettingsSource(settings_cls),
file_secret_settings,
)
@dataclass
class AppContext:
"""Shared resources initialised once at server startup."""
@@ -322,12 +390,16 @@ mcp = FastMCP(
async def knowledge_search(
query: str,
ctx: Context,
source_folders: list[str] | None = None,
) -> str:
"""Search a knowledge base using a natural-language query.
Args:
query: The text query to search for.
ctx: MCP request context (injected automatically).
source_folders: Optional list of source folder paths to filter results.
If provided, only documents from these folders will be returned.
Example: ["Educacion Financiera", "Productos y Servicios"]
Returns:
A formatted string containing matched documents with id and content.
@@ -350,13 +422,31 @@ async def knowledge_search(
embedding = response.embeddings[0].values
t_embed = time.perf_counter()
# Build restricts for source folder filtering if provided
restricts = None
if source_folders:
restricts = [
{
"namespace": "source_folder",
"allow": source_folders,
}
]
logger.info(f"Filtering by source_folders: {source_folders}")
else:
logger.info("No filtering - searching all folders")
search_results = await app.vector_search.async_run_query(
deployed_index_id=app.settings.deployed_index_id,
query=embedding,
limit=app.settings.search_limit,
restricts=restricts,
)
t_search = time.perf_counter()
# Log raw results from Vertex AI before similarity filtering
logger.info(f"Raw results from Vertex AI (before similarity filter): {len(search_results)} chunks")
logger.info(f"Raw chunk IDs: {[s['id'] for s in search_results]}")
# Apply similarity filtering
if search_results:
max_sim = max(r["distance"] for r in search_results)
@@ -366,16 +456,13 @@ async def knowledge_search(
for s in search_results
if s["distance"] > cutoff and s["distance"] > min_sim
]
log_structured_entry(
"knowledge_search timing",
"INFO",
{
"embedding": f"{round((t_embed - t0) * 1000, 1)}ms",
"vector_serach": f"{round((t_search - t_embed) * 1000, 1)}ms",
"total": f"{round((t_search - t0) * 1000, 1)}ms",
"chunks": {[s["id"] for s in search_results]}
}
logger.info(
"knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, chunks=%s",
round((t_embed - t0) * 1000, 1),
round((t_search - t_embed) * 1000, 1),
round((t_search - t0) * 1000, 1),
[s["id"] for s in search_results],
)
# Format results as XML-like documents

View File

@@ -1,4 +0,0 @@
from .config import Settings, _args
from .logging_setup import log_structured_entry
__all__ = ['Settings', '_args', 'log_structured_entry']

View File

@@ -1,54 +0,0 @@
import os
import argparse
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, YamlConfigSettingsSource
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--transport",
choices=["stdio", "sse"],
default="stdio",
)
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080)
parser.add_argument(
"--config",
default=os.environ.get("CONFIG_FILE", "config.yaml"),
)
return parser.parse_args()
_args = _parse_args()
class Settings(BaseSettings):
"""Server configuration populated from env vars and a YAML config file."""
model_config = {"env_file": ".env", "yaml_file": _args.config}
project_id: str
location: str
bucket: str
index_name: str
deployed_index_id: str
endpoint_name: str
endpoint_domain: str
embedding_model: str = "gemini-embedding-001"
search_limit: int = 10
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
dotenv_settings,
YamlConfigSettingsSource(settings_cls),
file_secret_settings,
)

View File

@@ -1,50 +0,0 @@
"""
Centralized Cloud Logging setup.
Uses CloudLoggingHandler (background thread) so logging does not add latency
"""
import logging
from typing import Optional, Dict, Literal
import google.cloud.logging
from google.cloud.logging.handlers import CloudLoggingHandler
from .config import Settings
def _setup_logger() -> logging.Logger:
"""Create or return the singleton evaluation logger."""
log_name = "va_agent-evaluation-logs"
logger = logging.getLogger(log_name)
cfg = Settings.model_validate({})
if any(isinstance(h, CloudLoggingHandler) for h in logger.handlers):
return logger
try:
client = google.cloud.logging.Client(project=cfg.project_id)
handler = CloudLoggingHandler(client, name=log_name) # async transport
logger.addHandler(handler)
logger.setLevel(logging.INFO)
except Exception as e:
# Fallback to console if Cloud Logging is unavailable (local dev)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(log_name)
logger.warning("Cloud Logging setup failed; using console. Error: %s", e)
return logger
_eval_log = _setup_logger()
def log_structured_entry(message: str, severity: Literal["INFO", "WARNING", "ERROR"], custom_log: Optional[Dict] = None) -> None:
"""
Emit a JSON-structured log row.
Args:
message: Short label for the row (e.g., "Final agent turn").
severity: "INFO" | "WARNING" | "ERROR" etc.
custom_log: A dict with your structured payload.
"""
level = getattr(logging, severity.upper(), logging.INFO)
_eval_log.log(level, message, extra={"json_fields": {"message": message, "custom": custom_log or {}}})