311 lines
10 KiB
Python
311 lines
10 KiB
Python
"""Google Cloud Vertex AI Vector Search implementation."""
|
|
|
|
import asyncio
|
|
from collections.abc import Sequence
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
import aiohttp
|
|
import google.auth
|
|
import google.auth.credentials
|
|
import google.auth.transport.requests
|
|
from gcloud.aio.auth import Token
|
|
from google.cloud import aiplatform
|
|
|
|
from .file_storage.google_cloud import GoogleCloudFileStorage
|
|
from .base import BaseVectorSearch, SearchResult
|
|
|
|
|
|
class GoogleCloudVectorSearch(BaseVectorSearch):
|
|
"""A vector search provider using Vertex AI Vector Search."""
|
|
|
|
def __init__(
|
|
self,
|
|
project_id: str,
|
|
location: str,
|
|
bucket: str,
|
|
index_name: str | None = None,
|
|
) -> None:
|
|
"""Initialize the GoogleCloudVectorSearch client.
|
|
|
|
Args:
|
|
project_id: The Google Cloud project ID.
|
|
location: The Google Cloud location (e.g., 'us-central1').
|
|
bucket: The GCS bucket to use for file storage.
|
|
index_name: The name of the index.
|
|
|
|
"""
|
|
aiplatform.init(project=project_id, location=location)
|
|
self.project_id = project_id
|
|
self.location = location
|
|
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
|
self.index_name = index_name
|
|
self._credentials: google.auth.credentials.Credentials | None = None
|
|
self._aio_session: aiohttp.ClientSession | None = None
|
|
self._async_token: Token | None = None
|
|
|
|
def _get_auth_headers(self) -> dict[str, str]:
|
|
if self._credentials is None:
|
|
self._credentials, _ = google.auth.default(
|
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
)
|
|
if not self._credentials.token or self._credentials.expired:
|
|
self._credentials.refresh(
|
|
google.auth.transport.requests.Request(),
|
|
)
|
|
return {
|
|
"Authorization": f"Bearer {self._credentials.token}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
async def _async_get_auth_headers(self) -> dict[str, str]:
|
|
if self._async_token is None:
|
|
self._async_token = Token(
|
|
session=self._get_aio_session(),
|
|
scopes=[
|
|
"https://www.googleapis.com/auth/cloud-platform",
|
|
],
|
|
)
|
|
access_token = await self._async_token.get()
|
|
return {
|
|
"Authorization": f"Bearer {access_token}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
if self._aio_session is None or self._aio_session.closed:
|
|
connector = aiohttp.TCPConnector(
|
|
limit=300, limit_per_host=50,
|
|
)
|
|
timeout = aiohttp.ClientTimeout(total=60)
|
|
self._aio_session = aiohttp.ClientSession(
|
|
timeout=timeout, connector=connector,
|
|
)
|
|
return self._aio_session
|
|
|
|
def create_index(
|
|
self,
|
|
name: str,
|
|
content_path: str,
|
|
*,
|
|
dimensions: int = 3072,
|
|
approximate_neighbors_count: int = 150,
|
|
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
|
**kwargs: Any, # noqa: ANN401, ARG002
|
|
) -> None:
|
|
"""Create a new Vertex AI Vector Search index.
|
|
|
|
Args:
|
|
name: The display name for the new index.
|
|
content_path: GCS URI to the embeddings JSON file.
|
|
dimensions: Number of dimensions in embedding vectors.
|
|
approximate_neighbors_count: Neighbors to find per vector.
|
|
distance_measure_type: The distance measure to use.
|
|
**kwargs: Additional arguments.
|
|
|
|
"""
|
|
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
|
display_name=name,
|
|
contents_delta_uri=content_path,
|
|
dimensions=dimensions,
|
|
approximate_neighbors_count=approximate_neighbors_count,
|
|
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
|
|
leaf_node_embedding_count=1000,
|
|
leaf_nodes_to_search_percent=10,
|
|
)
|
|
self.index = index
|
|
|
|
def update_index(
|
|
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
|
|
) -> None:
|
|
"""Update an existing Vertex AI Vector Search index.
|
|
|
|
Args:
|
|
index_name: The resource name of the index to update.
|
|
content_path: GCS URI to the new embeddings JSON file.
|
|
**kwargs: Additional arguments.
|
|
|
|
"""
|
|
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
|
index.update_embeddings(
|
|
contents_delta_uri=content_path,
|
|
)
|
|
self.index = index
|
|
|
|
def deploy_index(
|
|
self,
|
|
index_name: str,
|
|
machine_type: str = "e2-standard-2",
|
|
) -> None:
|
|
"""Deploy a Vertex AI Vector Search index to an endpoint.
|
|
|
|
Args:
|
|
index_name: The name of the index to deploy.
|
|
machine_type: The machine type for the endpoint.
|
|
|
|
"""
|
|
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
|
display_name=f"{index_name}-endpoint",
|
|
public_endpoint_enabled=True,
|
|
)
|
|
index_endpoint.deploy_index(
|
|
index=self.index,
|
|
deployed_index_id=(
|
|
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
|
|
),
|
|
machine_type=machine_type,
|
|
)
|
|
self.index_endpoint = index_endpoint
|
|
|
|
def load_index_endpoint(self, endpoint_name: str) -> None:
|
|
"""Load an existing Vertex AI Vector Search index endpoint.
|
|
|
|
Args:
|
|
endpoint_name: The resource name of the index endpoint.
|
|
|
|
"""
|
|
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
|
endpoint_name,
|
|
)
|
|
if not self.index_endpoint.public_endpoint_domain_name:
|
|
msg = (
|
|
"The index endpoint does not have a public endpoint. "
|
|
"Ensure the endpoint is configured for public access."
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
def run_query(
|
|
self,
|
|
deployed_index_id: str,
|
|
query: list[float],
|
|
limit: int,
|
|
) -> list[SearchResult]:
|
|
"""Run a similarity search query against the deployed index.
|
|
|
|
Args:
|
|
deployed_index_id: The ID of the deployed index.
|
|
query: The embedding vector for the search query.
|
|
limit: Maximum number of nearest neighbors to return.
|
|
|
|
Returns:
|
|
A list of matched items with id, distance, and content.
|
|
|
|
"""
|
|
response = self.index_endpoint.find_neighbors(
|
|
deployed_index_id=deployed_index_id,
|
|
queries=[query],
|
|
num_neighbors=limit,
|
|
)
|
|
results = []
|
|
for neighbor in response[0]:
|
|
file_path = (
|
|
f"{self.index_name}/contents/{neighbor.id}.md"
|
|
)
|
|
content = (
|
|
self.storage.get_file_stream(file_path)
|
|
.read()
|
|
.decode("utf-8")
|
|
)
|
|
results.append(
|
|
SearchResult(
|
|
id=neighbor.id,
|
|
distance=float(neighbor.distance or 0),
|
|
content=content,
|
|
),
|
|
)
|
|
return results
|
|
|
|
async def async_run_query(
|
|
self,
|
|
deployed_index_id: str,
|
|
query: Sequence[float],
|
|
limit: int,
|
|
) -> list[SearchResult]:
|
|
"""Run an async similarity search via the REST API.
|
|
|
|
Args:
|
|
deployed_index_id: The ID of the deployed index.
|
|
query: The embedding vector for the search query.
|
|
limit: Maximum number of nearest neighbors to return.
|
|
|
|
Returns:
|
|
A list of matched items with id, distance, and content.
|
|
|
|
"""
|
|
domain = self.index_endpoint.public_endpoint_domain_name
|
|
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
|
url = (
|
|
f"https://{domain}/v1/projects/{self.project_id}"
|
|
f"/locations/{self.location}"
|
|
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
|
)
|
|
payload = {
|
|
"deployed_index_id": deployed_index_id,
|
|
"queries": [
|
|
{
|
|
"datapoint": {"feature_vector": list(query)},
|
|
"neighbor_count": limit,
|
|
},
|
|
],
|
|
}
|
|
|
|
headers = await self._async_get_auth_headers()
|
|
session = self._get_aio_session()
|
|
async with session.post(
|
|
url, json=payload, headers=headers,
|
|
) as response:
|
|
response.raise_for_status()
|
|
data = await response.json()
|
|
|
|
neighbors = (
|
|
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
|
)
|
|
content_tasks = []
|
|
for neighbor in neighbors:
|
|
datapoint_id = neighbor["datapoint"]["datapointId"]
|
|
file_path = (
|
|
f"{self.index_name}/contents/{datapoint_id}.md"
|
|
)
|
|
content_tasks.append(
|
|
self.storage.async_get_file_stream(file_path),
|
|
)
|
|
|
|
file_streams = await asyncio.gather(*content_tasks)
|
|
results: list[SearchResult] = []
|
|
for neighbor, stream in zip(
|
|
neighbors, file_streams, strict=True,
|
|
):
|
|
results.append(
|
|
SearchResult(
|
|
id=neighbor["datapoint"]["datapointId"],
|
|
distance=neighbor["distance"],
|
|
content=stream.read().decode("utf-8"),
|
|
),
|
|
)
|
|
return results
|
|
|
|
def delete_index(self, index_name: str) -> None:
|
|
"""Delete a Vertex AI Vector Search index.
|
|
|
|
Args:
|
|
index_name: The resource name of the index.
|
|
|
|
"""
|
|
index = aiplatform.MatchingEngineIndex(index_name)
|
|
index.delete()
|
|
|
|
def delete_index_endpoint(
|
|
self, index_endpoint_name: str,
|
|
) -> None:
|
|
"""Delete a Vertex AI Vector Search index endpoint.
|
|
|
|
Args:
|
|
index_endpoint_name: The resource name of the endpoint.
|
|
|
|
"""
|
|
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
|
index_endpoint_name,
|
|
)
|
|
index_endpoint.undeploy_all()
|
|
index_endpoint.delete(force=True)
|