Files
agent/rag_agent/vertex_ai.py

311 lines
10 KiB
Python

"""Google Cloud Vertex AI Vector Search implementation."""
import asyncio
from collections.abc import Sequence
from typing import Any
from uuid import uuid4
import aiohttp
import google.auth
import google.auth.credentials
import google.auth.transport.requests
from gcloud.aio.auth import Token
from google.cloud import aiplatform
from .file_storage.google_cloud import GoogleCloudFileStorage
from .base import BaseVectorSearch, SearchResult
class GoogleCloudVectorSearch(BaseVectorSearch):
"""A vector search provider using Vertex AI Vector Search."""
def __init__(
self,
project_id: str,
location: str,
bucket: str,
index_name: str | None = None,
) -> None:
"""Initialize the GoogleCloudVectorSearch client.
Args:
project_id: The Google Cloud project ID.
location: The Google Cloud location (e.g., 'us-central1').
bucket: The GCS bucket to use for file storage.
index_name: The name of the index.
"""
aiplatform.init(project=project_id, location=location)
self.project_id = project_id
self.location = location
self.storage = GoogleCloudFileStorage(bucket=bucket)
self.index_name = index_name
self._credentials: google.auth.credentials.Credentials | None = None
self._aio_session: aiohttp.ClientSession | None = None
self._async_token: Token | None = None
def _get_auth_headers(self) -> dict[str, str]:
if self._credentials is None:
self._credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
if not self._credentials.token or self._credentials.expired:
self._credentials.refresh(
google.auth.transport.requests.Request(),
)
return {
"Authorization": f"Bearer {self._credentials.token}",
"Content-Type": "application/json",
}
async def _async_get_auth_headers(self) -> dict[str, str]:
if self._async_token is None:
self._async_token = Token(
session=self._get_aio_session(),
scopes=[
"https://www.googleapis.com/auth/cloud-platform",
],
)
access_token = await self._async_token.get()
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def create_index(
self,
name: str,
content_path: str,
*,
dimensions: int = 3072,
approximate_neighbors_count: int = 150,
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
**kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Create a new Vertex AI Vector Search index.
Args:
name: The display name for the new index.
content_path: GCS URI to the embeddings JSON file.
dimensions: Number of dimensions in embedding vectors.
approximate_neighbors_count: Neighbors to find per vector.
distance_measure_type: The distance measure to use.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
display_name=name,
contents_delta_uri=content_path,
dimensions=dimensions,
approximate_neighbors_count=approximate_neighbors_count,
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
leaf_node_embedding_count=1000,
leaf_nodes_to_search_percent=10,
)
self.index = index
def update_index(
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Update an existing Vertex AI Vector Search index.
Args:
index_name: The resource name of the index to update.
content_path: GCS URI to the new embeddings JSON file.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex(index_name=index_name)
index.update_embeddings(
contents_delta_uri=content_path,
)
self.index = index
def deploy_index(
self,
index_name: str,
machine_type: str = "e2-standard-2",
) -> None:
"""Deploy a Vertex AI Vector Search index to an endpoint.
Args:
index_name: The name of the index to deploy.
machine_type: The machine type for the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=f"{index_name}-endpoint",
public_endpoint_enabled=True,
)
index_endpoint.deploy_index(
index=self.index,
deployed_index_id=(
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
),
machine_type=machine_type,
)
self.index_endpoint = index_endpoint
def load_index_endpoint(self, endpoint_name: str) -> None:
"""Load an existing Vertex AI Vector Search index endpoint.
Args:
endpoint_name: The resource name of the index endpoint.
"""
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
endpoint_name,
)
if not self.index_endpoint.public_endpoint_domain_name:
msg = (
"The index endpoint does not have a public endpoint. "
"Ensure the endpoint is configured for public access."
)
raise ValueError(msg)
def run_query(
self,
deployed_index_id: str,
query: list[float],
limit: int,
) -> list[SearchResult]:
"""Run a similarity search query against the deployed index.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
response = self.index_endpoint.find_neighbors(
deployed_index_id=deployed_index_id,
queries=[query],
num_neighbors=limit,
)
results = []
for neighbor in response[0]:
file_path = (
f"{self.index_name}/contents/{neighbor.id}.md"
)
content = (
self.storage.get_file_stream(file_path)
.read()
.decode("utf-8")
)
results.append(
SearchResult(
id=neighbor.id,
distance=float(neighbor.distance or 0),
content=content,
),
)
return results
async def async_run_query(
self,
deployed_index_id: str,
query: Sequence[float],
limit: int,
) -> list[SearchResult]:
"""Run an async similarity search via the REST API.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
domain = self.index_endpoint.public_endpoint_domain_name
endpoint_id = self.index_endpoint.name.split("/")[-1]
url = (
f"https://{domain}/v1/projects/{self.project_id}"
f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors"
)
payload = {
"deployed_index_id": deployed_index_id,
"queries": [
{
"datapoint": {"feature_vector": list(query)},
"neighbor_count": limit,
},
],
}
headers = await self._async_get_auth_headers()
session = self._get_aio_session()
async with session.post(
url, json=payload, headers=headers,
) as response:
response.raise_for_status()
data = await response.json()
neighbors = (
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
)
content_tasks = []
for neighbor in neighbors:
datapoint_id = neighbor["datapoint"]["datapointId"]
file_path = (
f"{self.index_name}/contents/{datapoint_id}.md"
)
content_tasks.append(
self.storage.async_get_file_stream(file_path),
)
file_streams = await asyncio.gather(*content_tasks)
results: list[SearchResult] = []
for neighbor, stream in zip(
neighbors, file_streams, strict=True,
):
results.append(
SearchResult(
id=neighbor["datapoint"]["datapointId"],
distance=neighbor["distance"],
content=stream.read().decode("utf-8"),
),
)
return results
def delete_index(self, index_name: str) -> None:
"""Delete a Vertex AI Vector Search index.
Args:
index_name: The resource name of the index.
"""
index = aiplatform.MatchingEngineIndex(index_name)
index.delete()
def delete_index_endpoint(
self, index_endpoint_name: str,
) -> None:
"""Delete a Vertex AI Vector Search index endpoint.
Args:
index_endpoint_name: The resource name of the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name,
)
index_endpoint.undeploy_all()
index_endpoint.delete(force=True)