dev: add ADK agent with vector search tool and Google Cloud file storage implementation
This commit is contained in:
176
rag_agent/vector_search_tool.py
Normal file
176
rag_agent/vector_search_tool.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A retrieval tool that uses Vertex AI Vector Search (not RAG Engine)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from typing_extensions import override
|
||||
|
||||
from .vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config_helper import VertexAIEmbedder
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
class VectorSearchTool(BaseRetrievalTool):
|
||||
"""A retrieval tool using Vertex AI Vector Search (not RAG Engine).
|
||||
|
||||
This tool uses GoogleCloudVectorSearch to query a vector index directly,
|
||||
which is useful when Vertex AI RAG Engine is not available in your GCP project.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
description: str,
|
||||
embedder: VertexAIEmbedder,
|
||||
project_id: str,
|
||||
location: str,
|
||||
bucket: str,
|
||||
index_name: str,
|
||||
index_endpoint: str,
|
||||
index_deployed_id: str,
|
||||
similarity_top_k: int = 5,
|
||||
min_similarity_threshold: float = 0.6,
|
||||
relative_threshold_factor: float = 0.9,
|
||||
):
|
||||
"""Initialize the VectorSearchTool.
|
||||
|
||||
Args:
|
||||
name: Tool name for function declaration
|
||||
description: Tool description for LLM
|
||||
embedder: Embedder instance for query embedding
|
||||
project_id: GCP project ID
|
||||
location: GCP location (e.g., 'us-central1')
|
||||
bucket: GCS bucket for content storage
|
||||
index_name: Vector search index name
|
||||
index_endpoint: Resource name of index endpoint
|
||||
index_deployed_id: Deployed index ID
|
||||
similarity_top_k: Number of results to retrieve (default: 5)
|
||||
min_similarity_threshold: Minimum similarity score 0.0-1.0 (default: 0.6)
|
||||
relative_threshold_factor: Factor of max similarity for dynamic filtering (default: 0.9)
|
||||
"""
|
||||
super().__init__(name=name, description=description)
|
||||
|
||||
self.embedder = embedder
|
||||
self.index_endpoint = index_endpoint
|
||||
self.index_deployed_id = index_deployed_id
|
||||
self.similarity_top_k = similarity_top_k
|
||||
self.min_similarity_threshold = min_similarity_threshold
|
||||
self.relative_threshold_factor = relative_threshold_factor
|
||||
|
||||
# Initialize vector search (endpoint loaded lazily on first use)
|
||||
self.vector_search = GoogleCloudVectorSearch(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
bucket=bucket,
|
||||
index_name=index_name,
|
||||
)
|
||||
self._endpoint_loaded = False
|
||||
|
||||
logger.info(
|
||||
'VectorSearchTool initialized with index=%s, deployed_id=%s',
|
||||
index_name,
|
||||
index_deployed_id,
|
||||
)
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self,
|
||||
*,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
"""Execute vector search with the user's query.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing 'query' key
|
||||
tool_context: Tool execution context
|
||||
|
||||
Returns:
|
||||
Formatted search results as XML-like documents or error message
|
||||
"""
|
||||
query = args['query']
|
||||
logger.debug('VectorSearchTool query: %s', query)
|
||||
|
||||
try:
|
||||
# Load index endpoint on first use (lazy loading)
|
||||
if not self._endpoint_loaded:
|
||||
self.vector_search.load_index_endpoint(self.index_endpoint)
|
||||
self._endpoint_loaded = True
|
||||
logger.info('Index endpoint loaded successfully')
|
||||
|
||||
# Embed the query using the configured embedder
|
||||
embedding_result = await self.embedder.embed_query(query)
|
||||
query_embedding = list(embedding_result.embeddings[0])
|
||||
|
||||
# Run vector search
|
||||
search_results = await self.vector_search.async_run_query(
|
||||
deployed_index_id=self.index_deployed_id,
|
||||
query=query_embedding,
|
||||
limit=self.similarity_top_k,
|
||||
)
|
||||
|
||||
# Apply similarity filtering (dual threshold approach)
|
||||
if search_results:
|
||||
# Dynamic threshold based on max similarity
|
||||
max_similarity = max(r['distance'] for r in search_results)
|
||||
dynamic_cutoff = max_similarity * self.relative_threshold_factor
|
||||
|
||||
# Filter by both absolute and relative thresholds
|
||||
search_results = [
|
||||
result
|
||||
for result in search_results
|
||||
if (
|
||||
result['distance'] > dynamic_cutoff
|
||||
and result['distance'] > self.min_similarity_threshold
|
||||
)
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
'VectorSearchTool results: %d documents after filtering',
|
||||
len(search_results),
|
||||
)
|
||||
|
||||
# Format results
|
||||
if not search_results:
|
||||
return (
|
||||
f"No matching documents found for query: '{query}' "
|
||||
f'(min_threshold={self.min_similarity_threshold})'
|
||||
)
|
||||
|
||||
# Format as XML-like documents (matching pydantic_ai pattern)
|
||||
formatted_results = [
|
||||
f'<document {i} name={result["id"]}>\n'
|
||||
f'{result["content"]}\n'
|
||||
f'</document {i}>'
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
]
|
||||
|
||||
return '\n'.join(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error('VectorSearchTool error: %s', e, exc_info=True)
|
||||
return f'Error during vector search: {str(e)}'
|
||||
Reference in New Issue
Block a user