"""Google Cloud Vector Search client.""" import asyncio from collections.abc import Sequence from gcloud.aio.auth import Token from knowledge_search_mcp.logging import log_structured_entry from knowledge_search_mcp.models import SearchResult, SourceNamespace from .base import BaseGoogleCloudClient from .storage import GoogleCloudFileStorage class GoogleCloudVectorSearch(BaseGoogleCloudClient): """Minimal async client for the Vertex AI Matching Engine REST API.""" def __init__( self, project_id: str, location: str, bucket: str, index_name: str | None = None, ) -> None: """Store configuration used to issue Matching Engine queries.""" super().__init__() self.project_id = project_id self.location = location self.storage = GoogleCloudFileStorage(bucket=bucket) self.index_name = index_name self._async_token: Token | None = None self._endpoint_domain: str | None = None self._endpoint_name: str | None = None async def _async_get_auth_headers(self) -> dict[str, str]: if self._async_token is None: self._async_token = Token( session=self._get_aio_session(), scopes=[ "https://www.googleapis.com/auth/cloud-platform", ], ) access_token = await self._async_token.get() return { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json", } async def close(self) -> None: """Close aiohttp sessions for both vector search and storage.""" await super().close() await self.storage.close() def configure_index_endpoint( self, *, name: str, public_domain: str, ) -> None: """Persist the metadata needed to access a deployed endpoint.""" if not name: msg = "Index endpoint name must be a non-empty string." raise ValueError(msg) if not public_domain: msg = "Index endpoint domain must be a non-empty public domain." raise ValueError(msg) self._endpoint_name = name self._endpoint_domain = public_domain async def async_run_query( self, deployed_index_id: str, query: Sequence[float], limit: int, source: SourceNamespace | None = None, ) -> list[SearchResult]: """Run an async similarity search via the REST API. Args: deployed_index_id: The ID of the deployed index. query: The embedding vector for the search query. limit: Maximum number of nearest neighbors to return. source: Optional namespace filter to restrict results by source. Returns: A list of matched items with id, distance, and content. """ if self._endpoint_domain is None or self._endpoint_name is None: msg = ( "Missing endpoint metadata. Call " "`configure_index_endpoint` before querying." ) log_structured_entry( "Vector search query failed - endpoint not configured", "ERROR", {"error": msg}, ) raise RuntimeError(msg) domain = self._endpoint_domain endpoint_id = self._endpoint_name.split("/")[-1] url = ( f"https://{domain}/v1/projects/{self.project_id}" f"/locations/{self.location}" f"/indexEndpoints/{endpoint_id}:findNeighbors" ) log_structured_entry( "Starting vector search query", "INFO", { "deployed_index_id": deployed_index_id, "neighbor_count": limit, "endpoint_id": endpoint_id, "embedding_dimension": len(query), }, ) datapoint: dict = {"feature_vector": list(query)} if source is not None: datapoint["restricts"] = [ {"namespace": "source", "allow_list": [source.value]}, ] payload = { "deployed_index_id": deployed_index_id, "queries": [ { "datapoint": datapoint, "neighbor_count": limit, }, ], } try: headers = await self._async_get_auth_headers() session = self._get_aio_session() async with session.post( url, json=payload, headers=headers, ) as response: if not response.ok: body = await response.text() msg = f"findNeighbors returned {response.status}: {body}" log_structured_entry( "Vector search API request failed", "ERROR", { "status": response.status, "response_body": body, "deployed_index_id": deployed_index_id, }, ) raise RuntimeError(msg) # noqa: TRY301 data = await response.json() neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", []) log_structured_entry( "Vector search API request successful", "INFO", { "neighbors_found": len(neighbors), "deployed_index_id": deployed_index_id, }, ) if not neighbors: log_structured_entry( "No neighbors found in vector search", "WARNING", {"deployed_index_id": deployed_index_id}, ) return [] # Fetch content for all neighbors content_tasks = [] for neighbor in neighbors: datapoint_id = neighbor["datapoint"]["datapointId"] file_path = f"{self.index_name}/contents/{datapoint_id}.md" content_tasks.append( self.storage.async_get_file_stream(file_path), ) log_structured_entry( "Fetching content for search results", "INFO", {"file_count": len(content_tasks)}, ) file_streams = await asyncio.gather(*content_tasks) results: list[SearchResult] = [] for neighbor, stream in zip( neighbors, file_streams, strict=True, ): results.append( SearchResult( id=neighbor["datapoint"]["datapointId"], distance=neighbor["distance"], content=stream.read().decode("utf-8"), ), ) log_structured_entry( "Vector search completed successfully", "INFO", {"results_count": len(results), "deployed_index_id": deployed_index_id}, ) return results # noqa: TRY300 except Exception as e: log_structured_entry( "Vector search query failed with exception", "ERROR", { "error": str(e), "error_type": type(e).__name__, "deployed_index_id": deployed_index_id, }, ) raise