Split out main module
This commit is contained in:
226
src/knowledge_search_mcp/clients/vector_search.py
Normal file
226
src/knowledge_search_mcp/clients/vector_search.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# ruff: noqa: INP001
|
||||
"""Google Cloud Vector Search client."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
|
||||
from gcloud.aio.auth import Token
|
||||
|
||||
from ..logging import log_structured_entry
|
||||
from ..models import SearchResult, SourceNamespace
|
||||
from .base import BaseGoogleCloudClient
|
||||
from .storage import GoogleCloudFileStorage
|
||||
|
||||
|
||||
class GoogleCloudVectorSearch(BaseGoogleCloudClient):
|
||||
"""Minimal async client for the Vertex AI Matching Engine REST API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
bucket: str,
|
||||
index_name: str | None = None,
|
||||
) -> None:
|
||||
"""Store configuration used to issue Matching Engine queries."""
|
||||
super().__init__()
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||
self.index_name = index_name
|
||||
self._async_token: Token | None = None
|
||||
self._endpoint_domain: str | None = None
|
||||
self._endpoint_name: str | None = None
|
||||
|
||||
async def _async_get_auth_headers(self) -> dict[str, str]:
|
||||
if self._async_token is None:
|
||||
self._async_token = Token(
|
||||
session=self._get_aio_session(),
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
],
|
||||
)
|
||||
access_token = await self._async_token.get()
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close aiohttp sessions for both vector search and storage."""
|
||||
await super().close()
|
||||
await self.storage.close()
|
||||
|
||||
def configure_index_endpoint(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
public_domain: str,
|
||||
) -> None:
|
||||
"""Persist the metadata needed to access a deployed endpoint."""
|
||||
if not name:
|
||||
msg = "Index endpoint name must be a non-empty string."
|
||||
raise ValueError(msg)
|
||||
if not public_domain:
|
||||
msg = "Index endpoint domain must be a non-empty public domain."
|
||||
raise ValueError(msg)
|
||||
self._endpoint_name = name
|
||||
self._endpoint_domain = public_domain
|
||||
|
||||
async def async_run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: Sequence[float],
|
||||
limit: int,
|
||||
source: SourceNamespace | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Run an async similarity search via the REST API.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
source: Optional namespace filter to restrict results by source.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
if self._endpoint_domain is None or self._endpoint_name is None:
|
||||
msg = (
|
||||
"Missing endpoint metadata. Call "
|
||||
"`configure_index_endpoint` before querying."
|
||||
)
|
||||
log_structured_entry(
|
||||
"Vector search query failed - endpoint not configured",
|
||||
"ERROR",
|
||||
{"error": msg}
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
domain = self._endpoint_domain
|
||||
endpoint_id = self._endpoint_name.split("/")[-1]
|
||||
url = (
|
||||
f"https://{domain}/v1/projects/{self.project_id}"
|
||||
f"/locations/{self.location}"
|
||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||
)
|
||||
|
||||
log_structured_entry(
|
||||
"Starting vector search query",
|
||||
"INFO",
|
||||
{
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"neighbor_count": limit,
|
||||
"endpoint_id": endpoint_id,
|
||||
"embedding_dimension": len(query)
|
||||
}
|
||||
)
|
||||
|
||||
datapoint: dict = {"feature_vector": list(query)}
|
||||
if source is not None:
|
||||
datapoint["restricts"] = [
|
||||
{"namespace": "source", "allow_list": [source.value]},
|
||||
]
|
||||
payload = {
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"queries": [
|
||||
{
|
||||
"datapoint": datapoint,
|
||||
"neighbor_count": limit,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
headers = await self._async_get_auth_headers()
|
||||
session = self._get_aio_session()
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
) as response:
|
||||
if not response.ok:
|
||||
body = await response.text()
|
||||
msg = f"findNeighbors returned {response.status}: {body}"
|
||||
log_structured_entry(
|
||||
"Vector search API request failed",
|
||||
"ERROR",
|
||||
{
|
||||
"status": response.status,
|
||||
"response_body": body,
|
||||
"deployed_index_id": deployed_index_id
|
||||
}
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
data = await response.json()
|
||||
|
||||
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||
log_structured_entry(
|
||||
"Vector search API request successful",
|
||||
"INFO",
|
||||
{
|
||||
"neighbors_found": len(neighbors),
|
||||
"deployed_index_id": deployed_index_id
|
||||
}
|
||||
)
|
||||
|
||||
if not neighbors:
|
||||
log_structured_entry(
|
||||
"No neighbors found in vector search",
|
||||
"WARNING",
|
||||
{"deployed_index_id": deployed_index_id}
|
||||
)
|
||||
return []
|
||||
|
||||
# Fetch content for all neighbors
|
||||
content_tasks = []
|
||||
for neighbor in neighbors:
|
||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
|
||||
content_tasks.append(
|
||||
self.storage.async_get_file_stream(file_path),
|
||||
)
|
||||
|
||||
log_structured_entry(
|
||||
"Fetching content for search results",
|
||||
"INFO",
|
||||
{"file_count": len(content_tasks)}
|
||||
)
|
||||
|
||||
file_streams = await asyncio.gather(*content_tasks)
|
||||
results: list[SearchResult] = []
|
||||
for neighbor, stream in zip(
|
||||
neighbors,
|
||||
file_streams,
|
||||
strict=True,
|
||||
):
|
||||
results.append(
|
||||
SearchResult(
|
||||
id=neighbor["datapoint"]["datapointId"],
|
||||
distance=neighbor["distance"],
|
||||
content=stream.read().decode("utf-8"),
|
||||
),
|
||||
)
|
||||
|
||||
log_structured_entry(
|
||||
"Vector search completed successfully",
|
||||
"INFO",
|
||||
{
|
||||
"results_count": len(results),
|
||||
"deployed_index_id": deployed_index_id
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
log_structured_entry(
|
||||
"Vector search query failed with exception",
|
||||
"ERROR",
|
||||
{
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
"deployed_index_id": deployed_index_id
|
||||
}
|
||||
)
|
||||
raise
|
||||
Reference in New Issue
Block a user