From da599fbbef44af94494835a6bad18ebb065590da Mon Sep 17 00:00:00 2001 From: A8080816 Date: Fri, 20 Feb 2026 23:02:57 +0000 Subject: [PATCH] dev: add ADK agent with vector search tool and Google Cloud file storage implementation --- pyproject.toml | 14 +- rag_agent/__init__.py | 1 + rag_agent/agent.py | 47 ++++ rag_agent/base.py | 68 ++++++ rag_agent/config_helper.py | 120 ++++++++++ rag_agent/file_storage/__init__.py | 1 + rag_agent/file_storage/base.py | 56 +++++ rag_agent/file_storage/google_cloud.py | 188 +++++++++++++++ rag_agent/vector_search_tool.py | 176 ++++++++++++++ rag_agent/vertex_ai.py | 310 +++++++++++++++++++++++++ 10 files changed, 974 insertions(+), 7 deletions(-) create mode 100644 rag_agent/__init__.py create mode 100644 rag_agent/agent.py create mode 100644 rag_agent/base.py create mode 100644 rag_agent/config_helper.py create mode 100644 rag_agent/file_storage/__init__.py create mode 100644 rag_agent/file_storage/base.py create mode 100644 rag_agent/file_storage/google_cloud.py create mode 100644 rag_agent/vector_search_tool.py create mode 100644 rag_agent/vertex_ai.py diff --git a/pyproject.toml b/pyproject.toml index 893a8ad..b27bf29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,20 +4,19 @@ version = "0.1.0" description = "Add your description here" readme = "README.md" authors = [ - { name = "Anibal Angulo", email = "a8065384@banorte.com" } + { name = "Anibal Angulo", email = "a8065384@banorte.com" }, + { name = "Jorge Juarez", email = "a8080816@banorte.com" } ] requires-python = "~=3.12.0" dependencies = [ "aiohttp>=3.13.3", - "fastapi>=0.129.0", "gcloud-aio-auth>=5.4.2", "gcloud-aio-storage>=9.6.1", - "google-cloud-aiplatform>=1.138.0", - "google-cloud-storage>=3.9.0", - "pydantic-ai-slim[google]>=1.62.0", - "pydantic-settings[yaml]>=2.10.1", + "google-adk>=1.14.1", + "google-cloud-aiplatform>=1.126.1", + "google-cloud-storage>=2.19.0", + "pydantic-settings[yaml]>=2.13.1", "structlog>=25.5.0", - "uvicorn>=0.41.0", ] [project.scripts] @@ -30,6 +29,7 @@ build-backend = "uv_build" [dependency-groups] dev = [ "clai>=1.62.0", + "marimo>=0.20.1", "pytest>=8.4.1", "ruff>=0.12.10", "ty>=0.0.1a19", diff --git a/rag_agent/__init__.py b/rag_agent/__init__.py new file mode 100644 index 0000000..02c597e --- /dev/null +++ b/rag_agent/__init__.py @@ -0,0 +1 @@ +from . import agent diff --git a/rag_agent/agent.py b/rag_agent/agent.py new file mode 100644 index 0000000..a1e6726 --- /dev/null +++ b/rag_agent/agent.py @@ -0,0 +1,47 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ADK agent with vector search RAG tool.""" + +from __future__ import annotations + +from google.adk.agents.llm_agent import Agent + +from .config_helper import settings +from .vector_search_tool import VectorSearchTool + +# Create vector search tool with configuration +vector_search_tool = VectorSearchTool( + name='conocimiento', + description='Search the vector index for company products and services information', + embedder=settings.embedder, + project_id=settings.project_id, + location=settings.location, + bucket=settings.bucket, + index_name=settings.index_name, + index_endpoint=settings.index_endpoint, + index_deployed_id=settings.index_deployed_id, + similarity_top_k=5, + min_similarity_threshold=0.6, + relative_threshold_factor=0.9, +) + +# Create agent with vector search tool +root_agent = Agent( + model=settings.agent_language_model, + name=settings.agent_name, + description='A helpful assistant for user questions.', + instruction=settings.agent_instructions, + tools=[vector_search_tool], +) diff --git a/rag_agent/base.py b/rag_agent/base.py new file mode 100644 index 0000000..ab00142 --- /dev/null +++ b/rag_agent/base.py @@ -0,0 +1,68 @@ +"""Abstract base class for vector search providers.""" + +from abc import ABC, abstractmethod +from typing import Any, TypedDict + + +class SearchResult(TypedDict): + """A single vector search result.""" + + id: str + distance: float + content: str + + +class BaseVectorSearch(ABC): + """Abstract base class for a vector search provider. + + This class defines the standard interface for creating a vector search + index and running queries against it. + """ + + @abstractmethod + def create_index( + self, name: str, content_path: str, **kwargs: Any # noqa: ANN401 + ) -> None: + """Create a new vector search index with the provided content. + + Args: + name: The desired name for the new index. + content_path: Path to the data used to populate the index. + **kwargs: Additional provider-specific arguments. + + """ + ... + + @abstractmethod + def update_index( + self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401 + ) -> None: + """Update an existing vector search index with new content. + + Args: + index_name: The name of the index to update. + content_path: Path to the data used to populate the index. + **kwargs: Additional provider-specific arguments. + + """ + ... + + @abstractmethod + def run_query( + self, + deployed_index_id: str, + query: list[float], + limit: int, + ) -> list[SearchResult]: + """Run a similarity search query against the index. + + Args: + deployed_index_id: The ID of the deployed index. + query: The embedding vector for the search query. + limit: Maximum number of nearest neighbors to return. + + Returns: + A list of matched items with id, distance, and content. + + """ + ... diff --git a/rag_agent/config_helper.py b/rag_agent/config_helper.py new file mode 100644 index 0000000..1f773bf --- /dev/null +++ b/rag_agent/config_helper.py @@ -0,0 +1,120 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration helper for ADK agent with vector search.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from functools import cached_property + +import vertexai +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, + YamlConfigSettingsSource, +) +from vertexai.language_models import TextEmbeddingModel + +CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml") + + +@dataclass +class EmbeddingResult: + """Result from embedding a query.""" + + embeddings: list[list[float]] + + +class VertexAIEmbedder: + """Embedder using Vertex AI TextEmbeddingModel.""" + + def __init__(self, model_name: str, project_id: str, location: str) -> None: + """Initialize the embedder. + + Args: + model_name: Name of the embedding model (e.g., 'text-embedding-004') + project_id: GCP project ID + location: GCP location + + """ + vertexai.init(project=project_id, location=location) + self.model = TextEmbeddingModel.from_pretrained(model_name) + + async def embed_query(self, query: str) -> EmbeddingResult: + """Embed a single query string. + + Args: + query: Text to embed + + Returns: + EmbeddingResult with embeddings list + + """ + embeddings = self.model.get_embeddings([query]) + return EmbeddingResult(embeddings=[list(embeddings[0].values)]) + + +class AgentSettings(BaseSettings): + """Settings for ADK agent with vector search.""" + + # Google Cloud settings + project_id: str + location: str + bucket: str + + # Agent configuration + agent_name: str + agent_instructions: str + agent_language_model: str + agent_embedding_model: str + + # Vector index configuration + index_name: str + index_deployed_id: str + index_endpoint: str + + model_config = SettingsConfigDict( + yaml_file=CONFIG_FILE_PATH, + extra="ignore", # Ignore extra fields from config.yaml + ) + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, # noqa: ARG003 + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003 + file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003 + ) -> tuple[PydanticBaseSettingsSource, ...]: + """Use env vars and YAML as settings sources.""" + return ( + env_settings, + YamlConfigSettingsSource(settings_cls), + ) + + @cached_property + def embedder(self) -> VertexAIEmbedder: + """Return an embedder configured for the agent's embedding model.""" + return VertexAIEmbedder( + model_name=self.agent_embedding_model, + project_id=self.project_id, + location=self.location, + ) + + +settings = AgentSettings.model_validate({}) diff --git a/rag_agent/file_storage/__init__.py b/rag_agent/file_storage/__init__.py new file mode 100644 index 0000000..7f88e76 --- /dev/null +++ b/rag_agent/file_storage/__init__.py @@ -0,0 +1 @@ +"""File storage provider implementations.""" diff --git a/rag_agent/file_storage/base.py b/rag_agent/file_storage/base.py new file mode 100644 index 0000000..a98ea39 --- /dev/null +++ b/rag_agent/file_storage/base.py @@ -0,0 +1,56 @@ +"""Abstract base class for file storage providers.""" + +from abc import ABC, abstractmethod +from typing import BinaryIO + + +class BaseFileStorage(ABC): + """Abstract base class for a remote file processor. + + Defines the interface for listing and processing files from + a remote source. + """ + + @abstractmethod + def upload_file( + self, + file_path: str, + destination_blob_name: str, + content_type: str | None = None, + ) -> None: + """Upload a file to the remote source. + + Args: + file_path: The local path to the file to upload. + destination_blob_name: Name of the file in remote storage. + content_type: The content type of the file. + + """ + ... + + @abstractmethod + def list_files(self, path: str | None = None) -> list[str]: + """List files from a remote location. + + Args: + path: Path to a specific file or directory. If None, + recursively lists all files in the bucket. + + Returns: + A list of file paths. + + """ + ... + + @abstractmethod + def get_file_stream(self, file_name: str) -> BinaryIO: + """Get a file from the remote source as a file-like object. + + Args: + file_name: The name of the file to retrieve. + + Returns: + A file-like object containing the file data. + + """ + ... diff --git a/rag_agent/file_storage/google_cloud.py b/rag_agent/file_storage/google_cloud.py new file mode 100644 index 0000000..7848d97 --- /dev/null +++ b/rag_agent/file_storage/google_cloud.py @@ -0,0 +1,188 @@ +"""Google Cloud Storage file storage implementation.""" + +import asyncio +import io +import logging +from typing import BinaryIO + +import aiohttp +from gcloud.aio.storage import Storage +from google.cloud import storage + +from .base import BaseFileStorage + +logger = logging.getLogger(__name__) + +HTTP_TOO_MANY_REQUESTS = 429 +HTTP_SERVER_ERROR = 500 + + +class GoogleCloudFileStorage(BaseFileStorage): + """File storage backed by Google Cloud Storage.""" + + def __init__(self, bucket: str) -> None: # noqa: D107 + self.bucket_name = bucket + + self.storage_client = storage.Client() + self.bucket_client = self.storage_client.bucket(self.bucket_name) + self._aio_session: aiohttp.ClientSession | None = None + self._aio_storage: Storage | None = None + self._cache: dict[str, bytes] = {} + + def upload_file( + self, + file_path: str, + destination_blob_name: str, + content_type: str | None = None, + ) -> None: + """Upload a file to Cloud Storage. + + Args: + file_path: The local path to the file to upload. + destination_blob_name: Name of the blob in the bucket. + content_type: The content type of the file. + + """ + blob = self.bucket_client.blob(destination_blob_name) + blob.upload_from_filename( + file_path, + content_type=content_type, + if_generation_match=0, + ) + self._cache.pop(destination_blob_name, None) + + def list_files(self, path: str | None = None) -> list[str]: + """List all files at the given path in the bucket. + + If path is None, recursively lists all files. + + Args: + path: Prefix to filter files by. + + Returns: + A list of blob names. + + """ + blobs = self.storage_client.list_blobs( + self.bucket_name, prefix=path, + ) + return [blob.name for blob in blobs] + + def get_file_stream(self, file_name: str) -> BinaryIO: + """Get a file as a file-like object, using cache. + + Args: + file_name: The blob name to retrieve. + + Returns: + A BytesIO stream with the file contents. + + """ + if file_name not in self._cache: + blob = self.bucket_client.blob(file_name) + self._cache[file_name] = blob.download_as_bytes() + file_stream = io.BytesIO(self._cache[file_name]) + file_stream.name = file_name + return file_stream + + def _get_aio_session(self) -> aiohttp.ClientSession: + if self._aio_session is None or self._aio_session.closed: + connector = aiohttp.TCPConnector( + limit=300, limit_per_host=50, + ) + timeout = aiohttp.ClientTimeout(total=60) + self._aio_session = aiohttp.ClientSession( + timeout=timeout, connector=connector, + ) + return self._aio_session + + def _get_aio_storage(self) -> Storage: + if self._aio_storage is None: + self._aio_storage = Storage( + session=self._get_aio_session(), + ) + return self._aio_storage + + async def async_get_file_stream( + self, file_name: str, max_retries: int = 3, + ) -> BinaryIO: + """Get a file asynchronously with retry on transient errors. + + Args: + file_name: The blob name to retrieve. + max_retries: Maximum number of retry attempts. + + Returns: + A BytesIO stream with the file contents. + + Raises: + TimeoutError: If all retry attempts fail. + + """ + if file_name in self._cache: + file_stream = io.BytesIO(self._cache[file_name]) + file_stream.name = file_name + return file_stream + + storage_client = self._get_aio_storage() + last_exception: Exception | None = None + + for attempt in range(max_retries): + try: + self._cache[file_name] = await storage_client.download( + self.bucket_name, file_name, + ) + file_stream = io.BytesIO(self._cache[file_name]) + file_stream.name = file_name + except TimeoutError as exc: + last_exception = exc + logger.warning( + "Timeout downloading gs://%s/%s (attempt %d/%d)", + self.bucket_name, + file_name, + attempt + 1, + max_retries, + ) + except aiohttp.ClientResponseError as exc: + last_exception = exc + if ( + exc.status == HTTP_TOO_MANY_REQUESTS + or exc.status >= HTTP_SERVER_ERROR + ): + logger.warning( + "HTTP %d downloading gs://%s/%s " + "(attempt %d/%d)", + exc.status, + self.bucket_name, + file_name, + attempt + 1, + max_retries, + ) + else: + raise + else: + return file_stream + + if attempt < max_retries - 1: + delay = 0.5 * (2**attempt) + await asyncio.sleep(delay) + + msg = ( + f"Failed to download gs://{self.bucket_name}/{file_name} " + f"after {max_retries} attempts" + ) + raise TimeoutError(msg) from last_exception + + def delete_files(self, path: str) -> None: + """Delete all files at the given path in the bucket. + + Args: + path: Prefix of blobs to delete. + + """ + blobs = self.storage_client.list_blobs( + self.bucket_name, prefix=path, + ) + for blob in blobs: + blob.delete() + self._cache.pop(blob.name, None) diff --git a/rag_agent/vector_search_tool.py b/rag_agent/vector_search_tool.py new file mode 100644 index 0000000..1393065 --- /dev/null +++ b/rag_agent/vector_search_tool.py @@ -0,0 +1,176 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A retrieval tool that uses Vertex AI Vector Search (not RAG Engine).""" + +from __future__ import annotations + +import logging +from typing import Any +from typing import TYPE_CHECKING + +from google.adk.tools.tool_context import ToolContext +from typing_extensions import override + +from .vertex_ai import GoogleCloudVectorSearch + +from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool + +if TYPE_CHECKING: + from .config_helper import VertexAIEmbedder + +logger = logging.getLogger('google_adk.' + __name__) + + +class VectorSearchTool(BaseRetrievalTool): + """A retrieval tool using Vertex AI Vector Search (not RAG Engine). + + This tool uses GoogleCloudVectorSearch to query a vector index directly, + which is useful when Vertex AI RAG Engine is not available in your GCP project. + """ + + def __init__( + self, + *, + name: str, + description: str, + embedder: VertexAIEmbedder, + project_id: str, + location: str, + bucket: str, + index_name: str, + index_endpoint: str, + index_deployed_id: str, + similarity_top_k: int = 5, + min_similarity_threshold: float = 0.6, + relative_threshold_factor: float = 0.9, + ): + """Initialize the VectorSearchTool. + + Args: + name: Tool name for function declaration + description: Tool description for LLM + embedder: Embedder instance for query embedding + project_id: GCP project ID + location: GCP location (e.g., 'us-central1') + bucket: GCS bucket for content storage + index_name: Vector search index name + index_endpoint: Resource name of index endpoint + index_deployed_id: Deployed index ID + similarity_top_k: Number of results to retrieve (default: 5) + min_similarity_threshold: Minimum similarity score 0.0-1.0 (default: 0.6) + relative_threshold_factor: Factor of max similarity for dynamic filtering (default: 0.9) + """ + super().__init__(name=name, description=description) + + self.embedder = embedder + self.index_endpoint = index_endpoint + self.index_deployed_id = index_deployed_id + self.similarity_top_k = similarity_top_k + self.min_similarity_threshold = min_similarity_threshold + self.relative_threshold_factor = relative_threshold_factor + + # Initialize vector search (endpoint loaded lazily on first use) + self.vector_search = GoogleCloudVectorSearch( + project_id=project_id, + location=location, + bucket=bucket, + index_name=index_name, + ) + self._endpoint_loaded = False + + logger.info( + 'VectorSearchTool initialized with index=%s, deployed_id=%s', + index_name, + index_deployed_id, + ) + + @override + async def run_async( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + ) -> Any: + """Execute vector search with the user's query. + + Args: + args: Dictionary containing 'query' key + tool_context: Tool execution context + + Returns: + Formatted search results as XML-like documents or error message + """ + query = args['query'] + logger.debug('VectorSearchTool query: %s', query) + + try: + # Load index endpoint on first use (lazy loading) + if not self._endpoint_loaded: + self.vector_search.load_index_endpoint(self.index_endpoint) + self._endpoint_loaded = True + logger.info('Index endpoint loaded successfully') + + # Embed the query using the configured embedder + embedding_result = await self.embedder.embed_query(query) + query_embedding = list(embedding_result.embeddings[0]) + + # Run vector search + search_results = await self.vector_search.async_run_query( + deployed_index_id=self.index_deployed_id, + query=query_embedding, + limit=self.similarity_top_k, + ) + + # Apply similarity filtering (dual threshold approach) + if search_results: + # Dynamic threshold based on max similarity + max_similarity = max(r['distance'] for r in search_results) + dynamic_cutoff = max_similarity * self.relative_threshold_factor + + # Filter by both absolute and relative thresholds + search_results = [ + result + for result in search_results + if ( + result['distance'] > dynamic_cutoff + and result['distance'] > self.min_similarity_threshold + ) + ] + + logger.debug( + 'VectorSearchTool results: %d documents after filtering', + len(search_results), + ) + + # Format results + if not search_results: + return ( + f"No matching documents found for query: '{query}' " + f'(min_threshold={self.min_similarity_threshold})' + ) + + # Format as XML-like documents (matching pydantic_ai pattern) + formatted_results = [ + f'\n' + f'{result["content"]}\n' + f'' + for i, result in enumerate(search_results, start=1) + ] + + return '\n'.join(formatted_results) + + except Exception as e: + logger.error('VectorSearchTool error: %s', e, exc_info=True) + return f'Error during vector search: {str(e)}' diff --git a/rag_agent/vertex_ai.py b/rag_agent/vertex_ai.py new file mode 100644 index 0000000..1371c00 --- /dev/null +++ b/rag_agent/vertex_ai.py @@ -0,0 +1,310 @@ +"""Google Cloud Vertex AI Vector Search implementation.""" + +import asyncio +from collections.abc import Sequence +from typing import Any +from uuid import uuid4 + +import aiohttp +import google.auth +import google.auth.credentials +import google.auth.transport.requests +from gcloud.aio.auth import Token +from google.cloud import aiplatform + +from .file_storage.google_cloud import GoogleCloudFileStorage +from .base import BaseVectorSearch, SearchResult + + +class GoogleCloudVectorSearch(BaseVectorSearch): + """A vector search provider using Vertex AI Vector Search.""" + + def __init__( + self, + project_id: str, + location: str, + bucket: str, + index_name: str | None = None, + ) -> None: + """Initialize the GoogleCloudVectorSearch client. + + Args: + project_id: The Google Cloud project ID. + location: The Google Cloud location (e.g., 'us-central1'). + bucket: The GCS bucket to use for file storage. + index_name: The name of the index. + + """ + aiplatform.init(project=project_id, location=location) + self.project_id = project_id + self.location = location + self.storage = GoogleCloudFileStorage(bucket=bucket) + self.index_name = index_name + self._credentials: google.auth.credentials.Credentials | None = None + self._aio_session: aiohttp.ClientSession | None = None + self._async_token: Token | None = None + + def _get_auth_headers(self) -> dict[str, str]: + if self._credentials is None: + self._credentials, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + if not self._credentials.token or self._credentials.expired: + self._credentials.refresh( + google.auth.transport.requests.Request(), + ) + return { + "Authorization": f"Bearer {self._credentials.token}", + "Content-Type": "application/json", + } + + async def _async_get_auth_headers(self) -> dict[str, str]: + if self._async_token is None: + self._async_token = Token( + session=self._get_aio_session(), + scopes=[ + "https://www.googleapis.com/auth/cloud-platform", + ], + ) + access_token = await self._async_token.get() + return { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + def _get_aio_session(self) -> aiohttp.ClientSession: + if self._aio_session is None or self._aio_session.closed: + connector = aiohttp.TCPConnector( + limit=300, limit_per_host=50, + ) + timeout = aiohttp.ClientTimeout(total=60) + self._aio_session = aiohttp.ClientSession( + timeout=timeout, connector=connector, + ) + return self._aio_session + + def create_index( + self, + name: str, + content_path: str, + *, + dimensions: int = 3072, + approximate_neighbors_count: int = 150, + distance_measure_type: str = "DOT_PRODUCT_DISTANCE", + **kwargs: Any, # noqa: ANN401, ARG002 + ) -> None: + """Create a new Vertex AI Vector Search index. + + Args: + name: The display name for the new index. + content_path: GCS URI to the embeddings JSON file. + dimensions: Number of dimensions in embedding vectors. + approximate_neighbors_count: Neighbors to find per vector. + distance_measure_type: The distance measure to use. + **kwargs: Additional arguments. + + """ + index = aiplatform.MatchingEngineIndex.create_tree_ah_index( + display_name=name, + contents_delta_uri=content_path, + dimensions=dimensions, + approximate_neighbors_count=approximate_neighbors_count, + distance_measure_type=distance_measure_type, # type: ignore[arg-type] + leaf_node_embedding_count=1000, + leaf_nodes_to_search_percent=10, + ) + self.index = index + + def update_index( + self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002 + ) -> None: + """Update an existing Vertex AI Vector Search index. + + Args: + index_name: The resource name of the index to update. + content_path: GCS URI to the new embeddings JSON file. + **kwargs: Additional arguments. + + """ + index = aiplatform.MatchingEngineIndex(index_name=index_name) + index.update_embeddings( + contents_delta_uri=content_path, + ) + self.index = index + + def deploy_index( + self, + index_name: str, + machine_type: str = "e2-standard-2", + ) -> None: + """Deploy a Vertex AI Vector Search index to an endpoint. + + Args: + index_name: The name of the index to deploy. + machine_type: The machine type for the endpoint. + + """ + index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create( + display_name=f"{index_name}-endpoint", + public_endpoint_enabled=True, + ) + index_endpoint.deploy_index( + index=self.index, + deployed_index_id=( + f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}" + ), + machine_type=machine_type, + ) + self.index_endpoint = index_endpoint + + def load_index_endpoint(self, endpoint_name: str) -> None: + """Load an existing Vertex AI Vector Search index endpoint. + + Args: + endpoint_name: The resource name of the index endpoint. + + """ + self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + endpoint_name, + ) + if not self.index_endpoint.public_endpoint_domain_name: + msg = ( + "The index endpoint does not have a public endpoint. " + "Ensure the endpoint is configured for public access." + ) + raise ValueError(msg) + + def run_query( + self, + deployed_index_id: str, + query: list[float], + limit: int, + ) -> list[SearchResult]: + """Run a similarity search query against the deployed index. + + Args: + deployed_index_id: The ID of the deployed index. + query: The embedding vector for the search query. + limit: Maximum number of nearest neighbors to return. + + Returns: + A list of matched items with id, distance, and content. + + """ + response = self.index_endpoint.find_neighbors( + deployed_index_id=deployed_index_id, + queries=[query], + num_neighbors=limit, + ) + results = [] + for neighbor in response[0]: + file_path = ( + f"{self.index_name}/contents/{neighbor.id}.md" + ) + content = ( + self.storage.get_file_stream(file_path) + .read() + .decode("utf-8") + ) + results.append( + SearchResult( + id=neighbor.id, + distance=float(neighbor.distance or 0), + content=content, + ), + ) + return results + + async def async_run_query( + self, + deployed_index_id: str, + query: Sequence[float], + limit: int, + ) -> list[SearchResult]: + """Run an async similarity search via the REST API. + + Args: + deployed_index_id: The ID of the deployed index. + query: The embedding vector for the search query. + limit: Maximum number of nearest neighbors to return. + + Returns: + A list of matched items with id, distance, and content. + + """ + domain = self.index_endpoint.public_endpoint_domain_name + endpoint_id = self.index_endpoint.name.split("/")[-1] + url = ( + f"https://{domain}/v1/projects/{self.project_id}" + f"/locations/{self.location}" + f"/indexEndpoints/{endpoint_id}:findNeighbors" + ) + payload = { + "deployed_index_id": deployed_index_id, + "queries": [ + { + "datapoint": {"feature_vector": list(query)}, + "neighbor_count": limit, + }, + ], + } + + headers = await self._async_get_auth_headers() + session = self._get_aio_session() + async with session.post( + url, json=payload, headers=headers, + ) as response: + response.raise_for_status() + data = await response.json() + + neighbors = ( + data.get("nearestNeighbors", [{}])[0].get("neighbors", []) + ) + content_tasks = [] + for neighbor in neighbors: + datapoint_id = neighbor["datapoint"]["datapointId"] + file_path = ( + f"{self.index_name}/contents/{datapoint_id}.md" + ) + content_tasks.append( + self.storage.async_get_file_stream(file_path), + ) + + file_streams = await asyncio.gather(*content_tasks) + results: list[SearchResult] = [] + for neighbor, stream in zip( + neighbors, file_streams, strict=True, + ): + results.append( + SearchResult( + id=neighbor["datapoint"]["datapointId"], + distance=neighbor["distance"], + content=stream.read().decode("utf-8"), + ), + ) + return results + + def delete_index(self, index_name: str) -> None: + """Delete a Vertex AI Vector Search index. + + Args: + index_name: The resource name of the index. + + """ + index = aiplatform.MatchingEngineIndex(index_name) + index.delete() + + def delete_index_endpoint( + self, index_endpoint_name: str, + ) -> None: + """Delete a Vertex AI Vector Search index endpoint. + + Args: + index_endpoint_name: The resource name of the endpoint. + + """ + index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name, + ) + index_endpoint.undeploy_all() + index_endpoint.delete(force=True)