initial implementation

This commit is contained in:
2026-02-22 05:56:43 +00:00
parent 20f29d959b
commit 56e181a772
5 changed files with 1931 additions and 0 deletions

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12

3
AGENTS.md Normal file
View File

@@ -0,0 +1,3 @@
Use `uv` for project management
Linter: `uv run ruff check`
Type-checking: `uv run ty check`

401
main.py Normal file
View File

@@ -0,0 +1,401 @@
# ruff: noqa: INP001
"""Async helpers for querying Vertex AI vector search via MCP."""
import asyncio
import io
import logging
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import BinaryIO, TypedDict
import aiohttp
from gcloud.aio.auth import Token
from gcloud.aio.storage import Storage
from google import genai
from google.genai import types as genai_types
from mcp.server.fastmcp import Context, FastMCP
from pydantic_settings import BaseSettings
logger = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500
class GoogleCloudFileStorage:
"""Cache-aware helper for downloading files from Google Cloud Storage."""
def __init__(self, bucket: str) -> None:
"""Initialize the storage helper."""
self.bucket_name = bucket
self._aio_session: aiohttp.ClientSession | None = None
self._aio_storage: Storage | None = None
self._cache: dict[str, bytes] = {}
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300,
limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout,
connector=connector,
)
return self._aio_session
def _get_aio_storage(self) -> Storage:
if self._aio_storage is None:
self._aio_storage = Storage(
session=self._get_aio_session(),
)
return self._aio_storage
async def async_get_file_stream(
self,
file_name: str,
max_retries: int = 3,
) -> BinaryIO:
"""Get a file asynchronously with retry on transient errors.
Args:
file_name: The blob name to retrieve.
max_retries: Maximum number of retry attempts.
Returns:
A BytesIO stream with the file contents.
Raises:
TimeoutError: If all retry attempts fail.
"""
if file_name in self._cache:
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
storage_client = self._get_aio_storage()
last_exception: Exception | None = None
for attempt in range(max_retries):
try:
self._cache[file_name] = await storage_client.download(
self.bucket_name,
file_name,
)
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
except TimeoutError as exc:
last_exception = exc
logger.warning(
"Timeout downloading gs://%s/%s (attempt %d/%d)",
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
except aiohttp.ClientResponseError as exc:
last_exception = exc
if (
exc.status == HTTP_TOO_MANY_REQUESTS
or exc.status >= HTTP_SERVER_ERROR
):
logger.warning(
"HTTP %d downloading gs://%s/%s (attempt %d/%d)",
exc.status,
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
else:
raise
else:
return file_stream
if attempt < max_retries - 1:
delay = 0.5 * (2**attempt)
await asyncio.sleep(delay)
msg = (
f"Failed to download gs://{self.bucket_name}/{file_name} "
f"after {max_retries} attempts"
)
raise TimeoutError(msg) from last_exception
class SearchResult(TypedDict):
"""Structured response item returned by the vector search API."""
id: str
distance: float
content: str
class GoogleCloudVectorSearch:
"""Minimal async client for the Vertex AI Matching Engine REST API."""
def __init__(
self,
project_id: str,
location: str,
bucket: str,
index_name: str | None = None,
) -> None:
"""Store configuration used to issue Matching Engine queries."""
self.project_id = project_id
self.location = location
self.storage = GoogleCloudFileStorage(bucket=bucket)
self.index_name = index_name
self._aio_session: aiohttp.ClientSession | None = None
self._async_token: Token | None = None
self._endpoint_domain: str | None = None
self._endpoint_name: str | None = None
async def _async_get_auth_headers(self) -> dict[str, str]:
if self._async_token is None:
self._async_token = Token(
session=self._get_aio_session(),
scopes=[
"https://www.googleapis.com/auth/cloud-platform",
],
)
access_token = await self._async_token.get()
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300,
limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout,
connector=connector,
)
return self._aio_session
def configure_index_endpoint(
self,
*,
name: str,
public_domain: str,
) -> None:
"""Persist the metadata needed to access a deployed endpoint."""
if not name:
msg = "Index endpoint name must be a non-empty string."
raise ValueError(msg)
if not public_domain:
msg = "Index endpoint domain must be a non-empty public domain."
raise ValueError(msg)
self._endpoint_name = name
self._endpoint_domain = public_domain
async def async_run_query(
self,
deployed_index_id: str,
query: Sequence[float],
limit: int,
) -> list[SearchResult]:
"""Run an async similarity search via the REST API.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
if self._endpoint_domain is None or self._endpoint_name is None:
msg = (
"Missing endpoint metadata. Call "
"`configure_index_endpoint` before querying."
)
raise RuntimeError(msg)
domain = self._endpoint_domain
endpoint_id = self._endpoint_name.split("/")[-1]
url = (
f"https://{domain}/v1/projects/{self.project_id}"
f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors"
)
payload = {
"deployed_index_id": deployed_index_id,
"queries": [
{
"datapoint": {"feature_vector": list(query)},
"neighbor_count": limit,
},
],
}
headers = await self._async_get_auth_headers()
session = self._get_aio_session()
async with session.post(
url,
json=payload,
headers=headers,
) as response:
response.raise_for_status()
data = await response.json()
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
content_tasks = []
for neighbor in neighbors:
datapoint_id = neighbor["datapoint"]["datapointId"]
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
content_tasks.append(
self.storage.async_get_file_stream(file_path),
)
file_streams = await asyncio.gather(*content_tasks)
results: list[SearchResult] = []
for neighbor, stream in zip(
neighbors,
file_streams,
strict=True,
):
results.append(
SearchResult(
id=neighbor["datapoint"]["datapointId"],
distance=neighbor["distance"],
content=stream.read().decode("utf-8"),
),
)
return results
# ---------------------------------------------------------------------------
# MCP Server
# ---------------------------------------------------------------------------
class Settings(BaseSettings):
"""Server configuration populated from environment variables."""
project_id: str
location: str
bucket: str
index_name: str
deployed_index_id: str
endpoint_name: str
endpoint_domain: str
embedding_model: str = "text-embedding-005"
search_limit: int = 10
@dataclass
class AppContext:
"""Shared resources initialised once at server startup."""
vector_search: GoogleCloudVectorSearch
genai_client: genai.Client
settings: Settings
@asynccontextmanager
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
"""Create and configure the vector-search client for the server lifetime."""
cfg = Settings.model_validate({})
vs = GoogleCloudVectorSearch(
project_id=cfg.project_id,
location=cfg.location,
bucket=cfg.bucket,
index_name=cfg.index_name,
)
vs.configure_index_endpoint(
name=cfg.endpoint_name,
public_domain=cfg.endpoint_domain,
)
genai_client = genai.Client(
vertexai=True,
project=cfg.project_id,
location=cfg.location,
)
yield AppContext(
vector_search=vs,
genai_client=genai_client,
settings=cfg,
)
mcp = FastMCP("knowledge-search", lifespan=lifespan)
@mcp.tool()
async def knowledge_search(
query: str,
ctx: Context,
) -> str:
"""Search a knowledge base using a natural-language query.
Args:
query: The text query to search for.
ctx: MCP request context (injected automatically).
Returns:
A formatted string containing matched documents with id and content.
"""
import time
app: AppContext = ctx.request_context.lifespan_context
t0 = time.perf_counter()
min_sim = 0.6
response = await app.genai_client.aio.models.embed_content(
model=app.settings.embedding_model,
contents=query,
config=genai_types.EmbedContentConfig(
task_type="RETRIEVAL_QUERY",
),
)
embedding = response.embeddings[0].values
t_embed = time.perf_counter()
search_results = await app.vector_search.async_run_query(
deployed_index_id=app.settings.deployed_index_id,
query=embedding,
limit=app.settings.search_limit,
)
t_search = time.perf_counter()
# Apply similarity filtering
if search_results:
max_sim = max(r["distance"] for r in search_results)
cutoff = max_sim * 0.9
search_results = [
s
for s in search_results
if s["distance"] > cutoff and s["distance"] > min_sim
]
logger.info(
"knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, chunks=%s",
round((t_embed - t0) * 1000, 1),
round((t_search - t_embed) * 1000, 1),
round((t_search - t0) * 1000, 1),
[s["id"] for s in search_results],
)
# Format results as XML-like documents
formatted_results = [
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
for i, result in enumerate(search_results, start=1)
]
return "\n".join(formatted_results)
if __name__ == "__main__":
mcp.run()

21
pyproject.toml Normal file
View File

@@ -0,0 +1,21 @@
[project]
name = "knowledge-search-mcp"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"aiohttp>=3.13.3",
"gcloud-aio-auth>=5.4.2",
"gcloud-aio-storage>=9.6.1",
"google-auth>=2.48.0",
"google-genai>=1.64.0",
"mcp[cli]>=1.26.0",
"pydantic-settings>=2.9.1",
]
[dependency-groups]
dev = [
"ruff>=0.15.2",
"ty>=0.0.18",
]

1505
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff