224 lines
7.5 KiB
Python
224 lines
7.5 KiB
Python
"""Google Cloud Vector Search client."""
|
|
|
|
import asyncio
|
|
from collections.abc import Sequence
|
|
|
|
from gcloud.aio.auth import Token
|
|
|
|
from knowledge_search_mcp.logging import log_structured_entry
|
|
from knowledge_search_mcp.models import SearchResult, SourceNamespace
|
|
|
|
from .base import BaseGoogleCloudClient
|
|
from .storage import GoogleCloudFileStorage
|
|
|
|
|
|
class GoogleCloudVectorSearch(BaseGoogleCloudClient):
|
|
"""Minimal async client for the Vertex AI Matching Engine REST API."""
|
|
|
|
def __init__(
|
|
self,
|
|
project_id: str,
|
|
location: str,
|
|
bucket: str,
|
|
index_name: str | None = None,
|
|
) -> None:
|
|
"""Store configuration used to issue Matching Engine queries."""
|
|
super().__init__()
|
|
self.project_id = project_id
|
|
self.location = location
|
|
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
|
self.index_name = index_name
|
|
self._async_token: Token | None = None
|
|
self._endpoint_domain: str | None = None
|
|
self._endpoint_name: str | None = None
|
|
|
|
async def _async_get_auth_headers(self) -> dict[str, str]:
|
|
if self._async_token is None:
|
|
self._async_token = Token(
|
|
session=self._get_aio_session(),
|
|
scopes=[
|
|
"https://www.googleapis.com/auth/cloud-platform",
|
|
],
|
|
)
|
|
access_token = await self._async_token.get()
|
|
return {
|
|
"Authorization": f"Bearer {access_token}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
async def close(self) -> None:
|
|
"""Close aiohttp sessions for both vector search and storage."""
|
|
await super().close()
|
|
await self.storage.close()
|
|
|
|
def configure_index_endpoint(
|
|
self,
|
|
*,
|
|
name: str,
|
|
public_domain: str,
|
|
) -> None:
|
|
"""Persist the metadata needed to access a deployed endpoint."""
|
|
if not name:
|
|
msg = "Index endpoint name must be a non-empty string."
|
|
raise ValueError(msg)
|
|
if not public_domain:
|
|
msg = "Index endpoint domain must be a non-empty public domain."
|
|
raise ValueError(msg)
|
|
self._endpoint_name = name
|
|
self._endpoint_domain = public_domain
|
|
|
|
async def async_run_query(
|
|
self,
|
|
deployed_index_id: str,
|
|
query: Sequence[float],
|
|
limit: int,
|
|
source: SourceNamespace | None = None,
|
|
) -> list[SearchResult]:
|
|
"""Run an async similarity search via the REST API.
|
|
|
|
Args:
|
|
deployed_index_id: The ID of the deployed index.
|
|
query: The embedding vector for the search query.
|
|
limit: Maximum number of nearest neighbors to return.
|
|
source: Optional namespace filter to restrict results by source.
|
|
|
|
Returns:
|
|
A list of matched items with id, distance, and content.
|
|
|
|
"""
|
|
if self._endpoint_domain is None or self._endpoint_name is None:
|
|
msg = (
|
|
"Missing endpoint metadata. Call "
|
|
"`configure_index_endpoint` before querying."
|
|
)
|
|
log_structured_entry(
|
|
"Vector search query failed - endpoint not configured",
|
|
"ERROR",
|
|
{"error": msg},
|
|
)
|
|
raise RuntimeError(msg)
|
|
|
|
domain = self._endpoint_domain
|
|
endpoint_id = self._endpoint_name.split("/")[-1]
|
|
url = (
|
|
f"https://{domain}/v1/projects/{self.project_id}"
|
|
f"/locations/{self.location}"
|
|
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
|
)
|
|
|
|
log_structured_entry(
|
|
"Starting vector search query",
|
|
"INFO",
|
|
{
|
|
"deployed_index_id": deployed_index_id,
|
|
"neighbor_count": limit,
|
|
"endpoint_id": endpoint_id,
|
|
"embedding_dimension": len(query),
|
|
},
|
|
)
|
|
|
|
datapoint: dict = {"feature_vector": list(query)}
|
|
if source is not None:
|
|
datapoint["restricts"] = [
|
|
{"namespace": "source", "allow_list": [source.value]},
|
|
]
|
|
payload = {
|
|
"deployed_index_id": deployed_index_id,
|
|
"queries": [
|
|
{
|
|
"datapoint": datapoint,
|
|
"neighbor_count": limit,
|
|
},
|
|
],
|
|
}
|
|
|
|
try:
|
|
headers = await self._async_get_auth_headers()
|
|
session = self._get_aio_session()
|
|
async with session.post(
|
|
url,
|
|
json=payload,
|
|
headers=headers,
|
|
) as response:
|
|
if not response.ok:
|
|
body = await response.text()
|
|
msg = f"findNeighbors returned {response.status}: {body}"
|
|
log_structured_entry(
|
|
"Vector search API request failed",
|
|
"ERROR",
|
|
{
|
|
"status": response.status,
|
|
"response_body": body,
|
|
"deployed_index_id": deployed_index_id,
|
|
},
|
|
)
|
|
raise RuntimeError(msg) # noqa: TRY301
|
|
data = await response.json()
|
|
|
|
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
|
log_structured_entry(
|
|
"Vector search API request successful",
|
|
"INFO",
|
|
{
|
|
"neighbors_found": len(neighbors),
|
|
"deployed_index_id": deployed_index_id,
|
|
},
|
|
)
|
|
|
|
if not neighbors:
|
|
log_structured_entry(
|
|
"No neighbors found in vector search",
|
|
"WARNING",
|
|
{"deployed_index_id": deployed_index_id},
|
|
)
|
|
return []
|
|
|
|
# Fetch content for all neighbors
|
|
content_tasks = []
|
|
for neighbor in neighbors:
|
|
datapoint_id = neighbor["datapoint"]["datapointId"]
|
|
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
|
|
content_tasks.append(
|
|
self.storage.async_get_file_stream(file_path),
|
|
)
|
|
|
|
log_structured_entry(
|
|
"Fetching content for search results",
|
|
"INFO",
|
|
{"file_count": len(content_tasks)},
|
|
)
|
|
|
|
file_streams = await asyncio.gather(*content_tasks)
|
|
results: list[SearchResult] = []
|
|
for neighbor, stream in zip(
|
|
neighbors,
|
|
file_streams,
|
|
strict=True,
|
|
):
|
|
results.append(
|
|
SearchResult(
|
|
id=neighbor["datapoint"]["datapointId"],
|
|
distance=neighbor["distance"],
|
|
content=stream.read().decode("utf-8"),
|
|
),
|
|
)
|
|
|
|
log_structured_entry(
|
|
"Vector search completed successfully",
|
|
"INFO",
|
|
{"results_count": len(results), "deployed_index_id": deployed_index_id},
|
|
)
|
|
return results # noqa: TRY300
|
|
|
|
except Exception as e:
|
|
log_structured_entry(
|
|
"Vector search query failed with exception",
|
|
"ERROR",
|
|
{
|
|
"error": str(e),
|
|
"error_type": type(e).__name__,
|
|
"deployed_index_id": deployed_index_id,
|
|
},
|
|
)
|
|
raise
|