Files
knowledge-search-mcp/src/knowledge_search_mcp/clients/vector_search.py
Anibal Angulo 8675a89b80
Some checks failed
CI / lint (pull_request) Successful in 11s
CI / typecheck (pull_request) Successful in 11s
CI / test (pull_request) Successful in 25s
CI / build (pull_request) Failing after 3s
Add CI
2026-03-05 21:55:44 +00:00

224 lines
7.5 KiB
Python

"""Google Cloud Vector Search client."""
import asyncio
from collections.abc import Sequence
from gcloud.aio.auth import Token
from knowledge_search_mcp.logging import log_structured_entry
from knowledge_search_mcp.models import SearchResult, SourceNamespace
from .base import BaseGoogleCloudClient
from .storage import GoogleCloudFileStorage
class GoogleCloudVectorSearch(BaseGoogleCloudClient):
"""Minimal async client for the Vertex AI Matching Engine REST API."""
def __init__(
self,
project_id: str,
location: str,
bucket: str,
index_name: str | None = None,
) -> None:
"""Store configuration used to issue Matching Engine queries."""
super().__init__()
self.project_id = project_id
self.location = location
self.storage = GoogleCloudFileStorage(bucket=bucket)
self.index_name = index_name
self._async_token: Token | None = None
self._endpoint_domain: str | None = None
self._endpoint_name: str | None = None
async def _async_get_auth_headers(self) -> dict[str, str]:
if self._async_token is None:
self._async_token = Token(
session=self._get_aio_session(),
scopes=[
"https://www.googleapis.com/auth/cloud-platform",
],
)
access_token = await self._async_token.get()
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
async def close(self) -> None:
"""Close aiohttp sessions for both vector search and storage."""
await super().close()
await self.storage.close()
def configure_index_endpoint(
self,
*,
name: str,
public_domain: str,
) -> None:
"""Persist the metadata needed to access a deployed endpoint."""
if not name:
msg = "Index endpoint name must be a non-empty string."
raise ValueError(msg)
if not public_domain:
msg = "Index endpoint domain must be a non-empty public domain."
raise ValueError(msg)
self._endpoint_name = name
self._endpoint_domain = public_domain
async def async_run_query(
self,
deployed_index_id: str,
query: Sequence[float],
limit: int,
source: SourceNamespace | None = None,
) -> list[SearchResult]:
"""Run an async similarity search via the REST API.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
source: Optional namespace filter to restrict results by source.
Returns:
A list of matched items with id, distance, and content.
"""
if self._endpoint_domain is None or self._endpoint_name is None:
msg = (
"Missing endpoint metadata. Call "
"`configure_index_endpoint` before querying."
)
log_structured_entry(
"Vector search query failed - endpoint not configured",
"ERROR",
{"error": msg},
)
raise RuntimeError(msg)
domain = self._endpoint_domain
endpoint_id = self._endpoint_name.split("/")[-1]
url = (
f"https://{domain}/v1/projects/{self.project_id}"
f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors"
)
log_structured_entry(
"Starting vector search query",
"INFO",
{
"deployed_index_id": deployed_index_id,
"neighbor_count": limit,
"endpoint_id": endpoint_id,
"embedding_dimension": len(query),
},
)
datapoint: dict = {"feature_vector": list(query)}
if source is not None:
datapoint["restricts"] = [
{"namespace": "source", "allow_list": [source.value]},
]
payload = {
"deployed_index_id": deployed_index_id,
"queries": [
{
"datapoint": datapoint,
"neighbor_count": limit,
},
],
}
try:
headers = await self._async_get_auth_headers()
session = self._get_aio_session()
async with session.post(
url,
json=payload,
headers=headers,
) as response:
if not response.ok:
body = await response.text()
msg = f"findNeighbors returned {response.status}: {body}"
log_structured_entry(
"Vector search API request failed",
"ERROR",
{
"status": response.status,
"response_body": body,
"deployed_index_id": deployed_index_id,
},
)
raise RuntimeError(msg) # noqa: TRY301
data = await response.json()
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
log_structured_entry(
"Vector search API request successful",
"INFO",
{
"neighbors_found": len(neighbors),
"deployed_index_id": deployed_index_id,
},
)
if not neighbors:
log_structured_entry(
"No neighbors found in vector search",
"WARNING",
{"deployed_index_id": deployed_index_id},
)
return []
# Fetch content for all neighbors
content_tasks = []
for neighbor in neighbors:
datapoint_id = neighbor["datapoint"]["datapointId"]
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
content_tasks.append(
self.storage.async_get_file_stream(file_path),
)
log_structured_entry(
"Fetching content for search results",
"INFO",
{"file_count": len(content_tasks)},
)
file_streams = await asyncio.gather(*content_tasks)
results: list[SearchResult] = []
for neighbor, stream in zip(
neighbors,
file_streams,
strict=True,
):
results.append(
SearchResult(
id=neighbor["datapoint"]["datapointId"],
distance=neighbor["distance"],
content=stream.read().decode("utf-8"),
),
)
log_structured_entry(
"Vector search completed successfully",
"INFO",
{"results_count": len(results), "deployed_index_id": deployed_index_id},
)
return results # noqa: TRY300
except Exception as e:
log_structured_entry(
"Vector search query failed with exception",
"ERROR",
{
"error": str(e),
"error_type": type(e).__name__,
"deployed_index_id": deployed_index_id,
},
)
raise