diff --git a/README.md b/README.md index 0192487..ed892cd 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,24 @@ An MCP (Model Context Protocol) server that exposes a `knowledge_search` tool fo 1. A natural-language query is embedded using a Gemini embedding model. 2. The embedding is sent to a Vertex AI Matching Engine index endpoint to find nearest neighbors. -3. The matched document contents are fetched from a GCS bucket and returned to the caller. +3. Optional filters (restricts) can be applied to search only specific source folders. +4. The matched document contents are fetched from a GCS bucket and returned to the caller. + +## Filtering by Source Folder + +The `knowledge_search` tool supports filtering results by source folder: + +```python +# Search all folders +knowledge_search(query="what is a savings account?") + +# Search only in specific folders +knowledge_search( + query="what is a savings account?", + source_folders=["Educacion Financiera", "Productos y Servicios"] +) +``` + ## Prerequisites diff --git a/agent.py b/agent.py index 66d8e46..74ef3ad 100644 --- a/agent.py +++ b/agent.py @@ -57,9 +57,20 @@ async def async_main() -> None: model="gemini-2.0-flash", name="knowledge_agent", instruction=( - "You are a helpful assistant with access to a knowledge base. " - "Use the knowledge_search tool to find relevant information " - "when the user asks questions. Summarize the results clearly." + "You are a helpful assistant with access to a knowledge base organized by folders. " + "Use the knowledge_search tool to find relevant information when the user asks questions.\n\n" + "Available folders in the knowledge base:\n" + "- 'Educacion Financiera': Educational content about finance, savings, investments, financial concepts\n" + "- 'Funcionalidades de la App Movil': Mobile app features, functionality, usage instructions\n" + "- 'Productos y Servicios': Bank products and services, accounts, procedures\n\n" + "IMPORTANT: When the user asks about a specific topic, analyze which folders are relevant " + "and use the source_folders parameter to filter results for more precise answers.\n\n" + "Examples:\n" + "- User asks about 'cuenta de ahorros' → Use source_folders=['Educacion Financiera', 'Productos y Servicios']\n" + "- User asks about 'cómo usar la app móvil' → Use source_folders=['Funcionalidades de App Movil']\n" + "- User asks about 'transferencias en la app' → Use source_folders=['Funcionalidades de App Movil', 'Productos y Servicios']\n" + "- User asks general question → Don't use source_folders (search all)\n\n" + "Summarize the results clearly in Spanish." ), tools=[toolset], ) diff --git a/main.py b/main.py index dfb5d96..42f7fc8 100644 --- a/main.py +++ b/main.py @@ -204,6 +204,7 @@ class GoogleCloudVectorSearch: deployed_index_id: str, query: Sequence[float], limit: int, + restricts: list[dict[str, list[str]]] | None = None, ) -> list[SearchResult]: """Run an async similarity search via the REST API. @@ -229,14 +230,18 @@ class GoogleCloudVectorSearch: f"/locations/{self.location}" f"/indexEndpoints/{endpoint_id}:findNeighbors" ) + query_payload = { + "datapoint": {"feature_vector": list(query)}, + "neighbor_count": limit, + } + + # Add restricts if provided + if restricts: + query_payload["restricts"] = restricts + payload = { "deployed_index_id": deployed_index_id, - "queries": [ - { - "datapoint": {"feature_vector": list(query)}, - "neighbor_count": limit, - }, - ], + "queries": [query_payload], } headers = await self._async_get_auth_headers() @@ -385,12 +390,16 @@ mcp = FastMCP( async def knowledge_search( query: str, ctx: Context, + source_folders: list[str] | None = None, ) -> str: """Search a knowledge base using a natural-language query. Args: query: The text query to search for. ctx: MCP request context (injected automatically). + source_folders: Optional list of source folder paths to filter results. + If provided, only documents from these folders will be returned. + Example: ["Educacion Financiera", "Productos y Servicios"] Returns: A formatted string containing matched documents with id and content. @@ -413,13 +422,31 @@ async def knowledge_search( embedding = response.embeddings[0].values t_embed = time.perf_counter() + # Build restricts for source folder filtering if provided + restricts = None + if source_folders: + restricts = [ + { + "namespace": "source_folder", + "allow": source_folders, + } + ] + logger.info(f"Filtering by source_folders: {source_folders}") + else: + logger.info("No filtering - searching all folders") + search_results = await app.vector_search.async_run_query( deployed_index_id=app.settings.deployed_index_id, query=embedding, limit=app.settings.search_limit, + restricts=restricts, ) t_search = time.perf_counter() + # Log raw results from Vertex AI before similarity filtering + logger.info(f"Raw results from Vertex AI (before similarity filter): {len(search_results)} chunks") + logger.info(f"Raw chunk IDs: {[s['id'] for s in search_results]}") + # Apply similarity filtering if search_results: max_sim = max(r["distance"] for r in search_results)