"""Google Cloud Vertex AI Vector Search implementation.""" import asyncio from collections.abc import Sequence from typing import Any from uuid import uuid4 import aiohttp import google.auth import google.auth.credentials import google.auth.transport.requests from gcloud.aio.auth import Token from google.cloud import aiplatform from .file_storage.google_cloud import GoogleCloudFileStorage from .base import BaseVectorSearch, SearchResult class GoogleCloudVectorSearch(BaseVectorSearch): """A vector search provider using Vertex AI Vector Search.""" def __init__( self, project_id: str, location: str, bucket: str, index_name: str | None = None, ) -> None: """Initialize the GoogleCloudVectorSearch client. Args: project_id: The Google Cloud project ID. location: The Google Cloud location (e.g., 'us-central1'). bucket: The GCS bucket to use for file storage. index_name: The name of the index. """ aiplatform.init(project=project_id, location=location) self.project_id = project_id self.location = location self.storage = GoogleCloudFileStorage(bucket=bucket) self.index_name = index_name self._credentials: google.auth.credentials.Credentials | None = None self._aio_session: aiohttp.ClientSession | None = None self._async_token: Token | None = None def _get_auth_headers(self) -> dict[str, str]: if self._credentials is None: self._credentials, _ = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"], ) if not self._credentials.token or self._credentials.expired: self._credentials.refresh( google.auth.transport.requests.Request(), ) return { "Authorization": f"Bearer {self._credentials.token}", "Content-Type": "application/json", } async def _async_get_auth_headers(self) -> dict[str, str]: if self._async_token is None: self._async_token = Token( session=self._get_aio_session(), scopes=[ "https://www.googleapis.com/auth/cloud-platform", ], ) access_token = await self._async_token.get() return { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json", } def _get_aio_session(self) -> aiohttp.ClientSession: if self._aio_session is None or self._aio_session.closed: connector = aiohttp.TCPConnector( limit=300, limit_per_host=50, ) timeout = aiohttp.ClientTimeout(total=60) self._aio_session = aiohttp.ClientSession( timeout=timeout, connector=connector, ) return self._aio_session def create_index( self, name: str, content_path: str, *, dimensions: int = 3072, approximate_neighbors_count: int = 150, distance_measure_type: str = "DOT_PRODUCT_DISTANCE", **kwargs: Any, # noqa: ANN401, ARG002 ) -> None: """Create a new Vertex AI Vector Search index. Args: name: The display name for the new index. content_path: GCS URI to the embeddings JSON file. dimensions: Number of dimensions in embedding vectors. approximate_neighbors_count: Neighbors to find per vector. distance_measure_type: The distance measure to use. **kwargs: Additional arguments. """ index = aiplatform.MatchingEngineIndex.create_tree_ah_index( display_name=name, contents_delta_uri=content_path, dimensions=dimensions, approximate_neighbors_count=approximate_neighbors_count, distance_measure_type=distance_measure_type, # type: ignore[arg-type] leaf_node_embedding_count=1000, leaf_nodes_to_search_percent=10, ) self.index = index def update_index( self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002 ) -> None: """Update an existing Vertex AI Vector Search index. Args: index_name: The resource name of the index to update. content_path: GCS URI to the new embeddings JSON file. **kwargs: Additional arguments. """ index = aiplatform.MatchingEngineIndex(index_name=index_name) index.update_embeddings( contents_delta_uri=content_path, ) self.index = index def deploy_index( self, index_name: str, machine_type: str = "e2-standard-2", ) -> None: """Deploy a Vertex AI Vector Search index to an endpoint. Args: index_name: The name of the index to deploy. machine_type: The machine type for the endpoint. """ index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create( display_name=f"{index_name}-endpoint", public_endpoint_enabled=True, ) index_endpoint.deploy_index( index=self.index, deployed_index_id=( f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}" ), machine_type=machine_type, ) self.index_endpoint = index_endpoint def load_index_endpoint(self, endpoint_name: str) -> None: """Load an existing Vertex AI Vector Search index endpoint. Args: endpoint_name: The resource name of the index endpoint. """ self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint( endpoint_name, ) if not self.index_endpoint.public_endpoint_domain_name: msg = ( "The index endpoint does not have a public endpoint. " "Ensure the endpoint is configured for public access." ) raise ValueError(msg) def run_query( self, deployed_index_id: str, query: list[float], limit: int, ) -> list[SearchResult]: """Run a similarity search query against the deployed index. Args: deployed_index_id: The ID of the deployed index. query: The embedding vector for the search query. limit: Maximum number of nearest neighbors to return. Returns: A list of matched items with id, distance, and content. """ response = self.index_endpoint.find_neighbors( deployed_index_id=deployed_index_id, queries=[query], num_neighbors=limit, ) results = [] for neighbor in response[0]: file_path = ( f"{self.index_name}/contents/{neighbor.id}.md" ) content = ( self.storage.get_file_stream(file_path) .read() .decode("utf-8") ) results.append( SearchResult( id=neighbor.id, distance=float(neighbor.distance or 0), content=content, ), ) return results async def async_run_query( self, deployed_index_id: str, query: Sequence[float], limit: int, ) -> list[SearchResult]: """Run an async similarity search via the REST API. Args: deployed_index_id: The ID of the deployed index. query: The embedding vector for the search query. limit: Maximum number of nearest neighbors to return. Returns: A list of matched items with id, distance, and content. """ domain = self.index_endpoint.public_endpoint_domain_name endpoint_id = self.index_endpoint.name.split("/")[-1] url = ( f"https://{domain}/v1/projects/{self.project_id}" f"/locations/{self.location}" f"/indexEndpoints/{endpoint_id}:findNeighbors" ) payload = { "deployed_index_id": deployed_index_id, "queries": [ { "datapoint": {"feature_vector": list(query)}, "neighbor_count": limit, }, ], } headers = await self._async_get_auth_headers() session = self._get_aio_session() async with session.post( url, json=payload, headers=headers, ) as response: response.raise_for_status() data = await response.json() neighbors = ( data.get("nearestNeighbors", [{}])[0].get("neighbors", []) ) content_tasks = [] for neighbor in neighbors: datapoint_id = neighbor["datapoint"]["datapointId"] file_path = ( f"{self.index_name}/contents/{datapoint_id}.md" ) content_tasks.append( self.storage.async_get_file_stream(file_path), ) file_streams = await asyncio.gather(*content_tasks) results: list[SearchResult] = [] for neighbor, stream in zip( neighbors, file_streams, strict=True, ): results.append( SearchResult( id=neighbor["datapoint"]["datapointId"], distance=neighbor["distance"], content=stream.read().decode("utf-8"), ), ) return results def delete_index(self, index_name: str) -> None: """Delete a Vertex AI Vector Search index. Args: index_name: The resource name of the index. """ index = aiplatform.MatchingEngineIndex(index_name) index.delete() def delete_index_endpoint( self, index_endpoint_name: str, ) -> None: """Delete a Vertex AI Vector Search index endpoint. Args: index_endpoint_name: The resource name of the endpoint. """ index_endpoint = aiplatform.MatchingEngineIndexEndpoint( index_endpoint_name, ) index_endpoint.undeploy_all() index_endpoint.delete(force=True)