# ruff: noqa: INP001 """Async helpers for querying Vertex AI vector search via MCP.""" import argparse import asyncio import io import logging from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass from typing import BinaryIO, TypedDict import aiohttp from gcloud.aio.auth import Token from gcloud.aio.storage import Storage from google import genai from google.genai import types as genai_types from mcp.server.fastmcp import Context, FastMCP from pydantic_settings import BaseSettings logger = logging.getLogger(__name__) HTTP_TOO_MANY_REQUESTS = 429 HTTP_SERVER_ERROR = 500 class GoogleCloudFileStorage: """Cache-aware helper for downloading files from Google Cloud Storage.""" def __init__(self, bucket: str) -> None: """Initialize the storage helper.""" self.bucket_name = bucket self._aio_session: aiohttp.ClientSession | None = None self._aio_storage: Storage | None = None self._cache: dict[str, bytes] = {} def _get_aio_session(self) -> aiohttp.ClientSession: if self._aio_session is None or self._aio_session.closed: connector = aiohttp.TCPConnector( limit=300, limit_per_host=50, ) timeout = aiohttp.ClientTimeout(total=60) self._aio_session = aiohttp.ClientSession( timeout=timeout, connector=connector, ) return self._aio_session def _get_aio_storage(self) -> Storage: if self._aio_storage is None: self._aio_storage = Storage( session=self._get_aio_session(), ) return self._aio_storage async def async_get_file_stream( self, file_name: str, max_retries: int = 3, ) -> BinaryIO: """Get a file asynchronously with retry on transient errors. Args: file_name: The blob name to retrieve. max_retries: Maximum number of retry attempts. Returns: A BytesIO stream with the file contents. Raises: TimeoutError: If all retry attempts fail. """ if file_name in self._cache: file_stream = io.BytesIO(self._cache[file_name]) file_stream.name = file_name return file_stream storage_client = self._get_aio_storage() last_exception: Exception | None = None for attempt in range(max_retries): try: self._cache[file_name] = await storage_client.download( self.bucket_name, file_name, ) file_stream = io.BytesIO(self._cache[file_name]) file_stream.name = file_name except TimeoutError as exc: last_exception = exc logger.warning( "Timeout downloading gs://%s/%s (attempt %d/%d)", self.bucket_name, file_name, attempt + 1, max_retries, ) except aiohttp.ClientResponseError as exc: last_exception = exc if ( exc.status == HTTP_TOO_MANY_REQUESTS or exc.status >= HTTP_SERVER_ERROR ): logger.warning( "HTTP %d downloading gs://%s/%s (attempt %d/%d)", exc.status, self.bucket_name, file_name, attempt + 1, max_retries, ) else: raise else: return file_stream if attempt < max_retries - 1: delay = 0.5 * (2**attempt) await asyncio.sleep(delay) msg = ( f"Failed to download gs://{self.bucket_name}/{file_name} " f"after {max_retries} attempts" ) raise TimeoutError(msg) from last_exception class SearchResult(TypedDict): """Structured response item returned by the vector search API.""" id: str distance: float content: str class GoogleCloudVectorSearch: """Minimal async client for the Vertex AI Matching Engine REST API.""" def __init__( self, project_id: str, location: str, bucket: str, index_name: str | None = None, ) -> None: """Store configuration used to issue Matching Engine queries.""" self.project_id = project_id self.location = location self.storage = GoogleCloudFileStorage(bucket=bucket) self.index_name = index_name self._aio_session: aiohttp.ClientSession | None = None self._async_token: Token | None = None self._endpoint_domain: str | None = None self._endpoint_name: str | None = None async def _async_get_auth_headers(self) -> dict[str, str]: if self._async_token is None: self._async_token = Token( session=self._get_aio_session(), scopes=[ "https://www.googleapis.com/auth/cloud-platform", ], ) access_token = await self._async_token.get() return { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json", } def _get_aio_session(self) -> aiohttp.ClientSession: if self._aio_session is None or self._aio_session.closed: connector = aiohttp.TCPConnector( limit=300, limit_per_host=50, ) timeout = aiohttp.ClientTimeout(total=60) self._aio_session = aiohttp.ClientSession( timeout=timeout, connector=connector, ) return self._aio_session def configure_index_endpoint( self, *, name: str, public_domain: str, ) -> None: """Persist the metadata needed to access a deployed endpoint.""" if not name: msg = "Index endpoint name must be a non-empty string." raise ValueError(msg) if not public_domain: msg = "Index endpoint domain must be a non-empty public domain." raise ValueError(msg) self._endpoint_name = name self._endpoint_domain = public_domain async def async_run_query( self, deployed_index_id: str, query: Sequence[float], limit: int, ) -> list[SearchResult]: """Run an async similarity search via the REST API. Args: deployed_index_id: The ID of the deployed index. query: The embedding vector for the search query. limit: Maximum number of nearest neighbors to return. Returns: A list of matched items with id, distance, and content. """ if self._endpoint_domain is None or self._endpoint_name is None: msg = ( "Missing endpoint metadata. Call " "`configure_index_endpoint` before querying." ) raise RuntimeError(msg) domain = self._endpoint_domain endpoint_id = self._endpoint_name.split("/")[-1] url = ( f"https://{domain}/v1/projects/{self.project_id}" f"/locations/{self.location}" f"/indexEndpoints/{endpoint_id}:findNeighbors" ) payload = { "deployed_index_id": deployed_index_id, "queries": [ { "datapoint": {"feature_vector": list(query)}, "neighbor_count": limit, }, ], } headers = await self._async_get_auth_headers() session = self._get_aio_session() async with session.post( url, json=payload, headers=headers, ) as response: if not response.ok: body = await response.text() msg = f"findNeighbors returned {response.status}: {body}" raise RuntimeError(msg) data = await response.json() neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", []) content_tasks = [] for neighbor in neighbors: datapoint_id = neighbor["datapoint"]["datapointId"] file_path = f"{self.index_name}/contents/{datapoint_id}.md" content_tasks.append( self.storage.async_get_file_stream(file_path), ) file_streams = await asyncio.gather(*content_tasks) results: list[SearchResult] = [] for neighbor, stream in zip( neighbors, file_streams, strict=True, ): results.append( SearchResult( id=neighbor["datapoint"]["datapointId"], distance=neighbor["distance"], content=stream.read().decode("utf-8"), ), ) return results # --------------------------------------------------------------------------- # MCP Server # --------------------------------------------------------------------------- class Settings(BaseSettings): """Server configuration populated from environment variables.""" model_config = {"env_file": ".env"} project_id: str location: str bucket: str index_name: str deployed_index_id: str endpoint_name: str endpoint_domain: str embedding_model: str = "gemini-embedding-001" search_limit: int = 10 @dataclass class AppContext: """Shared resources initialised once at server startup.""" vector_search: GoogleCloudVectorSearch genai_client: genai.Client settings: Settings @asynccontextmanager async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]: """Create and configure the vector-search client for the server lifetime.""" cfg = Settings.model_validate({}) vs = GoogleCloudVectorSearch( project_id=cfg.project_id, location=cfg.location, bucket=cfg.bucket, index_name=cfg.index_name, ) vs.configure_index_endpoint( name=cfg.endpoint_name, public_domain=cfg.endpoint_domain, ) genai_client = genai.Client( vertexai=True, project=cfg.project_id, location=cfg.location, ) yield AppContext( vector_search=vs, genai_client=genai_client, settings=cfg, ) def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--transport", choices=["stdio", "sse"], default="stdio", ) parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=8080) return parser.parse_args() _args = _parse_args() mcp = FastMCP( "knowledge-search", host=_args.host, port=_args.port, lifespan=lifespan, ) @mcp.tool() async def knowledge_search( query: str, ctx: Context, ) -> str: """Search a knowledge base using a natural-language query. Args: query: The text query to search for. ctx: MCP request context (injected automatically). Returns: A formatted string containing matched documents with id and content. """ import time app: AppContext = ctx.request_context.lifespan_context t0 = time.perf_counter() min_sim = 0.6 response = await app.genai_client.aio.models.embed_content( model=app.settings.embedding_model, contents=query, config=genai_types.EmbedContentConfig( task_type="RETRIEVAL_QUERY", ), ) embedding = response.embeddings[0].values t_embed = time.perf_counter() search_results = await app.vector_search.async_run_query( deployed_index_id=app.settings.deployed_index_id, query=embedding, limit=app.settings.search_limit, ) t_search = time.perf_counter() # Apply similarity filtering if search_results: max_sim = max(r["distance"] for r in search_results) cutoff = max_sim * 0.9 search_results = [ s for s in search_results if s["distance"] > cutoff and s["distance"] > min_sim ] logger.info( "knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, chunks=%s", round((t_embed - t0) * 1000, 1), round((t_search - t_embed) * 1000, 1), round((t_search - t0) * 1000, 1), [s["id"] for s in search_results], ) # Format results as XML-like documents formatted_results = [ f"\n{result['content']}\n" for i, result in enumerate(search_results, start=1) ] return "\n".join(formatted_results) if __name__ == "__main__": mcp.run(transport=_args.transport)