Files
knowledge-search-mcp/main.py

783 lines
27 KiB
Python

# ruff: noqa: INP001
"""Async helpers for querying Vertex AI vector search via MCP."""
import asyncio
import io
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import BinaryIO, TypedDict
import aiohttp
from gcloud.aio.auth import Token
from gcloud.aio.storage import Storage
from google import genai
from google.genai import types as genai_types
from mcp.server.fastmcp import Context, FastMCP
from utils import Settings, _args, log_structured_entry
HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500
class GoogleCloudFileStorage:
"""Cache-aware helper for downloading files from Google Cloud Storage."""
def __init__(self, bucket: str) -> None:
"""Initialize the storage helper."""
self.bucket_name = bucket
self._aio_session: aiohttp.ClientSession | None = None
self._aio_storage: Storage | None = None
self._cache: dict[str, bytes] = {}
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300,
limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout,
connector=connector,
)
return self._aio_session
def _get_aio_storage(self) -> Storage:
if self._aio_storage is None:
self._aio_storage = Storage(
session=self._get_aio_session(),
)
return self._aio_storage
async def async_get_file_stream(
self,
file_name: str,
max_retries: int = 3,
) -> BinaryIO:
"""Get a file asynchronously with retry on transient errors.
Args:
file_name: The blob name to retrieve.
max_retries: Maximum number of retry attempts.
Returns:
A BytesIO stream with the file contents.
Raises:
TimeoutError: If all retry attempts fail.
"""
if file_name in self._cache:
log_structured_entry(
"File retrieved from cache",
"INFO",
{"file": file_name, "bucket": self.bucket_name}
)
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
log_structured_entry(
"Starting file download from GCS",
"INFO",
{"file": file_name, "bucket": self.bucket_name}
)
storage_client = self._get_aio_storage()
last_exception: Exception | None = None
for attempt in range(max_retries):
try:
self._cache[file_name] = await storage_client.download(
self.bucket_name,
file_name,
)
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
log_structured_entry(
"File downloaded successfully",
"INFO",
{
"file": file_name,
"bucket": self.bucket_name,
"size_bytes": len(self._cache[file_name]),
"attempt": attempt + 1
}
)
except TimeoutError as exc:
last_exception = exc
log_structured_entry(
f"Timeout downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})",
"WARNING",
{"error": str(exc)}
)
except aiohttp.ClientResponseError as exc:
last_exception = exc
if (
exc.status == HTTP_TOO_MANY_REQUESTS
or exc.status >= HTTP_SERVER_ERROR
):
log_structured_entry(
f"HTTP {exc.status} downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})",
"WARNING",
{"status": exc.status, "message": str(exc)}
)
else:
log_structured_entry(
f"Non-retryable HTTP error downloading gs://{self.bucket_name}/{file_name}",
"ERROR",
{"status": exc.status, "message": str(exc)}
)
raise
else:
return file_stream
if attempt < max_retries - 1:
delay = 0.5 * (2**attempt)
log_structured_entry(
"Retrying file download",
"INFO",
{"file": file_name, "delay_seconds": delay}
)
await asyncio.sleep(delay)
msg = (
f"Failed to download gs://{self.bucket_name}/{file_name} "
f"after {max_retries} attempts"
)
log_structured_entry(
"File download failed after all retries",
"ERROR",
{
"file": file_name,
"bucket": self.bucket_name,
"max_retries": max_retries,
"last_error": str(last_exception)
}
)
raise TimeoutError(msg) from last_exception
class SearchResult(TypedDict):
"""Structured response item returned by the vector search API."""
id: str
distance: float
content: str
class GoogleCloudVectorSearch:
"""Minimal async client for the Vertex AI Matching Engine REST API."""
def __init__(
self,
project_id: str,
location: str,
bucket: str,
index_name: str | None = None,
) -> None:
"""Store configuration used to issue Matching Engine queries."""
self.project_id = project_id
self.location = location
self.storage = GoogleCloudFileStorage(bucket=bucket)
self.index_name = index_name
self._aio_session: aiohttp.ClientSession | None = None
self._async_token: Token | None = None
self._endpoint_domain: str | None = None
self._endpoint_name: str | None = None
async def _async_get_auth_headers(self) -> dict[str, str]:
if self._async_token is None:
self._async_token = Token(
session=self._get_aio_session(),
scopes=[
"https://www.googleapis.com/auth/cloud-platform",
],
)
access_token = await self._async_token.get()
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300,
limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout,
connector=connector,
)
return self._aio_session
def configure_index_endpoint(
self,
*,
name: str,
public_domain: str,
) -> None:
"""Persist the metadata needed to access a deployed endpoint."""
if not name:
msg = "Index endpoint name must be a non-empty string."
raise ValueError(msg)
if not public_domain:
msg = "Index endpoint domain must be a non-empty public domain."
raise ValueError(msg)
self._endpoint_name = name
self._endpoint_domain = public_domain
async def async_run_query(
self,
deployed_index_id: str,
query: Sequence[float],
limit: int,
) -> list[SearchResult]:
"""Run an async similarity search via the REST API.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
if self._endpoint_domain is None or self._endpoint_name is None:
msg = (
"Missing endpoint metadata. Call "
"`configure_index_endpoint` before querying."
)
log_structured_entry(
"Vector search query failed - endpoint not configured",
"ERROR",
{"error": msg}
)
raise RuntimeError(msg)
domain = self._endpoint_domain
endpoint_id = self._endpoint_name.split("/")[-1]
url = (
f"https://{domain}/v1/projects/{self.project_id}"
f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors"
)
log_structured_entry(
"Starting vector search query",
"INFO",
{
"deployed_index_id": deployed_index_id,
"neighbor_count": limit,
"endpoint_id": endpoint_id,
"embedding_dimension": len(query)
}
)
payload = {
"deployed_index_id": deployed_index_id,
"queries": [
{
"datapoint": {"feature_vector": list(query)},
"neighbor_count": limit,
},
],
}
try:
headers = await self._async_get_auth_headers()
session = self._get_aio_session()
async with session.post(
url,
json=payload,
headers=headers,
) as response:
if not response.ok:
body = await response.text()
msg = f"findNeighbors returned {response.status}: {body}"
log_structured_entry(
"Vector search API request failed",
"ERROR",
{
"status": response.status,
"response_body": body,
"deployed_index_id": deployed_index_id
}
)
raise RuntimeError(msg)
data = await response.json()
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
log_structured_entry(
"Vector search API request successful",
"INFO",
{
"neighbors_found": len(neighbors),
"deployed_index_id": deployed_index_id
}
)
if not neighbors:
log_structured_entry(
"No neighbors found in vector search",
"WARNING",
{"deployed_index_id": deployed_index_id}
)
return []
# Fetch content for all neighbors
content_tasks = []
for neighbor in neighbors:
datapoint_id = neighbor["datapoint"]["datapointId"]
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
content_tasks.append(
self.storage.async_get_file_stream(file_path),
)
log_structured_entry(
"Fetching content for search results",
"INFO",
{"file_count": len(content_tasks)}
)
file_streams = await asyncio.gather(*content_tasks)
results: list[SearchResult] = []
for neighbor, stream in zip(
neighbors,
file_streams,
strict=True,
):
results.append(
SearchResult(
id=neighbor["datapoint"]["datapointId"],
distance=neighbor["distance"],
content=stream.read().decode("utf-8"),
),
)
log_structured_entry(
"Vector search completed successfully",
"INFO",
{
"results_count": len(results),
"deployed_index_id": deployed_index_id
}
)
return results
except Exception as e:
log_structured_entry(
"Vector search query failed with exception",
"ERROR",
{
"error": str(e),
"error_type": type(e).__name__,
"deployed_index_id": deployed_index_id
}
)
raise
# ---------------------------------------------------------------------------
# MCP Server
# ---------------------------------------------------------------------------
@dataclass
class AppContext:
"""Shared resources initialised once at server startup."""
vector_search: GoogleCloudVectorSearch
genai_client: genai.Client
settings: Settings
@asynccontextmanager
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
"""Create and configure the vector-search client for the server lifetime."""
log_structured_entry(
"Initializing MCP server",
"INFO",
{
"project_id": cfg.project_id,
"location": cfg.location,
"bucket": cfg.bucket,
"index_name": cfg.index_name,
}
)
try:
# Initialize vector search client
log_structured_entry("Creating GoogleCloudVectorSearch client", "INFO")
vs = GoogleCloudVectorSearch(
project_id=cfg.project_id,
location=cfg.location,
bucket=cfg.bucket,
index_name=cfg.index_name,
)
# Configure endpoint
log_structured_entry(
"Configuring index endpoint",
"INFO",
{
"endpoint_name": cfg.endpoint_name,
"endpoint_domain": cfg.endpoint_domain,
}
)
vs.configure_index_endpoint(
name=cfg.endpoint_name,
public_domain=cfg.endpoint_domain,
)
# Initialize GenAI client
log_structured_entry(
"Creating GenAI client",
"INFO",
{"project_id": cfg.project_id, "location": cfg.location}
)
genai_client = genai.Client(
vertexai=True,
project=cfg.project_id,
location=cfg.location,
)
# Validate credentials and configuration by testing actual resources
# These validations are non-blocking - errors are logged but won't stop startup
log_structured_entry("Starting validation of credentials and resources", "INFO")
validation_errors = []
# 1. Validate GenAI embedding access
log_structured_entry("Validating GenAI embedding access", "INFO")
try:
test_response = await genai_client.aio.models.embed_content(
model=cfg.embedding_model,
contents="test",
config=genai_types.EmbedContentConfig(
task_type="RETRIEVAL_QUERY",
),
)
if test_response and test_response.embeddings:
log_structured_entry(
"GenAI embedding validation successful",
"INFO",
{"embedding_dimension": len(test_response.embeddings[0].values)}
)
else:
msg = "Embedding validation returned empty response"
log_structured_entry(msg, "WARNING")
validation_errors.append(msg)
except Exception as e:
log_structured_entry(
"Failed to validate GenAI embedding access - service may not work correctly",
"WARNING",
{"error": str(e), "error_type": type(e).__name__}
)
validation_errors.append(f"GenAI: {str(e)}")
# 2. Validate GCS bucket access
log_structured_entry(
"Validating GCS bucket access",
"INFO",
{"bucket": cfg.bucket}
)
try:
session = vs.storage._get_aio_session()
token_obj = Token(
session=session,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
access_token = await token_obj.get()
headers = {"Authorization": f"Bearer {access_token}"}
async with session.get(
f"https://storage.googleapis.com/storage/v1/b/{cfg.bucket}/o?maxResults=1",
headers=headers,
) as response:
if response.status == 403:
msg = f"Access denied to bucket '{cfg.bucket}'. Check permissions."
log_structured_entry(
"GCS bucket validation failed - access denied - service may not work correctly",
"WARNING",
{"bucket": cfg.bucket, "status": response.status}
)
validation_errors.append(msg)
elif response.status == 404:
msg = f"Bucket '{cfg.bucket}' not found. Check bucket name and project."
log_structured_entry(
"GCS bucket validation failed - not found - service may not work correctly",
"WARNING",
{"bucket": cfg.bucket, "status": response.status}
)
validation_errors.append(msg)
elif not response.ok:
body = await response.text()
msg = f"Failed to access bucket '{cfg.bucket}': {response.status}"
log_structured_entry(
"GCS bucket validation failed - service may not work correctly",
"WARNING",
{"bucket": cfg.bucket, "status": response.status, "response": body}
)
validation_errors.append(msg)
else:
log_structured_entry(
"GCS bucket validation successful",
"INFO",
{"bucket": cfg.bucket}
)
except Exception as e:
log_structured_entry(
"Failed to validate GCS bucket access - service may not work correctly",
"WARNING",
{"error": str(e), "error_type": type(e).__name__, "bucket": cfg.bucket}
)
validation_errors.append(f"GCS: {str(e)}")
# 3. Validate vector search endpoint access
log_structured_entry(
"Validating vector search endpoint access",
"INFO",
{"endpoint_name": cfg.endpoint_name}
)
try:
# Try to get endpoint info
headers = await vs._async_get_auth_headers()
session = vs._get_aio_session()
endpoint_url = (
f"https://{cfg.location}-aiplatform.googleapis.com/v1/{cfg.endpoint_name}"
)
async with session.get(endpoint_url, headers=headers) as response:
if response.status == 403:
msg = f"Access denied to endpoint '{cfg.endpoint_name}'. Check permissions."
log_structured_entry(
"Vector search endpoint validation failed - access denied - service may not work correctly",
"WARNING",
{"endpoint": cfg.endpoint_name, "status": response.status}
)
validation_errors.append(msg)
elif response.status == 404:
msg = f"Endpoint '{cfg.endpoint_name}' not found. Check endpoint name and project."
log_structured_entry(
"Vector search endpoint validation failed - not found - service may not work correctly",
"WARNING",
{"endpoint": cfg.endpoint_name, "status": response.status}
)
validation_errors.append(msg)
elif not response.ok:
body = await response.text()
msg = f"Failed to access endpoint '{cfg.endpoint_name}': {response.status}"
log_structured_entry(
"Vector search endpoint validation failed - service may not work correctly",
"WARNING",
{"endpoint": cfg.endpoint_name, "status": response.status, "response": body}
)
validation_errors.append(msg)
else:
log_structured_entry(
"Vector search endpoint validation successful",
"INFO",
{"endpoint": cfg.endpoint_name}
)
except Exception as e:
log_structured_entry(
"Failed to validate vector search endpoint access - service may not work correctly",
"WARNING",
{"error": str(e), "error_type": type(e).__name__, "endpoint": cfg.endpoint_name}
)
validation_errors.append(f"Vector Search: {str(e)}")
# Summary of validations
if validation_errors:
log_structured_entry(
"MCP server started with validation errors - service may not work correctly",
"WARNING",
{"validation_errors": validation_errors, "error_count": len(validation_errors)}
)
else:
log_structured_entry("All validations passed - MCP server initialization complete", "INFO")
yield AppContext(
vector_search=vs,
genai_client=genai_client,
settings=cfg,
)
except Exception as e:
log_structured_entry(
"Failed to initialize MCP server",
"ERROR",
{
"error": str(e),
"error_type": type(e).__name__,
}
)
raise
finally:
log_structured_entry("MCP server lifespan ending", "INFO")
cfg = Settings.model_validate({})
mcp = FastMCP(
"knowledge-search",
host=_args.host,
port=_args.port,
lifespan=lifespan,
)
@mcp.tool()
async def knowledge_search(
query: str,
ctx: Context,
) -> str:
"""Search a knowledge base using a natural-language query.
Args:
query: The text query to search for.
ctx: MCP request context (injected automatically).
Returns:
A formatted string containing matched documents with id and content.
"""
import time
app: AppContext = ctx.request_context.lifespan_context
t0 = time.perf_counter()
min_sim = 0.6
log_structured_entry(
"knowledge_search request received",
"INFO",
{"query": query[:100]} # Log first 100 chars of query
)
try:
# Generate embedding for the query
log_structured_entry("Generating query embedding", "INFO")
try:
response = await app.genai_client.aio.models.embed_content(
model=app.settings.embedding_model,
contents=query,
config=genai_types.EmbedContentConfig(
task_type="RETRIEVAL_QUERY",
),
)
embedding = response.embeddings[0].values
t_embed = time.perf_counter()
log_structured_entry(
"Query embedding generated successfully",
"INFO",
{"time_ms": round((t_embed - t0) * 1000, 1)}
)
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
# Check if it's a rate limit error
if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
log_structured_entry(
"Rate limit exceeded while generating embedding",
"WARNING",
{
"error": error_msg,
"error_type": error_type,
"query": query[:100]
}
)
return "Error: API rate limit exceeded. Please try again later."
else:
log_structured_entry(
"Failed to generate query embedding",
"ERROR",
{
"error": error_msg,
"error_type": error_type,
"query": query[:100]
}
)
return f"Error generating embedding: {error_msg}"
# Perform vector search
log_structured_entry("Performing vector search", "INFO")
try:
search_results = await app.vector_search.async_run_query(
deployed_index_id=app.settings.deployed_index_id,
query=embedding,
limit=app.settings.search_limit,
)
t_search = time.perf_counter()
except Exception as e:
log_structured_entry(
"Vector search failed",
"ERROR",
{
"error": str(e),
"error_type": type(e).__name__,
"query": query[:100]
}
)
return f"Error performing vector search: {str(e)}"
# Apply similarity filtering
if search_results:
max_sim = max(r["distance"] for r in search_results)
cutoff = max_sim * 0.9
search_results = [
s
for s in search_results
if s["distance"] > cutoff and s["distance"] > min_sim
]
log_structured_entry(
"knowledge_search completed successfully",
"INFO",
{
"embedding_ms": f"{round((t_embed - t0) * 1000, 1)}ms",
"vector_search_ms": f"{round((t_search - t_embed) * 1000, 1)}ms",
"total_ms": f"{round((t_search - t0) * 1000, 1)}ms",
"results_count": len(search_results),
"chunks": [s["id"] for s in search_results]
}
)
# Format results as XML-like documents
if not search_results:
log_structured_entry(
"No results found for query",
"INFO",
{"query": query[:100]}
)
return "No relevant documents found for your query."
formatted_results = [
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
for i, result in enumerate(search_results, start=1)
]
return "\n".join(formatted_results)
except Exception as e:
# Catch-all for any unexpected errors
log_structured_entry(
"Unexpected error in knowledge_search",
"ERROR",
{
"error": str(e),
"error_type": type(e).__name__,
"query": query[:100]
}
)
return f"Unexpected error during search: {str(e)}"
if __name__ == "__main__":
mcp.run(transport=_args.transport)