Add ADK Agent implementation #1
@@ -4,20 +4,19 @@ version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" },
|
||||
{ name = "Jorge Juarez", email = "a8080816@banorte.com" }
|
||||
]
|
||||
requires-python = "~=3.12.0"
|
||||
dependencies = [
|
||||
"aiohttp>=3.13.3",
|
||||
"fastapi>=0.129.0",
|
||||
"gcloud-aio-auth>=5.4.2",
|
||||
"gcloud-aio-storage>=9.6.1",
|
||||
"google-cloud-aiplatform>=1.138.0",
|
||||
"google-cloud-storage>=3.9.0",
|
||||
"pydantic-ai-slim[google]>=1.62.0",
|
||||
"pydantic-settings[yaml]>=2.10.1",
|
||||
"google-adk>=1.14.1",
|
||||
"google-cloud-aiplatform>=1.126.1",
|
||||
"google-cloud-storage>=2.19.0",
|
||||
"pydantic-settings[yaml]>=2.13.1",
|
||||
"structlog>=25.5.0",
|
||||
"uvicorn>=0.41.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -30,6 +29,7 @@ build-backend = "uv_build"
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"clai>=1.62.0",
|
||||
"marimo>=0.20.1",
|
||||
"pytest>=8.4.1",
|
||||
"ruff>=0.12.10",
|
||||
"ty>=0.0.1a19",
|
||||
|
||||
1
rag_agent/__init__.py
Normal file
1
rag_agent/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from . import agent
|
||||
47
rag_agent/agent.py
Normal file
47
rag_agent/agent.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""ADK agent with vector search RAG tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
|
||||
from .config_helper import settings
|
||||
from .vector_search_tool import VectorSearchTool
|
||||
|
||||
# Create vector search tool with configuration
|
||||
vector_search_tool = VectorSearchTool(
|
||||
name='conocimiento',
|
||||
description='Search the vector index for company products and services information',
|
||||
embedder=settings.embedder,
|
||||
project_id=settings.project_id,
|
||||
location=settings.location,
|
||||
bucket=settings.bucket,
|
||||
index_name=settings.index_name,
|
||||
index_endpoint=settings.index_endpoint,
|
||||
index_deployed_id=settings.index_deployed_id,
|
||||
similarity_top_k=5,
|
||||
min_similarity_threshold=0.6,
|
||||
relative_threshold_factor=0.9,
|
||||
)
|
||||
|
||||
# Create agent with vector search tool
|
||||
root_agent = Agent(
|
||||
model=settings.agent_language_model,
|
||||
name=settings.agent_name,
|
||||
description='A helpful assistant for user questions.',
|
||||
instruction=settings.agent_instructions,
|
||||
tools=[vector_search_tool],
|
||||
)
|
||||
68
rag_agent/base.py
Normal file
68
rag_agent/base.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Abstract base class for vector search providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
"""A single vector search result."""
|
||||
|
||||
id: str
|
||||
distance: float
|
||||
content: str
|
||||
|
||||
|
||||
class BaseVectorSearch(ABC):
|
||||
"""Abstract base class for a vector search provider.
|
||||
|
||||
This class defines the standard interface for creating a vector search
|
||||
index and running queries against it.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_index(
|
||||
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
||||
) -> None:
|
||||
"""Create a new vector search index with the provided content.
|
||||
|
||||
Args:
|
||||
name: The desired name for the new index.
|
||||
content_path: Path to the data used to populate the index.
|
||||
**kwargs: Additional provider-specific arguments.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def update_index(
|
||||
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
||||
) -> None:
|
||||
"""Update an existing vector search index with new content.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to update.
|
||||
content_path: Path to the data used to populate the index.
|
||||
**kwargs: Additional provider-specific arguments.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: list[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run a similarity search query against the index.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
...
|
||||
120
rag_agent/config_helper.py
Normal file
120
rag_agent/config_helper.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration helper for ADK agent with vector search."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
|
||||
import vertexai
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
|
||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResult:
|
||||
"""Result from embedding a query."""
|
||||
|
||||
embeddings: list[list[float]]
|
||||
|
||||
|
||||
class VertexAIEmbedder:
|
||||
"""Embedder using Vertex AI TextEmbeddingModel."""
|
||||
|
||||
def __init__(self, model_name: str, project_id: str, location: str) -> None:
|
||||
"""Initialize the embedder.
|
||||
|
||||
Args:
|
||||
model_name: Name of the embedding model (e.g., 'text-embedding-004')
|
||||
project_id: GCP project ID
|
||||
location: GCP location
|
||||
|
||||
"""
|
||||
vertexai.init(project=project_id, location=location)
|
||||
self.model = TextEmbeddingModel.from_pretrained(model_name)
|
||||
|
||||
async def embed_query(self, query: str) -> EmbeddingResult:
|
||||
"""Embed a single query string.
|
||||
|
||||
Args:
|
||||
query: Text to embed
|
||||
|
||||
Returns:
|
||||
EmbeddingResult with embeddings list
|
||||
|
||||
"""
|
||||
embeddings = self.model.get_embeddings([query])
|
||||
return EmbeddingResult(embeddings=[list(embeddings[0].values)])
|
||||
|
||||
|
||||
class AgentSettings(BaseSettings):
|
||||
"""Settings for ADK agent with vector search."""
|
||||
|
||||
# Google Cloud settings
|
||||
project_id: str
|
||||
location: str
|
||||
bucket: str
|
||||
|
||||
# Agent configuration
|
||||
agent_name: str
|
||||
agent_instructions: str
|
||||
agent_language_model: str
|
||||
agent_embedding_model: str
|
||||
|
||||
# Vector index configuration
|
||||
index_name: str
|
||||
index_deployed_id: str
|
||||
index_endpoint: str
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
yaml_file=CONFIG_FILE_PATH,
|
||||
extra="ignore", # Ignore extra fields from config.yaml
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""Use env vars and YAML as settings sources."""
|
||||
return (
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def embedder(self) -> VertexAIEmbedder:
|
||||
"""Return an embedder configured for the agent's embedding model."""
|
||||
return VertexAIEmbedder(
|
||||
model_name=self.agent_embedding_model,
|
||||
project_id=self.project_id,
|
||||
location=self.location,
|
||||
)
|
||||
|
||||
|
||||
settings = AgentSettings.model_validate({})
|
||||
1
rag_agent/file_storage/__init__.py
Normal file
1
rag_agent/file_storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""File storage provider implementations."""
|
||||
56
rag_agent/file_storage/base.py
Normal file
56
rag_agent/file_storage/base.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Abstract base class for file storage providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO
|
||||
|
||||
|
||||
class BaseFileStorage(ABC):
|
||||
"""Abstract base class for a remote file processor.
|
||||
|
||||
Defines the interface for listing and processing files from
|
||||
a remote source.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""Upload a file to the remote source.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: Name of the file in remote storage.
|
||||
content_type: The content type of the file.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List files from a remote location.
|
||||
|
||||
Args:
|
||||
path: Path to a specific file or directory. If None,
|
||||
recursively lists all files in the bucket.
|
||||
|
||||
Returns:
|
||||
A list of file paths.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""Get a file from the remote source as a file-like object.
|
||||
|
||||
Args:
|
||||
file_name: The name of the file to retrieve.
|
||||
|
||||
Returns:
|
||||
A file-like object containing the file data.
|
||||
|
||||
"""
|
||||
...
|
||||
188
rag_agent/file_storage/google_cloud.py
Normal file
188
rag_agent/file_storage/google_cloud.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Google Cloud Storage file storage implementation."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
from typing import BinaryIO
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio.storage import Storage
|
||||
from google.cloud import storage
|
||||
|
||||
from .base import BaseFileStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HTTP_TOO_MANY_REQUESTS = 429
|
||||
HTTP_SERVER_ERROR = 500
|
||||
|
||||
|
||||
class GoogleCloudFileStorage(BaseFileStorage):
|
||||
"""File storage backed by Google Cloud Storage."""
|
||||
|
||||
def __init__(self, bucket: str) -> None: # noqa: D107
|
||||
self.bucket_name = bucket
|
||||
|
||||
self.storage_client = storage.Client()
|
||||
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._aio_storage: Storage | None = None
|
||||
self._cache: dict[str, bytes] = {}
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""Upload a file to Cloud Storage.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: Name of the blob in the bucket.
|
||||
content_type: The content type of the file.
|
||||
|
||||
"""
|
||||
blob = self.bucket_client.blob(destination_blob_name)
|
||||
blob.upload_from_filename(
|
||||
file_path,
|
||||
content_type=content_type,
|
||||
if_generation_match=0,
|
||||
)
|
||||
self._cache.pop(destination_blob_name, None)
|
||||
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List all files at the given path in the bucket.
|
||||
|
||||
If path is None, recursively lists all files.
|
||||
|
||||
Args:
|
||||
path: Prefix to filter files by.
|
||||
|
||||
Returns:
|
||||
A list of blob names.
|
||||
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(
|
||||
self.bucket_name, prefix=path,
|
||||
)
|
||||
return [blob.name for blob in blobs]
|
||||
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""Get a file as a file-like object, using cache.
|
||||
|
||||
Args:
|
||||
file_name: The blob name to retrieve.
|
||||
|
||||
Returns:
|
||||
A BytesIO stream with the file contents.
|
||||
|
||||
"""
|
||||
if file_name not in self._cache:
|
||||
blob = self.bucket_client.blob(file_name)
|
||||
self._cache[file_name] = blob.download_as_bytes()
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=300, limit_per_host=50,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector,
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def _get_aio_storage(self) -> Storage:
|
||||
if self._aio_storage is None:
|
||||
self._aio_storage = Storage(
|
||||
session=self._get_aio_session(),
|
||||
)
|
||||
return self._aio_storage
|
||||
|
||||
async def async_get_file_stream(
|
||||
self, file_name: str, max_retries: int = 3,
|
||||
) -> BinaryIO:
|
||||
"""Get a file asynchronously with retry on transient errors.
|
||||
|
||||
Args:
|
||||
file_name: The blob name to retrieve.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
|
||||
Returns:
|
||||
A BytesIO stream with the file contents.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If all retry attempts fail.
|
||||
|
||||
"""
|
||||
if file_name in self._cache:
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
storage_client = self._get_aio_storage()
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self._cache[file_name] = await storage_client.download(
|
||||
self.bucket_name, file_name,
|
||||
)
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
except TimeoutError as exc:
|
||||
last_exception = exc
|
||||
logger.warning(
|
||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
last_exception = exc
|
||||
if (
|
||||
exc.status == HTTP_TOO_MANY_REQUESTS
|
||||
or exc.status >= HTTP_SERVER_ERROR
|
||||
):
|
||||
logger.warning(
|
||||
"HTTP %d downloading gs://%s/%s "
|
||||
"(attempt %d/%d)",
|
||||
exc.status,
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
return file_stream
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
delay = 0.5 * (2**attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
msg = (
|
||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
||||
f"after {max_retries} attempts"
|
||||
)
|
||||
raise TimeoutError(msg) from last_exception
|
||||
|
||||
def delete_files(self, path: str) -> None:
|
||||
"""Delete all files at the given path in the bucket.
|
||||
|
||||
Args:
|
||||
path: Prefix of blobs to delete.
|
||||
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(
|
||||
self.bucket_name, prefix=path,
|
||||
)
|
||||
for blob in blobs:
|
||||
blob.delete()
|
||||
self._cache.pop(blob.name, None)
|
||||
176
rag_agent/vector_search_tool.py
Normal file
176
rag_agent/vector_search_tool.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A retrieval tool that uses Vertex AI Vector Search (not RAG Engine)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from typing_extensions import override
|
||||
|
||||
from .vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config_helper import VertexAIEmbedder
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
|
||||
class VectorSearchTool(BaseRetrievalTool):
|
||||
"""A retrieval tool using Vertex AI Vector Search (not RAG Engine).
|
||||
|
||||
This tool uses GoogleCloudVectorSearch to query a vector index directly,
|
||||
which is useful when Vertex AI RAG Engine is not available in your GCP project.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
description: str,
|
||||
embedder: VertexAIEmbedder,
|
||||
project_id: str,
|
||||
location: str,
|
||||
bucket: str,
|
||||
index_name: str,
|
||||
index_endpoint: str,
|
||||
index_deployed_id: str,
|
||||
similarity_top_k: int = 5,
|
||||
min_similarity_threshold: float = 0.6,
|
||||
relative_threshold_factor: float = 0.9,
|
||||
):
|
||||
"""Initialize the VectorSearchTool.
|
||||
|
||||
Args:
|
||||
name: Tool name for function declaration
|
||||
description: Tool description for LLM
|
||||
embedder: Embedder instance for query embedding
|
||||
project_id: GCP project ID
|
||||
location: GCP location (e.g., 'us-central1')
|
||||
bucket: GCS bucket for content storage
|
||||
index_name: Vector search index name
|
||||
index_endpoint: Resource name of index endpoint
|
||||
index_deployed_id: Deployed index ID
|
||||
similarity_top_k: Number of results to retrieve (default: 5)
|
||||
min_similarity_threshold: Minimum similarity score 0.0-1.0 (default: 0.6)
|
||||
relative_threshold_factor: Factor of max similarity for dynamic filtering (default: 0.9)
|
||||
"""
|
||||
super().__init__(name=name, description=description)
|
||||
|
||||
self.embedder = embedder
|
||||
self.index_endpoint = index_endpoint
|
||||
self.index_deployed_id = index_deployed_id
|
||||
self.similarity_top_k = similarity_top_k
|
||||
self.min_similarity_threshold = min_similarity_threshold
|
||||
self.relative_threshold_factor = relative_threshold_factor
|
||||
|
||||
# Initialize vector search (endpoint loaded lazily on first use)
|
||||
self.vector_search = GoogleCloudVectorSearch(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
bucket=bucket,
|
||||
index_name=index_name,
|
||||
)
|
||||
self._endpoint_loaded = False
|
||||
|
||||
logger.info(
|
||||
'VectorSearchTool initialized with index=%s, deployed_id=%s',
|
||||
index_name,
|
||||
index_deployed_id,
|
||||
)
|
||||
|
||||
@override
|
||||
async def run_async(
|
||||
self,
|
||||
*,
|
||||
args: dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
) -> Any:
|
||||
"""Execute vector search with the user's query.
|
||||
|
||||
Args:
|
||||
args: Dictionary containing 'query' key
|
||||
tool_context: Tool execution context
|
||||
|
||||
Returns:
|
||||
Formatted search results as XML-like documents or error message
|
||||
"""
|
||||
query = args['query']
|
||||
logger.debug('VectorSearchTool query: %s', query)
|
||||
|
||||
try:
|
||||
# Load index endpoint on first use (lazy loading)
|
||||
if not self._endpoint_loaded:
|
||||
self.vector_search.load_index_endpoint(self.index_endpoint)
|
||||
self._endpoint_loaded = True
|
||||
logger.info('Index endpoint loaded successfully')
|
||||
|
||||
# Embed the query using the configured embedder
|
||||
embedding_result = await self.embedder.embed_query(query)
|
||||
query_embedding = list(embedding_result.embeddings[0])
|
||||
|
||||
# Run vector search
|
||||
search_results = await self.vector_search.async_run_query(
|
||||
deployed_index_id=self.index_deployed_id,
|
||||
query=query_embedding,
|
||||
limit=self.similarity_top_k,
|
||||
)
|
||||
|
||||
# Apply similarity filtering (dual threshold approach)
|
||||
if search_results:
|
||||
# Dynamic threshold based on max similarity
|
||||
max_similarity = max(r['distance'] for r in search_results)
|
||||
dynamic_cutoff = max_similarity * self.relative_threshold_factor
|
||||
|
||||
# Filter by both absolute and relative thresholds
|
||||
search_results = [
|
||||
result
|
||||
for result in search_results
|
||||
if (
|
||||
result['distance'] > dynamic_cutoff
|
||||
and result['distance'] > self.min_similarity_threshold
|
||||
)
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
'VectorSearchTool results: %d documents after filtering',
|
||||
len(search_results),
|
||||
)
|
||||
|
||||
# Format results
|
||||
if not search_results:
|
||||
return (
|
||||
f"No matching documents found for query: '{query}' "
|
||||
f'(min_threshold={self.min_similarity_threshold})'
|
||||
)
|
||||
|
||||
# Format as XML-like documents (matching pydantic_ai pattern)
|
||||
formatted_results = [
|
||||
f'<document {i} name={result["id"]}>\n'
|
||||
f'{result["content"]}\n'
|
||||
f'</document {i}>'
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
]
|
||||
|
||||
return '\n'.join(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error('VectorSearchTool error: %s', e, exc_info=True)
|
||||
return f'Error during vector search: {str(e)}'
|
||||
310
rag_agent/vertex_ai.py
Normal file
310
rag_agent/vertex_ai.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Google Cloud Vertex AI Vector Search implementation."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import google.auth
|
||||
import google.auth.credentials
|
||||
import google.auth.transport.requests
|
||||
from gcloud.aio.auth import Token
|
||||
from google.cloud import aiplatform
|
||||
|
||||
from .file_storage.google_cloud import GoogleCloudFileStorage
|
||||
from .base import BaseVectorSearch, SearchResult
|
||||
|
||||
|
||||
class GoogleCloudVectorSearch(BaseVectorSearch):
|
||||
"""A vector search provider using Vertex AI Vector Search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
bucket: str,
|
||||
index_name: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the GoogleCloudVectorSearch client.
|
||||
|
||||
Args:
|
||||
project_id: The Google Cloud project ID.
|
||||
location: The Google Cloud location (e.g., 'us-central1').
|
||||
bucket: The GCS bucket to use for file storage.
|
||||
index_name: The name of the index.
|
||||
|
||||
"""
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||
self.index_name = index_name
|
||||
self._credentials: google.auth.credentials.Credentials | None = None
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._async_token: Token | None = None
|
||||
|
||||
def _get_auth_headers(self) -> dict[str, str]:
|
||||
if self._credentials is None:
|
||||
self._credentials, _ = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
if not self._credentials.token or self._credentials.expired:
|
||||
self._credentials.refresh(
|
||||
google.auth.transport.requests.Request(),
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {self._credentials.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def _async_get_auth_headers(self) -> dict[str, str]:
|
||||
if self._async_token is None:
|
||||
self._async_token = Token(
|
||||
session=self._get_aio_session(),
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
],
|
||||
)
|
||||
access_token = await self._async_token.get()
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=300, limit_per_host=50,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector,
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
name: str,
|
||||
content_path: str,
|
||||
*,
|
||||
dimensions: int = 3072,
|
||||
approximate_neighbors_count: int = 150,
|
||||
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
||||
**kwargs: Any, # noqa: ANN401, ARG002
|
||||
) -> None:
|
||||
"""Create a new Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
name: The display name for the new index.
|
||||
content_path: GCS URI to the embeddings JSON file.
|
||||
dimensions: Number of dimensions in embedding vectors.
|
||||
approximate_neighbors_count: Neighbors to find per vector.
|
||||
distance_measure_type: The distance measure to use.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
||||
display_name=name,
|
||||
contents_delta_uri=content_path,
|
||||
dimensions=dimensions,
|
||||
approximate_neighbors_count=approximate_neighbors_count,
|
||||
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
|
||||
leaf_node_embedding_count=1000,
|
||||
leaf_nodes_to_search_percent=10,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def update_index(
|
||||
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
|
||||
) -> None:
|
||||
"""Update an existing Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index to update.
|
||||
content_path: GCS URI to the new embeddings JSON file.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
||||
index.update_embeddings(
|
||||
contents_delta_uri=content_path,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def deploy_index(
|
||||
self,
|
||||
index_name: str,
|
||||
machine_type: str = "e2-standard-2",
|
||||
) -> None:
|
||||
"""Deploy a Vertex AI Vector Search index to an endpoint.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to deploy.
|
||||
machine_type: The machine type for the endpoint.
|
||||
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
||||
display_name=f"{index_name}-endpoint",
|
||||
public_endpoint_enabled=True,
|
||||
)
|
||||
index_endpoint.deploy_index(
|
||||
index=self.index,
|
||||
deployed_index_id=(
|
||||
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
|
||||
),
|
||||
machine_type=machine_type,
|
||||
)
|
||||
self.index_endpoint = index_endpoint
|
||||
|
||||
def load_index_endpoint(self, endpoint_name: str) -> None:
|
||||
"""Load an existing Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: The resource name of the index endpoint.
|
||||
|
||||
"""
|
||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
||||
endpoint_name,
|
||||
)
|
||||
if not self.index_endpoint.public_endpoint_domain_name:
|
||||
msg = (
|
||||
"The index endpoint does not have a public endpoint. "
|
||||
"Ensure the endpoint is configured for public access."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: list[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run a similarity search query against the deployed index.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
response = self.index_endpoint.find_neighbors(
|
||||
deployed_index_id=deployed_index_id,
|
||||
queries=[query],
|
||||
num_neighbors=limit,
|
||||
)
|
||||
results = []
|
||||
for neighbor in response[0]:
|
||||
file_path = (
|
||||
f"{self.index_name}/contents/{neighbor.id}.md"
|
||||
)
|
||||
content = (
|
||||
self.storage.get_file_stream(file_path)
|
||||
.read()
|
||||
.decode("utf-8")
|
||||
)
|
||||
results.append(
|
||||
SearchResult(
|
||||
id=neighbor.id,
|
||||
distance=float(neighbor.distance or 0),
|
||||
content=content,
|
||||
),
|
||||
)
|
||||
return results
|
||||
|
||||
async def async_run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: Sequence[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run an async similarity search via the REST API.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
domain = self.index_endpoint.public_endpoint_domain_name
|
||||
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
||||
url = (
|
||||
f"https://{domain}/v1/projects/{self.project_id}"
|
||||
f"/locations/{self.location}"
|
||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||
)
|
||||
payload = {
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"queries": [
|
||||
{
|
||||
"datapoint": {"feature_vector": list(query)},
|
||||
"neighbor_count": limit,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
headers = await self._async_get_auth_headers()
|
||||
session = self._get_aio_session()
|
||||
async with session.post(
|
||||
url, json=payload, headers=headers,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
neighbors = (
|
||||
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||
)
|
||||
content_tasks = []
|
||||
for neighbor in neighbors:
|
||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||
file_path = (
|
||||
f"{self.index_name}/contents/{datapoint_id}.md"
|
||||
)
|
||||
content_tasks.append(
|
||||
self.storage.async_get_file_stream(file_path),
|
||||
)
|
||||
|
||||
file_streams = await asyncio.gather(*content_tasks)
|
||||
results: list[SearchResult] = []
|
||||
for neighbor, stream in zip(
|
||||
neighbors, file_streams, strict=True,
|
||||
):
|
||||
results.append(
|
||||
SearchResult(
|
||||
id=neighbor["datapoint"]["datapointId"],
|
||||
distance=neighbor["distance"],
|
||||
content=stream.read().decode("utf-8"),
|
||||
),
|
||||
)
|
||||
return results
|
||||
|
||||
def delete_index(self, index_name: str) -> None:
|
||||
"""Delete a Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index.
|
||||
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name)
|
||||
index.delete()
|
||||
|
||||
def delete_index_endpoint(
|
||||
self, index_endpoint_name: str,
|
||||
) -> None:
|
||||
"""Delete a Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
index_endpoint_name: The resource name of the endpoint.
|
||||
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
||||
index_endpoint_name,
|
||||
)
|
||||
index_endpoint.undeploy_all()
|
||||
index_endpoint.delete(force=True)
|
||||
Reference in New Issue
Block a user