Files
agent/rag_agent/vector_search_tool.py

177 lines
5.6 KiB
Python

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A retrieval tool that uses Vertex AI Vector Search (not RAG Engine)."""
from __future__ import annotations
import logging
from typing import Any
from typing import TYPE_CHECKING
from google.adk.tools.tool_context import ToolContext
from typing_extensions import override
from .vertex_ai import GoogleCloudVectorSearch
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
if TYPE_CHECKING:
from .config_helper import VertexAIEmbedder
logger = logging.getLogger('google_adk.' + __name__)
class VectorSearchTool(BaseRetrievalTool):
"""A retrieval tool using Vertex AI Vector Search (not RAG Engine).
This tool uses GoogleCloudVectorSearch to query a vector index directly,
which is useful when Vertex AI RAG Engine is not available in your GCP project.
"""
def __init__(
self,
*,
name: str,
description: str,
embedder: VertexAIEmbedder,
project_id: str,
location: str,
bucket: str,
index_name: str,
index_endpoint: str,
index_deployed_id: str,
similarity_top_k: int = 5,
min_similarity_threshold: float = 0.6,
relative_threshold_factor: float = 0.9,
):
"""Initialize the VectorSearchTool.
Args:
name: Tool name for function declaration
description: Tool description for LLM
embedder: Embedder instance for query embedding
project_id: GCP project ID
location: GCP location (e.g., 'us-central1')
bucket: GCS bucket for content storage
index_name: Vector search index name
index_endpoint: Resource name of index endpoint
index_deployed_id: Deployed index ID
similarity_top_k: Number of results to retrieve (default: 5)
min_similarity_threshold: Minimum similarity score 0.0-1.0 (default: 0.6)
relative_threshold_factor: Factor of max similarity for dynamic filtering (default: 0.9)
"""
super().__init__(name=name, description=description)
self.embedder = embedder
self.index_endpoint = index_endpoint
self.index_deployed_id = index_deployed_id
self.similarity_top_k = similarity_top_k
self.min_similarity_threshold = min_similarity_threshold
self.relative_threshold_factor = relative_threshold_factor
# Initialize vector search (endpoint loaded lazily on first use)
self.vector_search = GoogleCloudVectorSearch(
project_id=project_id,
location=location,
bucket=bucket,
index_name=index_name,
)
self._endpoint_loaded = False
logger.info(
'VectorSearchTool initialized with index=%s, deployed_id=%s',
index_name,
index_deployed_id,
)
@override
async def run_async(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:
"""Execute vector search with the user's query.
Args:
args: Dictionary containing 'query' key
tool_context: Tool execution context
Returns:
Formatted search results as XML-like documents or error message
"""
query = args['query']
logger.debug('VectorSearchTool query: %s', query)
try:
# Load index endpoint on first use (lazy loading)
if not self._endpoint_loaded:
self.vector_search.load_index_endpoint(self.index_endpoint)
self._endpoint_loaded = True
logger.info('Index endpoint loaded successfully')
# Embed the query using the configured embedder
embedding_result = await self.embedder.embed_query(query)
query_embedding = list(embedding_result.embeddings[0])
# Run vector search
search_results = await self.vector_search.async_run_query(
deployed_index_id=self.index_deployed_id,
query=query_embedding,
limit=self.similarity_top_k,
)
# Apply similarity filtering (dual threshold approach)
if search_results:
# Dynamic threshold based on max similarity
max_similarity = max(r['distance'] for r in search_results)
dynamic_cutoff = max_similarity * self.relative_threshold_factor
# Filter by both absolute and relative thresholds
search_results = [
result
for result in search_results
if (
result['distance'] > dynamic_cutoff
and result['distance'] > self.min_similarity_threshold
)
]
logger.debug(
'VectorSearchTool results: %d documents after filtering',
len(search_results),
)
# Format results
if not search_results:
return (
f"No matching documents found for query: '{query}' "
f'(min_threshold={self.min_similarity_threshold})'
)
# Format as XML-like documents (matching pydantic_ai pattern)
formatted_results = [
f'<document {i} name={result["id"]}>\n'
f'{result["content"]}\n'
f'</document {i}>'
for i, result in enumerate(search_results, start=1)
]
return '\n'.join(formatted_results)
except Exception as e:
logger.error('VectorSearchTool error: %s', e, exc_info=True)
return f'Error during vector search: {str(e)}'