dev: add ADK agent with vector search tool and Google Cloud file storage implementation
This commit is contained in:
@@ -4,20 +4,19 @@ version = "0.1.0"
|
|||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
{ name = "Anibal Angulo", email = "a8065384@banorte.com" },
|
||||||
|
{ name = "Jorge Juarez", email = "a8080816@banorte.com" }
|
||||||
]
|
]
|
||||||
requires-python = "~=3.12.0"
|
requires-python = "~=3.12.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp>=3.13.3",
|
"aiohttp>=3.13.3",
|
||||||
"fastapi>=0.129.0",
|
|
||||||
"gcloud-aio-auth>=5.4.2",
|
"gcloud-aio-auth>=5.4.2",
|
||||||
"gcloud-aio-storage>=9.6.1",
|
"gcloud-aio-storage>=9.6.1",
|
||||||
"google-cloud-aiplatform>=1.138.0",
|
"google-adk>=1.14.1",
|
||||||
"google-cloud-storage>=3.9.0",
|
"google-cloud-aiplatform>=1.126.1",
|
||||||
"pydantic-ai-slim[google]>=1.62.0",
|
"google-cloud-storage>=2.19.0",
|
||||||
"pydantic-settings[yaml]>=2.10.1",
|
"pydantic-settings[yaml]>=2.13.1",
|
||||||
"structlog>=25.5.0",
|
"structlog>=25.5.0",
|
||||||
"uvicorn>=0.41.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
@@ -30,6 +29,7 @@ build-backend = "uv_build"
|
|||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"clai>=1.62.0",
|
"clai>=1.62.0",
|
||||||
|
"marimo>=0.20.1",
|
||||||
"pytest>=8.4.1",
|
"pytest>=8.4.1",
|
||||||
"ruff>=0.12.10",
|
"ruff>=0.12.10",
|
||||||
"ty>=0.0.1a19",
|
"ty>=0.0.1a19",
|
||||||
|
|||||||
1
rag_agent/__init__.py
Normal file
1
rag_agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from . import agent
|
||||||
47
rag_agent/agent.py
Normal file
47
rag_agent/agent.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Copyright 2026 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""ADK agent with vector search RAG tool."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from google.adk.agents.llm_agent import Agent
|
||||||
|
|
||||||
|
from .config_helper import settings
|
||||||
|
from .vector_search_tool import VectorSearchTool
|
||||||
|
|
||||||
|
# Create vector search tool with configuration
|
||||||
|
vector_search_tool = VectorSearchTool(
|
||||||
|
name='conocimiento',
|
||||||
|
description='Search the vector index for company products and services information',
|
||||||
|
embedder=settings.embedder,
|
||||||
|
project_id=settings.project_id,
|
||||||
|
location=settings.location,
|
||||||
|
bucket=settings.bucket,
|
||||||
|
index_name=settings.index_name,
|
||||||
|
index_endpoint=settings.index_endpoint,
|
||||||
|
index_deployed_id=settings.index_deployed_id,
|
||||||
|
similarity_top_k=5,
|
||||||
|
min_similarity_threshold=0.6,
|
||||||
|
relative_threshold_factor=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create agent with vector search tool
|
||||||
|
root_agent = Agent(
|
||||||
|
model=settings.agent_language_model,
|
||||||
|
name=settings.agent_name,
|
||||||
|
description='A helpful assistant for user questions.',
|
||||||
|
instruction=settings.agent_instructions,
|
||||||
|
tools=[vector_search_tool],
|
||||||
|
)
|
||||||
68
rag_agent/base.py
Normal file
68
rag_agent/base.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""Abstract base class for vector search providers."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(TypedDict):
|
||||||
|
"""A single vector search result."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
distance: float
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVectorSearch(ABC):
|
||||||
|
"""Abstract base class for a vector search provider.
|
||||||
|
|
||||||
|
This class defines the standard interface for creating a vector search
|
||||||
|
index and running queries against it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_index(
|
||||||
|
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
||||||
|
) -> None:
|
||||||
|
"""Create a new vector search index with the provided content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The desired name for the new index.
|
||||||
|
content_path: Path to the data used to populate the index.
|
||||||
|
**kwargs: Additional provider-specific arguments.
|
||||||
|
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_index(
|
||||||
|
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
||||||
|
) -> None:
|
||||||
|
"""Update an existing vector search index with new content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name: The name of the index to update.
|
||||||
|
content_path: Path to the data used to populate the index.
|
||||||
|
**kwargs: Additional provider-specific arguments.
|
||||||
|
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_query(
|
||||||
|
self,
|
||||||
|
deployed_index_id: str,
|
||||||
|
query: list[float],
|
||||||
|
limit: int,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Run a similarity search query against the index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deployed_index_id: The ID of the deployed index.
|
||||||
|
query: The embedding vector for the search query.
|
||||||
|
limit: Maximum number of nearest neighbors to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of matched items with id, distance, and content.
|
||||||
|
|
||||||
|
"""
|
||||||
|
...
|
||||||
120
rag_agent/config_helper.py
Normal file
120
rag_agent/config_helper.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# Copyright 2026 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Configuration helper for ADK agent with vector search."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
import vertexai
|
||||||
|
from pydantic_settings import (
|
||||||
|
BaseSettings,
|
||||||
|
PydanticBaseSettingsSource,
|
||||||
|
SettingsConfigDict,
|
||||||
|
YamlConfigSettingsSource,
|
||||||
|
)
|
||||||
|
from vertexai.language_models import TextEmbeddingModel
|
||||||
|
|
||||||
|
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingResult:
|
||||||
|
"""Result from embedding a query."""
|
||||||
|
|
||||||
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIEmbedder:
|
||||||
|
"""Embedder using Vertex AI TextEmbeddingModel."""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, project_id: str, location: str) -> None:
|
||||||
|
"""Initialize the embedder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the embedding model (e.g., 'text-embedding-004')
|
||||||
|
project_id: GCP project ID
|
||||||
|
location: GCP location
|
||||||
|
|
||||||
|
"""
|
||||||
|
vertexai.init(project=project_id, location=location)
|
||||||
|
self.model = TextEmbeddingModel.from_pretrained(model_name)
|
||||||
|
|
||||||
|
async def embed_query(self, query: str) -> EmbeddingResult:
|
||||||
|
"""Embed a single query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResult with embeddings list
|
||||||
|
|
||||||
|
"""
|
||||||
|
embeddings = self.model.get_embeddings([query])
|
||||||
|
return EmbeddingResult(embeddings=[list(embeddings[0].values)])
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSettings(BaseSettings):
|
||||||
|
"""Settings for ADK agent with vector search."""
|
||||||
|
|
||||||
|
# Google Cloud settings
|
||||||
|
project_id: str
|
||||||
|
location: str
|
||||||
|
bucket: str
|
||||||
|
|
||||||
|
# Agent configuration
|
||||||
|
agent_name: str
|
||||||
|
agent_instructions: str
|
||||||
|
agent_language_model: str
|
||||||
|
agent_embedding_model: str
|
||||||
|
|
||||||
|
# Vector index configuration
|
||||||
|
index_name: str
|
||||||
|
index_deployed_id: str
|
||||||
|
index_endpoint: str
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
yaml_file=CONFIG_FILE_PATH,
|
||||||
|
extra="ignore", # Ignore extra fields from config.yaml
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def settings_customise_sources(
|
||||||
|
cls,
|
||||||
|
settings_cls: type[BaseSettings],
|
||||||
|
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||||
|
env_settings: PydanticBaseSettingsSource,
|
||||||
|
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||||
|
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||||
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||||
|
"""Use env vars and YAML as settings sources."""
|
||||||
|
return (
|
||||||
|
env_settings,
|
||||||
|
YamlConfigSettingsSource(settings_cls),
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def embedder(self) -> VertexAIEmbedder:
|
||||||
|
"""Return an embedder configured for the agent's embedding model."""
|
||||||
|
return VertexAIEmbedder(
|
||||||
|
model_name=self.agent_embedding_model,
|
||||||
|
project_id=self.project_id,
|
||||||
|
location=self.location,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
settings = AgentSettings.model_validate({})
|
||||||
1
rag_agent/file_storage/__init__.py
Normal file
1
rag_agent/file_storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""File storage provider implementations."""
|
||||||
56
rag_agent/file_storage/base.py
Normal file
56
rag_agent/file_storage/base.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Abstract base class for file storage providers."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import BinaryIO
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFileStorage(ABC):
|
||||||
|
"""Abstract base class for a remote file processor.
|
||||||
|
|
||||||
|
Defines the interface for listing and processing files from
|
||||||
|
a remote source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def upload_file(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
destination_blob_name: str,
|
||||||
|
content_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Upload a file to the remote source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: The local path to the file to upload.
|
||||||
|
destination_blob_name: Name of the file in remote storage.
|
||||||
|
content_type: The content type of the file.
|
||||||
|
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_files(self, path: str | None = None) -> list[str]:
|
||||||
|
"""List files from a remote location.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to a specific file or directory. If None,
|
||||||
|
recursively lists all files in the bucket.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of file paths.
|
||||||
|
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||||
|
"""Get a file from the remote source as a file-like object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: The name of the file to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A file-like object containing the file data.
|
||||||
|
|
||||||
|
"""
|
||||||
|
...
|
||||||
188
rag_agent/file_storage/google_cloud.py
Normal file
188
rag_agent/file_storage/google_cloud.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""Google Cloud Storage file storage implementation."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import BinaryIO
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from gcloud.aio.storage import Storage
|
||||||
|
from google.cloud import storage
|
||||||
|
|
||||||
|
from .base import BaseFileStorage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HTTP_TOO_MANY_REQUESTS = 429
|
||||||
|
HTTP_SERVER_ERROR = 500
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleCloudFileStorage(BaseFileStorage):
|
||||||
|
"""File storage backed by Google Cloud Storage."""
|
||||||
|
|
||||||
|
def __init__(self, bucket: str) -> None: # noqa: D107
|
||||||
|
self.bucket_name = bucket
|
||||||
|
|
||||||
|
self.storage_client = storage.Client()
|
||||||
|
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
||||||
|
self._aio_session: aiohttp.ClientSession | None = None
|
||||||
|
self._aio_storage: Storage | None = None
|
||||||
|
self._cache: dict[str, bytes] = {}
|
||||||
|
|
||||||
|
def upload_file(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
destination_blob_name: str,
|
||||||
|
content_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Upload a file to Cloud Storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: The local path to the file to upload.
|
||||||
|
destination_blob_name: Name of the blob in the bucket.
|
||||||
|
content_type: The content type of the file.
|
||||||
|
|
||||||
|
"""
|
||||||
|
blob = self.bucket_client.blob(destination_blob_name)
|
||||||
|
blob.upload_from_filename(
|
||||||
|
file_path,
|
||||||
|
content_type=content_type,
|
||||||
|
if_generation_match=0,
|
||||||
|
)
|
||||||
|
self._cache.pop(destination_blob_name, None)
|
||||||
|
|
||||||
|
def list_files(self, path: str | None = None) -> list[str]:
|
||||||
|
"""List all files at the given path in the bucket.
|
||||||
|
|
||||||
|
If path is None, recursively lists all files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Prefix to filter files by.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of blob names.
|
||||||
|
|
||||||
|
"""
|
||||||
|
blobs = self.storage_client.list_blobs(
|
||||||
|
self.bucket_name, prefix=path,
|
||||||
|
)
|
||||||
|
return [blob.name for blob in blobs]
|
||||||
|
|
||||||
|
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||||
|
"""Get a file as a file-like object, using cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: The blob name to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A BytesIO stream with the file contents.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if file_name not in self._cache:
|
||||||
|
blob = self.bucket_client.blob(file_name)
|
||||||
|
self._cache[file_name] = blob.download_as_bytes()
|
||||||
|
file_stream = io.BytesIO(self._cache[file_name])
|
||||||
|
file_stream.name = file_name
|
||||||
|
return file_stream
|
||||||
|
|
||||||
|
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._aio_session is None or self._aio_session.closed:
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
limit=300, limit_per_host=50,
|
||||||
|
)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=60)
|
||||||
|
self._aio_session = aiohttp.ClientSession(
|
||||||
|
timeout=timeout, connector=connector,
|
||||||
|
)
|
||||||
|
return self._aio_session
|
||||||
|
|
||||||
|
def _get_aio_storage(self) -> Storage:
|
||||||
|
if self._aio_storage is None:
|
||||||
|
self._aio_storage = Storage(
|
||||||
|
session=self._get_aio_session(),
|
||||||
|
)
|
||||||
|
return self._aio_storage
|
||||||
|
|
||||||
|
async def async_get_file_stream(
|
||||||
|
self, file_name: str, max_retries: int = 3,
|
||||||
|
) -> BinaryIO:
|
||||||
|
"""Get a file asynchronously with retry on transient errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: The blob name to retrieve.
|
||||||
|
max_retries: Maximum number of retry attempts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A BytesIO stream with the file contents.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If all retry attempts fail.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if file_name in self._cache:
|
||||||
|
file_stream = io.BytesIO(self._cache[file_name])
|
||||||
|
file_stream.name = file_name
|
||||||
|
return file_stream
|
||||||
|
|
||||||
|
storage_client = self._get_aio_storage()
|
||||||
|
last_exception: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
self._cache[file_name] = await storage_client.download(
|
||||||
|
self.bucket_name, file_name,
|
||||||
|
)
|
||||||
|
file_stream = io.BytesIO(self._cache[file_name])
|
||||||
|
file_stream.name = file_name
|
||||||
|
except TimeoutError as exc:
|
||||||
|
last_exception = exc
|
||||||
|
logger.warning(
|
||||||
|
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
||||||
|
self.bucket_name,
|
||||||
|
file_name,
|
||||||
|
attempt + 1,
|
||||||
|
max_retries,
|
||||||
|
)
|
||||||
|
except aiohttp.ClientResponseError as exc:
|
||||||
|
last_exception = exc
|
||||||
|
if (
|
||||||
|
exc.status == HTTP_TOO_MANY_REQUESTS
|
||||||
|
or exc.status >= HTTP_SERVER_ERROR
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"HTTP %d downloading gs://%s/%s "
|
||||||
|
"(attempt %d/%d)",
|
||||||
|
exc.status,
|
||||||
|
self.bucket_name,
|
||||||
|
file_name,
|
||||||
|
attempt + 1,
|
||||||
|
max_retries,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
return file_stream
|
||||||
|
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
delay = 0.5 * (2**attempt)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
||||||
|
f"after {max_retries} attempts"
|
||||||
|
)
|
||||||
|
raise TimeoutError(msg) from last_exception
|
||||||
|
|
||||||
|
def delete_files(self, path: str) -> None:
|
||||||
|
"""Delete all files at the given path in the bucket.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Prefix of blobs to delete.
|
||||||
|
|
||||||
|
"""
|
||||||
|
blobs = self.storage_client.list_blobs(
|
||||||
|
self.bucket_name, prefix=path,
|
||||||
|
)
|
||||||
|
for blob in blobs:
|
||||||
|
blob.delete()
|
||||||
|
self._cache.pop(blob.name, None)
|
||||||
176
rag_agent/vector_search_tool.py
Normal file
176
rag_agent/vector_search_tool.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
# Copyright 2026 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""A retrieval tool that uses Vertex AI Vector Search (not RAG Engine)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from google.adk.tools.tool_context import ToolContext
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from .vertex_ai import GoogleCloudVectorSearch
|
||||||
|
|
||||||
|
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .config_helper import VertexAIEmbedder
|
||||||
|
|
||||||
|
logger = logging.getLogger('google_adk.' + __name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchTool(BaseRetrievalTool):
|
||||||
|
"""A retrieval tool using Vertex AI Vector Search (not RAG Engine).
|
||||||
|
|
||||||
|
This tool uses GoogleCloudVectorSearch to query a vector index directly,
|
||||||
|
which is useful when Vertex AI RAG Engine is not available in your GCP project.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
embedder: VertexAIEmbedder,
|
||||||
|
project_id: str,
|
||||||
|
location: str,
|
||||||
|
bucket: str,
|
||||||
|
index_name: str,
|
||||||
|
index_endpoint: str,
|
||||||
|
index_deployed_id: str,
|
||||||
|
similarity_top_k: int = 5,
|
||||||
|
min_similarity_threshold: float = 0.6,
|
||||||
|
relative_threshold_factor: float = 0.9,
|
||||||
|
):
|
||||||
|
"""Initialize the VectorSearchTool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Tool name for function declaration
|
||||||
|
description: Tool description for LLM
|
||||||
|
embedder: Embedder instance for query embedding
|
||||||
|
project_id: GCP project ID
|
||||||
|
location: GCP location (e.g., 'us-central1')
|
||||||
|
bucket: GCS bucket for content storage
|
||||||
|
index_name: Vector search index name
|
||||||
|
index_endpoint: Resource name of index endpoint
|
||||||
|
index_deployed_id: Deployed index ID
|
||||||
|
similarity_top_k: Number of results to retrieve (default: 5)
|
||||||
|
min_similarity_threshold: Minimum similarity score 0.0-1.0 (default: 0.6)
|
||||||
|
relative_threshold_factor: Factor of max similarity for dynamic filtering (default: 0.9)
|
||||||
|
"""
|
||||||
|
super().__init__(name=name, description=description)
|
||||||
|
|
||||||
|
self.embedder = embedder
|
||||||
|
self.index_endpoint = index_endpoint
|
||||||
|
self.index_deployed_id = index_deployed_id
|
||||||
|
self.similarity_top_k = similarity_top_k
|
||||||
|
self.min_similarity_threshold = min_similarity_threshold
|
||||||
|
self.relative_threshold_factor = relative_threshold_factor
|
||||||
|
|
||||||
|
# Initialize vector search (endpoint loaded lazily on first use)
|
||||||
|
self.vector_search = GoogleCloudVectorSearch(
|
||||||
|
project_id=project_id,
|
||||||
|
location=location,
|
||||||
|
bucket=bucket,
|
||||||
|
index_name=index_name,
|
||||||
|
)
|
||||||
|
self._endpoint_loaded = False
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'VectorSearchTool initialized with index=%s, deployed_id=%s',
|
||||||
|
index_name,
|
||||||
|
index_deployed_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def run_async(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
args: dict[str, Any],
|
||||||
|
tool_context: ToolContext,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute vector search with the user's query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Dictionary containing 'query' key
|
||||||
|
tool_context: Tool execution context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted search results as XML-like documents or error message
|
||||||
|
"""
|
||||||
|
query = args['query']
|
||||||
|
logger.debug('VectorSearchTool query: %s', query)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load index endpoint on first use (lazy loading)
|
||||||
|
if not self._endpoint_loaded:
|
||||||
|
self.vector_search.load_index_endpoint(self.index_endpoint)
|
||||||
|
self._endpoint_loaded = True
|
||||||
|
logger.info('Index endpoint loaded successfully')
|
||||||
|
|
||||||
|
# Embed the query using the configured embedder
|
||||||
|
embedding_result = await self.embedder.embed_query(query)
|
||||||
|
query_embedding = list(embedding_result.embeddings[0])
|
||||||
|
|
||||||
|
# Run vector search
|
||||||
|
search_results = await self.vector_search.async_run_query(
|
||||||
|
deployed_index_id=self.index_deployed_id,
|
||||||
|
query=query_embedding,
|
||||||
|
limit=self.similarity_top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply similarity filtering (dual threshold approach)
|
||||||
|
if search_results:
|
||||||
|
# Dynamic threshold based on max similarity
|
||||||
|
max_similarity = max(r['distance'] for r in search_results)
|
||||||
|
dynamic_cutoff = max_similarity * self.relative_threshold_factor
|
||||||
|
|
||||||
|
# Filter by both absolute and relative thresholds
|
||||||
|
search_results = [
|
||||||
|
result
|
||||||
|
for result in search_results
|
||||||
|
if (
|
||||||
|
result['distance'] > dynamic_cutoff
|
||||||
|
and result['distance'] > self.min_similarity_threshold
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
'VectorSearchTool results: %d documents after filtering',
|
||||||
|
len(search_results),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format results
|
||||||
|
if not search_results:
|
||||||
|
return (
|
||||||
|
f"No matching documents found for query: '{query}' "
|
||||||
|
f'(min_threshold={self.min_similarity_threshold})'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format as XML-like documents (matching pydantic_ai pattern)
|
||||||
|
formatted_results = [
|
||||||
|
f'<document {i} name={result["id"]}>\n'
|
||||||
|
f'{result["content"]}\n'
|
||||||
|
f'</document {i}>'
|
||||||
|
for i, result in enumerate(search_results, start=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
return '\n'.join(formatted_results)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error('VectorSearchTool error: %s', e, exc_info=True)
|
||||||
|
return f'Error during vector search: {str(e)}'
|
||||||
310
rag_agent/vertex_ai.py
Normal file
310
rag_agent/vertex_ai.py
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
"""Google Cloud Vertex AI Vector Search implementation."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import google.auth
|
||||||
|
import google.auth.credentials
|
||||||
|
import google.auth.transport.requests
|
||||||
|
from gcloud.aio.auth import Token
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
from .file_storage.google_cloud import GoogleCloudFileStorage
|
||||||
|
from .base import BaseVectorSearch, SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleCloudVectorSearch(BaseVectorSearch):
|
||||||
|
"""A vector search provider using Vertex AI Vector Search."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
location: str,
|
||||||
|
bucket: str,
|
||||||
|
index_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the GoogleCloudVectorSearch client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: The Google Cloud project ID.
|
||||||
|
location: The Google Cloud location (e.g., 'us-central1').
|
||||||
|
bucket: The GCS bucket to use for file storage.
|
||||||
|
index_name: The name of the index.
|
||||||
|
|
||||||
|
"""
|
||||||
|
aiplatform.init(project=project_id, location=location)
|
||||||
|
self.project_id = project_id
|
||||||
|
self.location = location
|
||||||
|
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||||
|
self.index_name = index_name
|
||||||
|
self._credentials: google.auth.credentials.Credentials | None = None
|
||||||
|
self._aio_session: aiohttp.ClientSession | None = None
|
||||||
|
self._async_token: Token | None = None
|
||||||
|
|
||||||
|
def _get_auth_headers(self) -> dict[str, str]:
|
||||||
|
if self._credentials is None:
|
||||||
|
self._credentials, _ = google.auth.default(
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
if not self._credentials.token or self._credentials.expired:
|
||||||
|
self._credentials.refresh(
|
||||||
|
google.auth.transport.requests.Request(),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {self._credentials.token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _async_get_auth_headers(self) -> dict[str, str]:
|
||||||
|
if self._async_token is None:
|
||||||
|
self._async_token = Token(
|
||||||
|
session=self._get_aio_session(),
|
||||||
|
scopes=[
|
||||||
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
access_token = await self._async_token.get()
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._aio_session is None or self._aio_session.closed:
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
limit=300, limit_per_host=50,
|
||||||
|
)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=60)
|
||||||
|
self._aio_session = aiohttp.ClientSession(
|
||||||
|
timeout=timeout, connector=connector,
|
||||||
|
)
|
||||||
|
return self._aio_session
|
||||||
|
|
||||||
|
def create_index(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
content_path: str,
|
||||||
|
*,
|
||||||
|
dimensions: int = 3072,
|
||||||
|
approximate_neighbors_count: int = 150,
|
||||||
|
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
||||||
|
**kwargs: Any, # noqa: ANN401, ARG002
|
||||||
|
) -> None:
|
||||||
|
"""Create a new Vertex AI Vector Search index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The display name for the new index.
|
||||||
|
content_path: GCS URI to the embeddings JSON file.
|
||||||
|
dimensions: Number of dimensions in embedding vectors.
|
||||||
|
approximate_neighbors_count: Neighbors to find per vector.
|
||||||
|
distance_measure_type: The distance measure to use.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
"""
|
||||||
|
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
||||||
|
display_name=name,
|
||||||
|
contents_delta_uri=content_path,
|
||||||
|
dimensions=dimensions,
|
||||||
|
approximate_neighbors_count=approximate_neighbors_count,
|
||||||
|
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
|
||||||
|
leaf_node_embedding_count=1000,
|
||||||
|
leaf_nodes_to_search_percent=10,
|
||||||
|
)
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def update_index(
|
||||||
|
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
|
||||||
|
) -> None:
|
||||||
|
"""Update an existing Vertex AI Vector Search index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name: The resource name of the index to update.
|
||||||
|
content_path: GCS URI to the new embeddings JSON file.
|
||||||
|
**kwargs: Additional arguments.
|
||||||
|
|
||||||
|
"""
|
||||||
|
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
||||||
|
index.update_embeddings(
|
||||||
|
contents_delta_uri=content_path,
|
||||||
|
)
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def deploy_index(
|
||||||
|
self,
|
||||||
|
index_name: str,
|
||||||
|
machine_type: str = "e2-standard-2",
|
||||||
|
) -> None:
|
||||||
|
"""Deploy a Vertex AI Vector Search index to an endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name: The name of the index to deploy.
|
||||||
|
machine_type: The machine type for the endpoint.
|
||||||
|
|
||||||
|
"""
|
||||||
|
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
||||||
|
display_name=f"{index_name}-endpoint",
|
||||||
|
public_endpoint_enabled=True,
|
||||||
|
)
|
||||||
|
index_endpoint.deploy_index(
|
||||||
|
index=self.index,
|
||||||
|
deployed_index_id=(
|
||||||
|
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
|
||||||
|
),
|
||||||
|
machine_type=machine_type,
|
||||||
|
)
|
||||||
|
self.index_endpoint = index_endpoint
|
||||||
|
|
||||||
|
def load_index_endpoint(self, endpoint_name: str) -> None:
|
||||||
|
"""Load an existing Vertex AI Vector Search index endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint_name: The resource name of the index endpoint.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
||||||
|
endpoint_name,
|
||||||
|
)
|
||||||
|
if not self.index_endpoint.public_endpoint_domain_name:
|
||||||
|
msg = (
|
||||||
|
"The index endpoint does not have a public endpoint. "
|
||||||
|
"Ensure the endpoint is configured for public access."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def run_query(
|
||||||
|
self,
|
||||||
|
deployed_index_id: str,
|
||||||
|
query: list[float],
|
||||||
|
limit: int,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Run a similarity search query against the deployed index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deployed_index_id: The ID of the deployed index.
|
||||||
|
query: The embedding vector for the search query.
|
||||||
|
limit: Maximum number of nearest neighbors to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of matched items with id, distance, and content.
|
||||||
|
|
||||||
|
"""
|
||||||
|
response = self.index_endpoint.find_neighbors(
|
||||||
|
deployed_index_id=deployed_index_id,
|
||||||
|
queries=[query],
|
||||||
|
num_neighbors=limit,
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for neighbor in response[0]:
|
||||||
|
file_path = (
|
||||||
|
f"{self.index_name}/contents/{neighbor.id}.md"
|
||||||
|
)
|
||||||
|
content = (
|
||||||
|
self.storage.get_file_stream(file_path)
|
||||||
|
.read()
|
||||||
|
.decode("utf-8")
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
SearchResult(
|
||||||
|
id=neighbor.id,
|
||||||
|
distance=float(neighbor.distance or 0),
|
||||||
|
content=content,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def async_run_query(
|
||||||
|
self,
|
||||||
|
deployed_index_id: str,
|
||||||
|
query: Sequence[float],
|
||||||
|
limit: int,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Run an async similarity search via the REST API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deployed_index_id: The ID of the deployed index.
|
||||||
|
query: The embedding vector for the search query.
|
||||||
|
limit: Maximum number of nearest neighbors to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of matched items with id, distance, and content.
|
||||||
|
|
||||||
|
"""
|
||||||
|
domain = self.index_endpoint.public_endpoint_domain_name
|
||||||
|
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
||||||
|
url = (
|
||||||
|
f"https://{domain}/v1/projects/{self.project_id}"
|
||||||
|
f"/locations/{self.location}"
|
||||||
|
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||||
|
)
|
||||||
|
payload = {
|
||||||
|
"deployed_index_id": deployed_index_id,
|
||||||
|
"queries": [
|
||||||
|
{
|
||||||
|
"datapoint": {"feature_vector": list(query)},
|
||||||
|
"neighbor_count": limit,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = await self._async_get_auth_headers()
|
||||||
|
session = self._get_aio_session()
|
||||||
|
async with session.post(
|
||||||
|
url, json=payload, headers=headers,
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
data = await response.json()
|
||||||
|
|
||||||
|
neighbors = (
|
||||||
|
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||||
|
)
|
||||||
|
content_tasks = []
|
||||||
|
for neighbor in neighbors:
|
||||||
|
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||||
|
file_path = (
|
||||||
|
f"{self.index_name}/contents/{datapoint_id}.md"
|
||||||
|
)
|
||||||
|
content_tasks.append(
|
||||||
|
self.storage.async_get_file_stream(file_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
file_streams = await asyncio.gather(*content_tasks)
|
||||||
|
results: list[SearchResult] = []
|
||||||
|
for neighbor, stream in zip(
|
||||||
|
neighbors, file_streams, strict=True,
|
||||||
|
):
|
||||||
|
results.append(
|
||||||
|
SearchResult(
|
||||||
|
id=neighbor["datapoint"]["datapointId"],
|
||||||
|
distance=neighbor["distance"],
|
||||||
|
content=stream.read().decode("utf-8"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def delete_index(self, index_name: str) -> None:
|
||||||
|
"""Delete a Vertex AI Vector Search index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name: The resource name of the index.
|
||||||
|
|
||||||
|
"""
|
||||||
|
index = aiplatform.MatchingEngineIndex(index_name)
|
||||||
|
index.delete()
|
||||||
|
|
||||||
|
def delete_index_endpoint(
|
||||||
|
self, index_endpoint_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Delete a Vertex AI Vector Search index endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_endpoint_name: The resource name of the endpoint.
|
||||||
|
|
||||||
|
"""
|
||||||
|
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
||||||
|
index_endpoint_name,
|
||||||
|
)
|
||||||
|
index_endpoint.undeploy_all()
|
||||||
|
index_endpoint.delete(force=True)
|
||||||
Reference in New Issue
Block a user