# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A retrieval tool that uses Vertex AI Vector Search (not RAG Engine).""" from __future__ import annotations import logging from typing import Any from typing import TYPE_CHECKING from google.adk.tools.tool_context import ToolContext from typing_extensions import override from .vertex_ai import GoogleCloudVectorSearch from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool if TYPE_CHECKING: from .config_helper import VertexAIEmbedder logger = logging.getLogger('google_adk.' + __name__) class VectorSearchTool(BaseRetrievalTool): """A retrieval tool using Vertex AI Vector Search (not RAG Engine). This tool uses GoogleCloudVectorSearch to query a vector index directly, which is useful when Vertex AI RAG Engine is not available in your GCP project. """ def __init__( self, *, name: str, description: str, embedder: VertexAIEmbedder, project_id: str, location: str, bucket: str, index_name: str, index_endpoint: str, index_deployed_id: str, similarity_top_k: int = 5, min_similarity_threshold: float = 0.6, relative_threshold_factor: float = 0.9, ): """Initialize the VectorSearchTool. Args: name: Tool name for function declaration description: Tool description for LLM embedder: Embedder instance for query embedding project_id: GCP project ID location: GCP location (e.g., 'us-central1') bucket: GCS bucket for content storage index_name: Vector search index name index_endpoint: Resource name of index endpoint index_deployed_id: Deployed index ID similarity_top_k: Number of results to retrieve (default: 5) min_similarity_threshold: Minimum similarity score 0.0-1.0 (default: 0.6) relative_threshold_factor: Factor of max similarity for dynamic filtering (default: 0.9) """ super().__init__(name=name, description=description) self.embedder = embedder self.index_endpoint = index_endpoint self.index_deployed_id = index_deployed_id self.similarity_top_k = similarity_top_k self.min_similarity_threshold = min_similarity_threshold self.relative_threshold_factor = relative_threshold_factor # Initialize vector search (endpoint loaded lazily on first use) self.vector_search = GoogleCloudVectorSearch( project_id=project_id, location=location, bucket=bucket, index_name=index_name, ) self._endpoint_loaded = False logger.info( 'VectorSearchTool initialized with index=%s, deployed_id=%s', index_name, index_deployed_id, ) @override async def run_async( self, *, args: dict[str, Any], tool_context: ToolContext, ) -> Any: """Execute vector search with the user's query. Args: args: Dictionary containing 'query' key tool_context: Tool execution context Returns: Formatted search results as XML-like documents or error message """ query = args['query'] logger.debug('VectorSearchTool query: %s', query) try: # Load index endpoint on first use (lazy loading) if not self._endpoint_loaded: self.vector_search.load_index_endpoint(self.index_endpoint) self._endpoint_loaded = True logger.info('Index endpoint loaded successfully') # Embed the query using the configured embedder embedding_result = await self.embedder.embed_query(query) query_embedding = list(embedding_result.embeddings[0]) # Run vector search search_results = await self.vector_search.async_run_query( deployed_index_id=self.index_deployed_id, query=query_embedding, limit=self.similarity_top_k, ) # Apply similarity filtering (dual threshold approach) if search_results: # Dynamic threshold based on max similarity max_similarity = max(r['distance'] for r in search_results) dynamic_cutoff = max_similarity * self.relative_threshold_factor # Filter by both absolute and relative thresholds search_results = [ result for result in search_results if ( result['distance'] > dynamic_cutoff and result['distance'] > self.min_similarity_threshold ) ] logger.debug( 'VectorSearchTool results: %d documents after filtering', len(search_results), ) # Format results if not search_results: return ( f"No matching documents found for query: '{query}' " f'(min_threshold={self.min_similarity_threshold})' ) # Format as XML-like documents (matching pydantic_ai pattern) formatted_results = [ f'\n' f'{result["content"]}\n' f'' for i, result in enumerate(search_results, start=1) ] return '\n'.join(formatted_results) except Exception as e: logger.error('VectorSearchTool error: %s', e, exc_info=True) return f'Error during vector search: {str(e)}'