From dba94107a55989bede995af151230f145e0d1f14 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Tue, 3 Mar 2026 18:34:33 +0000 Subject: [PATCH] Split out main module --- src/knowledge_search_mcp/__init__.py | 15 + src/knowledge_search_mcp/clients/__init__.py | 11 + src/knowledge_search_mcp/clients/base.py | 31 + src/knowledge_search_mcp/clients/storage.py | 144 +++ .../clients/vector_search.py | 226 +++++ src/knowledge_search_mcp/main.py | 832 +----------------- src/knowledge_search_mcp/models.py | 37 + src/knowledge_search_mcp/server.py | 129 +++ src/knowledge_search_mcp/services/__init__.py | 13 + src/knowledge_search_mcp/services/search.py | 110 +++ .../services/validation.py | 171 ++++ src/knowledge_search_mcp/utils/__init__.py | 5 + src/knowledge_search_mcp/utils/cache.py | 33 + tests/test_search.py | 2 +- 14 files changed, 934 insertions(+), 825 deletions(-) create mode 100644 src/knowledge_search_mcp/clients/__init__.py create mode 100644 src/knowledge_search_mcp/clients/base.py create mode 100644 src/knowledge_search_mcp/clients/storage.py create mode 100644 src/knowledge_search_mcp/clients/vector_search.py create mode 100644 src/knowledge_search_mcp/models.py create mode 100644 src/knowledge_search_mcp/server.py create mode 100644 src/knowledge_search_mcp/services/__init__.py create mode 100644 src/knowledge_search_mcp/services/search.py create mode 100644 src/knowledge_search_mcp/services/validation.py create mode 100644 src/knowledge_search_mcp/utils/__init__.py create mode 100644 src/knowledge_search_mcp/utils/cache.py diff --git a/src/knowledge_search_mcp/__init__.py b/src/knowledge_search_mcp/__init__.py index e69de29..3c3648f 100644 --- a/src/knowledge_search_mcp/__init__.py +++ b/src/knowledge_search_mcp/__init__.py @@ -0,0 +1,15 @@ +"""MCP server for semantic search over Vertex AI Vector Search.""" + +from .clients.storage import GoogleCloudFileStorage +from .clients.vector_search import GoogleCloudVectorSearch +from .models import AppContext, SearchResult, SourceNamespace +from .utils.cache import LRUCache + +__all__ = [ + "GoogleCloudFileStorage", + "GoogleCloudVectorSearch", + "SourceNamespace", + "SearchResult", + "AppContext", + "LRUCache", +] diff --git a/src/knowledge_search_mcp/clients/__init__.py b/src/knowledge_search_mcp/clients/__init__.py new file mode 100644 index 0000000..745d991 --- /dev/null +++ b/src/knowledge_search_mcp/clients/__init__.py @@ -0,0 +1,11 @@ +"""Client modules for Google Cloud services.""" + +from .base import BaseGoogleCloudClient +from .storage import GoogleCloudFileStorage +from .vector_search import GoogleCloudVectorSearch + +__all__ = [ + "BaseGoogleCloudClient", + "GoogleCloudFileStorage", + "GoogleCloudVectorSearch", +] diff --git a/src/knowledge_search_mcp/clients/base.py b/src/knowledge_search_mcp/clients/base.py new file mode 100644 index 0000000..4e4b0b2 --- /dev/null +++ b/src/knowledge_search_mcp/clients/base.py @@ -0,0 +1,31 @@ +# ruff: noqa: INP001 +"""Base client with shared aiohttp session management.""" + +import aiohttp + + +class BaseGoogleCloudClient: + """Base class with shared aiohttp session management.""" + + def __init__(self) -> None: + """Initialize session tracking.""" + self._aio_session: aiohttp.ClientSession | None = None + + def _get_aio_session(self) -> aiohttp.ClientSession: + """Get or create aiohttp session with connection pooling.""" + if self._aio_session is None or self._aio_session.closed: + connector = aiohttp.TCPConnector( + limit=300, + limit_per_host=50, + ) + timeout = aiohttp.ClientTimeout(total=60) + self._aio_session = aiohttp.ClientSession( + timeout=timeout, + connector=connector, + ) + return self._aio_session + + async def close(self) -> None: + """Close aiohttp session if open.""" + if self._aio_session and not self._aio_session.closed: + await self._aio_session.close() diff --git a/src/knowledge_search_mcp/clients/storage.py b/src/knowledge_search_mcp/clients/storage.py new file mode 100644 index 0000000..6004df8 --- /dev/null +++ b/src/knowledge_search_mcp/clients/storage.py @@ -0,0 +1,144 @@ +# ruff: noqa: INP001 +"""Google Cloud Storage client with caching.""" + +import asyncio +import io +from typing import BinaryIO + +import aiohttp +from gcloud.aio.storage import Storage + +from ..logging import log_structured_entry +from ..utils.cache import LRUCache +from .base import BaseGoogleCloudClient + +HTTP_TOO_MANY_REQUESTS = 429 +HTTP_SERVER_ERROR = 500 + + +class GoogleCloudFileStorage(BaseGoogleCloudClient): + """Cache-aware helper for downloading files from Google Cloud Storage.""" + + def __init__(self, bucket: str, cache_size: int = 100) -> None: + """Initialize the storage helper with LRU cache.""" + super().__init__() + self.bucket_name = bucket + self._aio_storage: Storage | None = None + self._cache = LRUCache(max_size=cache_size) + + def _get_aio_storage(self) -> Storage: + if self._aio_storage is None: + self._aio_storage = Storage( + session=self._get_aio_session(), + ) + return self._aio_storage + + async def async_get_file_stream( + self, + file_name: str, + max_retries: int = 3, + ) -> BinaryIO: + """Get a file asynchronously with retry on transient errors. + + Args: + file_name: The blob name to retrieve. + max_retries: Maximum number of retry attempts. + + Returns: + A BytesIO stream with the file contents. + + Raises: + TimeoutError: If all retry attempts fail. + + """ + cached_content = self._cache.get(file_name) + if cached_content is not None: + log_structured_entry( + "File retrieved from cache", + "INFO", + {"file": file_name, "bucket": self.bucket_name} + ) + file_stream = io.BytesIO(cached_content) + file_stream.name = file_name + return file_stream + + log_structured_entry( + "Starting file download from GCS", + "INFO", + {"file": file_name, "bucket": self.bucket_name} + ) + + storage_client = self._get_aio_storage() + last_exception: Exception | None = None + + for attempt in range(max_retries): + try: + content = await storage_client.download( + self.bucket_name, + file_name, + ) + self._cache.put(file_name, content) + file_stream = io.BytesIO(content) + file_stream.name = file_name + log_structured_entry( + "File downloaded successfully", + "INFO", + { + "file": file_name, + "bucket": self.bucket_name, + "size_bytes": len(content), + "attempt": attempt + 1 + } + ) + except TimeoutError as exc: + last_exception = exc + log_structured_entry( + f"Timeout downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})", + "WARNING", + {"error": str(exc)} + ) + except aiohttp.ClientResponseError as exc: + last_exception = exc + if ( + exc.status == HTTP_TOO_MANY_REQUESTS + or exc.status >= HTTP_SERVER_ERROR + ): + log_structured_entry( + f"HTTP {exc.status} downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})", + "WARNING", + {"status": exc.status, "message": str(exc)} + ) + else: + log_structured_entry( + f"Non-retryable HTTP error downloading gs://{self.bucket_name}/{file_name}", + "ERROR", + {"status": exc.status, "message": str(exc)} + ) + raise + else: + return file_stream + + if attempt < max_retries - 1: + delay = 0.5 * (2**attempt) + log_structured_entry( + "Retrying file download", + "INFO", + {"file": file_name, "delay_seconds": delay} + ) + await asyncio.sleep(delay) + + msg = ( + f"Failed to download gs://{self.bucket_name}/{file_name} " + f"after {max_retries} attempts" + ) + log_structured_entry( + "File download failed after all retries", + "ERROR", + { + "file": file_name, + "bucket": self.bucket_name, + "max_retries": max_retries, + "last_error": str(last_exception) + } + ) + raise TimeoutError(msg) from last_exception diff --git a/src/knowledge_search_mcp/clients/vector_search.py b/src/knowledge_search_mcp/clients/vector_search.py new file mode 100644 index 0000000..bd80585 --- /dev/null +++ b/src/knowledge_search_mcp/clients/vector_search.py @@ -0,0 +1,226 @@ +# ruff: noqa: INP001 +"""Google Cloud Vector Search client.""" + +import asyncio +from collections.abc import Sequence + +from gcloud.aio.auth import Token + +from ..logging import log_structured_entry +from ..models import SearchResult, SourceNamespace +from .base import BaseGoogleCloudClient +from .storage import GoogleCloudFileStorage + + +class GoogleCloudVectorSearch(BaseGoogleCloudClient): + """Minimal async client for the Vertex AI Matching Engine REST API.""" + + def __init__( + self, + project_id: str, + location: str, + bucket: str, + index_name: str | None = None, + ) -> None: + """Store configuration used to issue Matching Engine queries.""" + super().__init__() + self.project_id = project_id + self.location = location + self.storage = GoogleCloudFileStorage(bucket=bucket) + self.index_name = index_name + self._async_token: Token | None = None + self._endpoint_domain: str | None = None + self._endpoint_name: str | None = None + + async def _async_get_auth_headers(self) -> dict[str, str]: + if self._async_token is None: + self._async_token = Token( + session=self._get_aio_session(), + scopes=[ + "https://www.googleapis.com/auth/cloud-platform", + ], + ) + access_token = await self._async_token.get() + return { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + async def close(self) -> None: + """Close aiohttp sessions for both vector search and storage.""" + await super().close() + await self.storage.close() + + def configure_index_endpoint( + self, + *, + name: str, + public_domain: str, + ) -> None: + """Persist the metadata needed to access a deployed endpoint.""" + if not name: + msg = "Index endpoint name must be a non-empty string." + raise ValueError(msg) + if not public_domain: + msg = "Index endpoint domain must be a non-empty public domain." + raise ValueError(msg) + self._endpoint_name = name + self._endpoint_domain = public_domain + + async def async_run_query( + self, + deployed_index_id: str, + query: Sequence[float], + limit: int, + source: SourceNamespace | None = None, + ) -> list[SearchResult]: + """Run an async similarity search via the REST API. + + Args: + deployed_index_id: The ID of the deployed index. + query: The embedding vector for the search query. + limit: Maximum number of nearest neighbors to return. + source: Optional namespace filter to restrict results by source. + + Returns: + A list of matched items with id, distance, and content. + + """ + if self._endpoint_domain is None or self._endpoint_name is None: + msg = ( + "Missing endpoint metadata. Call " + "`configure_index_endpoint` before querying." + ) + log_structured_entry( + "Vector search query failed - endpoint not configured", + "ERROR", + {"error": msg} + ) + raise RuntimeError(msg) + + domain = self._endpoint_domain + endpoint_id = self._endpoint_name.split("/")[-1] + url = ( + f"https://{domain}/v1/projects/{self.project_id}" + f"/locations/{self.location}" + f"/indexEndpoints/{endpoint_id}:findNeighbors" + ) + + log_structured_entry( + "Starting vector search query", + "INFO", + { + "deployed_index_id": deployed_index_id, + "neighbor_count": limit, + "endpoint_id": endpoint_id, + "embedding_dimension": len(query) + } + ) + + datapoint: dict = {"feature_vector": list(query)} + if source is not None: + datapoint["restricts"] = [ + {"namespace": "source", "allow_list": [source.value]}, + ] + payload = { + "deployed_index_id": deployed_index_id, + "queries": [ + { + "datapoint": datapoint, + "neighbor_count": limit, + }, + ], + } + + try: + headers = await self._async_get_auth_headers() + session = self._get_aio_session() + async with session.post( + url, + json=payload, + headers=headers, + ) as response: + if not response.ok: + body = await response.text() + msg = f"findNeighbors returned {response.status}: {body}" + log_structured_entry( + "Vector search API request failed", + "ERROR", + { + "status": response.status, + "response_body": body, + "deployed_index_id": deployed_index_id + } + ) + raise RuntimeError(msg) + data = await response.json() + + neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", []) + log_structured_entry( + "Vector search API request successful", + "INFO", + { + "neighbors_found": len(neighbors), + "deployed_index_id": deployed_index_id + } + ) + + if not neighbors: + log_structured_entry( + "No neighbors found in vector search", + "WARNING", + {"deployed_index_id": deployed_index_id} + ) + return [] + + # Fetch content for all neighbors + content_tasks = [] + for neighbor in neighbors: + datapoint_id = neighbor["datapoint"]["datapointId"] + file_path = f"{self.index_name}/contents/{datapoint_id}.md" + content_tasks.append( + self.storage.async_get_file_stream(file_path), + ) + + log_structured_entry( + "Fetching content for search results", + "INFO", + {"file_count": len(content_tasks)} + ) + + file_streams = await asyncio.gather(*content_tasks) + results: list[SearchResult] = [] + for neighbor, stream in zip( + neighbors, + file_streams, + strict=True, + ): + results.append( + SearchResult( + id=neighbor["datapoint"]["datapointId"], + distance=neighbor["distance"], + content=stream.read().decode("utf-8"), + ), + ) + + log_structured_entry( + "Vector search completed successfully", + "INFO", + { + "results_count": len(results), + "deployed_index_id": deployed_index_id + } + ) + return results + + except Exception as e: + log_structured_entry( + "Vector search query failed with exception", + "ERROR", + { + "error": str(e), + "error_type": type(e).__name__, + "deployed_index_id": deployed_index_id + } + ) + raise diff --git a/src/knowledge_search_mcp/main.py b/src/knowledge_search_mcp/main.py index ccb3a05..4c98112 100644 --- a/src/knowledge_search_mcp/main.py +++ b/src/knowledge_search_mcp/main.py @@ -1,729 +1,15 @@ # ruff: noqa: INP001 -"""Async helpers for querying Vertex AI vector search via MCP.""" +"""MCP server for semantic search over Vertex AI Vector Search.""" -import asyncio -import io import time -from collections import OrderedDict -from collections.abc import AsyncIterator, Sequence -from contextlib import asynccontextmanager -from dataclasses import dataclass -from enum import Enum -from typing import BinaryIO, TypedDict -import aiohttp -from gcloud.aio.auth import Token -from gcloud.aio.storage import Storage -from google import genai -from google.genai import types as genai_types from mcp.server.fastmcp import Context, FastMCP -from .config import Settings, _args, cfg +from .config import _args from .logging import log_structured_entry - -HTTP_TOO_MANY_REQUESTS = 429 -HTTP_SERVER_ERROR = 500 - - -class LRUCache: - """Simple LRU cache with size limit.""" - - def __init__(self, max_size: int = 100) -> None: - """Initialize cache with maximum size.""" - self.cache: OrderedDict[str, bytes] = OrderedDict() - self.max_size = max_size - - def get(self, key: str) -> bytes | None: - """Get item from cache, returning None if not found.""" - if key not in self.cache: - return None - # Move to end to mark as recently used - self.cache.move_to_end(key) - return self.cache[key] - - def put(self, key: str, value: bytes) -> None: - """Put item in cache, evicting oldest if at capacity.""" - if key in self.cache: - self.cache.move_to_end(key) - self.cache[key] = value - if len(self.cache) > self.max_size: - self.cache.popitem(last=False) - - def __contains__(self, key: str) -> bool: - """Check if key exists in cache.""" - return key in self.cache - - -class BaseGoogleCloudClient: - """Base class with shared aiohttp session management.""" - - def __init__(self) -> None: - """Initialize session tracking.""" - self._aio_session: aiohttp.ClientSession | None = None - - def _get_aio_session(self) -> aiohttp.ClientSession: - """Get or create aiohttp session with connection pooling.""" - if self._aio_session is None or self._aio_session.closed: - connector = aiohttp.TCPConnector( - limit=300, - limit_per_host=50, - ) - timeout = aiohttp.ClientTimeout(total=60) - self._aio_session = aiohttp.ClientSession( - timeout=timeout, - connector=connector, - ) - return self._aio_session - - async def close(self) -> None: - """Close aiohttp session if open.""" - if self._aio_session and not self._aio_session.closed: - await self._aio_session.close() - - -class SourceNamespace(str, Enum): - """Allowed values for the 'source' namespace filter.""" - - EDUCACION_FINANCIERA = "Educacion Financiera" - PRODUCTOS_Y_SERVICIOS = "Productos y Servicios" - FUNCIONALIDADES_APP_MOVIL = "Funcionalidades de la App Movil" - - -class GoogleCloudFileStorage(BaseGoogleCloudClient): - """Cache-aware helper for downloading files from Google Cloud Storage.""" - - def __init__(self, bucket: str, cache_size: int = 100) -> None: - """Initialize the storage helper with LRU cache.""" - super().__init__() - self.bucket_name = bucket - self._aio_storage: Storage | None = None - self._cache = LRUCache(max_size=cache_size) - - def _get_aio_storage(self) -> Storage: - if self._aio_storage is None: - self._aio_storage = Storage( - session=self._get_aio_session(), - ) - return self._aio_storage - - async def async_get_file_stream( - self, - file_name: str, - max_retries: int = 3, - ) -> BinaryIO: - """Get a file asynchronously with retry on transient errors. - - Args: - file_name: The blob name to retrieve. - max_retries: Maximum number of retry attempts. - - Returns: - A BytesIO stream with the file contents. - - Raises: - TimeoutError: If all retry attempts fail. - - """ - cached_content = self._cache.get(file_name) - if cached_content is not None: - log_structured_entry( - "File retrieved from cache", - "INFO", - {"file": file_name, "bucket": self.bucket_name} - ) - file_stream = io.BytesIO(cached_content) - file_stream.name = file_name - return file_stream - - log_structured_entry( - "Starting file download from GCS", - "INFO", - {"file": file_name, "bucket": self.bucket_name} - ) - - storage_client = self._get_aio_storage() - last_exception: Exception | None = None - - for attempt in range(max_retries): - try: - content = await storage_client.download( - self.bucket_name, - file_name, - ) - self._cache.put(file_name, content) - file_stream = io.BytesIO(content) - file_stream.name = file_name - log_structured_entry( - "File downloaded successfully", - "INFO", - { - "file": file_name, - "bucket": self.bucket_name, - "size_bytes": len(content), - "attempt": attempt + 1 - } - ) - except TimeoutError as exc: - last_exception = exc - log_structured_entry( - f"Timeout downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})", - "WARNING", - {"error": str(exc)} - ) - except aiohttp.ClientResponseError as exc: - last_exception = exc - if ( - exc.status == HTTP_TOO_MANY_REQUESTS - or exc.status >= HTTP_SERVER_ERROR - ): - log_structured_entry( - f"HTTP {exc.status} downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})", - "WARNING", - {"status": exc.status, "message": str(exc)} - ) - else: - log_structured_entry( - f"Non-retryable HTTP error downloading gs://{self.bucket_name}/{file_name}", - "ERROR", - {"status": exc.status, "message": str(exc)} - ) - raise - else: - return file_stream - - if attempt < max_retries - 1: - delay = 0.5 * (2**attempt) - log_structured_entry( - "Retrying file download", - "INFO", - {"file": file_name, "delay_seconds": delay} - ) - await asyncio.sleep(delay) - - msg = ( - f"Failed to download gs://{self.bucket_name}/{file_name} " - f"after {max_retries} attempts" - ) - log_structured_entry( - "File download failed after all retries", - "ERROR", - { - "file": file_name, - "bucket": self.bucket_name, - "max_retries": max_retries, - "last_error": str(last_exception) - } - ) - raise TimeoutError(msg) from last_exception - - -class SearchResult(TypedDict): - """Structured response item returned by the vector search API.""" - - id: str - distance: float - content: str - - -class GoogleCloudVectorSearch(BaseGoogleCloudClient): - """Minimal async client for the Vertex AI Matching Engine REST API.""" - - def __init__( - self, - project_id: str, - location: str, - bucket: str, - index_name: str | None = None, - ) -> None: - """Store configuration used to issue Matching Engine queries.""" - super().__init__() - self.project_id = project_id - self.location = location - self.storage = GoogleCloudFileStorage(bucket=bucket) - self.index_name = index_name - self._async_token: Token | None = None - self._endpoint_domain: str | None = None - self._endpoint_name: str | None = None - - async def _async_get_auth_headers(self) -> dict[str, str]: - if self._async_token is None: - self._async_token = Token( - session=self._get_aio_session(), - scopes=[ - "https://www.googleapis.com/auth/cloud-platform", - ], - ) - access_token = await self._async_token.get() - return { - "Authorization": f"Bearer {access_token}", - "Content-Type": "application/json", - } - - async def close(self) -> None: - """Close aiohttp sessions for both vector search and storage.""" - await super().close() - await self.storage.close() - - def configure_index_endpoint( - self, - *, - name: str, - public_domain: str, - ) -> None: - """Persist the metadata needed to access a deployed endpoint.""" - if not name: - msg = "Index endpoint name must be a non-empty string." - raise ValueError(msg) - if not public_domain: - msg = "Index endpoint domain must be a non-empty public domain." - raise ValueError(msg) - self._endpoint_name = name - self._endpoint_domain = public_domain - - async def async_run_query( - self, - deployed_index_id: str, - query: Sequence[float], - limit: int, - source: SourceNamespace | None = None, - ) -> list[SearchResult]: - """Run an async similarity search via the REST API. - - Args: - deployed_index_id: The ID of the deployed index. - query: The embedding vector for the search query. - limit: Maximum number of nearest neighbors to return. - source: Optional namespace filter to restrict results by source. - - Returns: - A list of matched items with id, distance, and content. - - """ - if self._endpoint_domain is None or self._endpoint_name is None: - msg = ( - "Missing endpoint metadata. Call " - "`configure_index_endpoint` before querying." - ) - log_structured_entry( - "Vector search query failed - endpoint not configured", - "ERROR", - {"error": msg} - ) - raise RuntimeError(msg) - - domain = self._endpoint_domain - endpoint_id = self._endpoint_name.split("/")[-1] - url = ( - f"https://{domain}/v1/projects/{self.project_id}" - f"/locations/{self.location}" - f"/indexEndpoints/{endpoint_id}:findNeighbors" - ) - - log_structured_entry( - "Starting vector search query", - "INFO", - { - "deployed_index_id": deployed_index_id, - "neighbor_count": limit, - "endpoint_id": endpoint_id, - "embedding_dimension": len(query) - } - ) - - datapoint: dict = {"feature_vector": list(query)} - if source is not None: - datapoint["restricts"] = [ - {"namespace": "source", "allow_list": [source.value]}, - ] - payload = { - "deployed_index_id": deployed_index_id, - "queries": [ - { - "datapoint": datapoint, - "neighbor_count": limit, - }, - ], - } - - try: - headers = await self._async_get_auth_headers() - session = self._get_aio_session() - async with session.post( - url, - json=payload, - headers=headers, - ) as response: - if not response.ok: - body = await response.text() - msg = f"findNeighbors returned {response.status}: {body}" - log_structured_entry( - "Vector search API request failed", - "ERROR", - { - "status": response.status, - "response_body": body, - "deployed_index_id": deployed_index_id - } - ) - raise RuntimeError(msg) - data = await response.json() - - neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", []) - log_structured_entry( - "Vector search API request successful", - "INFO", - { - "neighbors_found": len(neighbors), - "deployed_index_id": deployed_index_id - } - ) - - if not neighbors: - log_structured_entry( - "No neighbors found in vector search", - "WARNING", - {"deployed_index_id": deployed_index_id} - ) - return [] - - # Fetch content for all neighbors - content_tasks = [] - for neighbor in neighbors: - datapoint_id = neighbor["datapoint"]["datapointId"] - file_path = f"{self.index_name}/contents/{datapoint_id}.md" - content_tasks.append( - self.storage.async_get_file_stream(file_path), - ) - - log_structured_entry( - "Fetching content for search results", - "INFO", - {"file_count": len(content_tasks)} - ) - - file_streams = await asyncio.gather(*content_tasks) - results: list[SearchResult] = [] - for neighbor, stream in zip( - neighbors, - file_streams, - strict=True, - ): - results.append( - SearchResult( - id=neighbor["datapoint"]["datapointId"], - distance=neighbor["distance"], - content=stream.read().decode("utf-8"), - ), - ) - - log_structured_entry( - "Vector search completed successfully", - "INFO", - { - "results_count": len(results), - "deployed_index_id": deployed_index_id - } - ) - return results - - except Exception as e: - log_structured_entry( - "Vector search query failed with exception", - "ERROR", - { - "error": str(e), - "error_type": type(e).__name__, - "deployed_index_id": deployed_index_id - } - ) - raise - - -# --------------------------------------------------------------------------- -# MCP Server -# --------------------------------------------------------------------------- - - -@dataclass -class AppContext: - """Shared resources initialised once at server startup.""" - - vector_search: GoogleCloudVectorSearch - genai_client: genai.Client - settings: Settings - - -async def _validate_genai_access(genai_client: genai.Client, cfg: Settings) -> str | None: - """Validate GenAI embedding access. - - Returns: - Error message if validation fails, None if successful. - """ - log_structured_entry("Validating GenAI embedding access", "INFO") - try: - test_response = await genai_client.aio.models.embed_content( - model=cfg.embedding_model, - contents="test", - config=genai_types.EmbedContentConfig( - task_type="RETRIEVAL_QUERY", - ), - ) - if test_response and test_response.embeddings: - embedding_values = test_response.embeddings[0].values - log_structured_entry( - "GenAI embedding validation successful", - "INFO", - {"embedding_dimension": len(embedding_values) if embedding_values else 0} - ) - return None - else: - msg = "Embedding validation returned empty response" - log_structured_entry(msg, "WARNING") - return msg - except Exception as e: - log_structured_entry( - "Failed to validate GenAI embedding access - service may not work correctly", - "WARNING", - {"error": str(e), "error_type": type(e).__name__} - ) - return f"GenAI: {str(e)}" - - -async def _validate_gcs_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str | None: - """Validate GCS bucket access. - - Returns: - Error message if validation fails, None if successful. - """ - log_structured_entry( - "Validating GCS bucket access", - "INFO", - {"bucket": cfg.bucket} - ) - try: - session = vs.storage._get_aio_session() - token_obj = Token( - session=session, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - access_token = await token_obj.get() - headers = {"Authorization": f"Bearer {access_token}"} - - async with session.get( - f"https://storage.googleapis.com/storage/v1/b/{cfg.bucket}/o?maxResults=1", - headers=headers, - ) as response: - if response.status == 403: - msg = f"Access denied to bucket '{cfg.bucket}'. Check permissions." - log_structured_entry( - "GCS bucket validation failed - access denied - service may not work correctly", - "WARNING", - {"bucket": cfg.bucket, "status": response.status} - ) - return msg - elif response.status == 404: - msg = f"Bucket '{cfg.bucket}' not found. Check bucket name and project." - log_structured_entry( - "GCS bucket validation failed - not found - service may not work correctly", - "WARNING", - {"bucket": cfg.bucket, "status": response.status} - ) - return msg - elif not response.ok: - body = await response.text() - msg = f"Failed to access bucket '{cfg.bucket}': {response.status}" - log_structured_entry( - "GCS bucket validation failed - service may not work correctly", - "WARNING", - {"bucket": cfg.bucket, "status": response.status, "response": body} - ) - return msg - else: - log_structured_entry( - "GCS bucket validation successful", - "INFO", - {"bucket": cfg.bucket} - ) - return None - except Exception as e: - log_structured_entry( - "Failed to validate GCS bucket access - service may not work correctly", - "WARNING", - {"error": str(e), "error_type": type(e).__name__, "bucket": cfg.bucket} - ) - return f"GCS: {str(e)}" - - -async def _validate_vector_search_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str | None: - """Validate vector search endpoint access. - - Returns: - Error message if validation fails, None if successful. - """ - log_structured_entry( - "Validating vector search endpoint access", - "INFO", - {"endpoint_name": cfg.endpoint_name} - ) - try: - headers = await vs._async_get_auth_headers() - session = vs._get_aio_session() - endpoint_url = ( - f"https://{cfg.location}-aiplatform.googleapis.com/v1/{cfg.endpoint_name}" - ) - - async with session.get(endpoint_url, headers=headers) as response: - if response.status == 403: - msg = f"Access denied to endpoint '{cfg.endpoint_name}'. Check permissions." - log_structured_entry( - "Vector search endpoint validation failed - access denied - service may not work correctly", - "WARNING", - {"endpoint": cfg.endpoint_name, "status": response.status} - ) - return msg - elif response.status == 404: - msg = f"Endpoint '{cfg.endpoint_name}' not found. Check endpoint name and project." - log_structured_entry( - "Vector search endpoint validation failed - not found - service may not work correctly", - "WARNING", - {"endpoint": cfg.endpoint_name, "status": response.status} - ) - return msg - elif not response.ok: - body = await response.text() - msg = f"Failed to access endpoint '{cfg.endpoint_name}': {response.status}" - log_structured_entry( - "Vector search endpoint validation failed - service may not work correctly", - "WARNING", - {"endpoint": cfg.endpoint_name, "status": response.status, "response": body} - ) - return msg - else: - log_structured_entry( - "Vector search endpoint validation successful", - "INFO", - {"endpoint": cfg.endpoint_name} - ) - return None - except Exception as e: - log_structured_entry( - "Failed to validate vector search endpoint access - service may not work correctly", - "WARNING", - {"error": str(e), "error_type": type(e).__name__, "endpoint": cfg.endpoint_name} - ) - return f"Vector Search: {str(e)}" - - -@asynccontextmanager -async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]: - """Create and configure the vector-search client for the server lifetime.""" - log_structured_entry( - "Initializing MCP server", - "INFO", - { - "project_id": cfg.project_id, - "location": cfg.location, - "bucket": cfg.bucket, - "index_name": cfg.index_name, - } - ) - - vs: GoogleCloudVectorSearch | None = None - try: - # Initialize vector search client - log_structured_entry("Creating GoogleCloudVectorSearch client", "INFO") - vs = GoogleCloudVectorSearch( - project_id=cfg.project_id, - location=cfg.location, - bucket=cfg.bucket, - index_name=cfg.index_name, - ) - - # Configure endpoint - log_structured_entry( - "Configuring index endpoint", - "INFO", - { - "endpoint_name": cfg.endpoint_name, - "endpoint_domain": cfg.endpoint_domain, - } - ) - vs.configure_index_endpoint( - name=cfg.endpoint_name, - public_domain=cfg.endpoint_domain, - ) - - # Initialize GenAI client - log_structured_entry( - "Creating GenAI client", - "INFO", - {"project_id": cfg.project_id, "location": cfg.location} - ) - genai_client = genai.Client( - vertexai=True, - project=cfg.project_id, - location=cfg.location, - ) - - # Validate credentials and configuration by testing actual resources - # These validations are non-blocking - errors are logged but won't stop startup - log_structured_entry("Starting validation of credentials and resources", "INFO") - - validation_errors = [] - - # Run all validations - genai_error = await _validate_genai_access(genai_client, cfg) - if genai_error: - validation_errors.append(genai_error) - - gcs_error = await _validate_gcs_access(vs, cfg) - if gcs_error: - validation_errors.append(gcs_error) - - vs_error = await _validate_vector_search_access(vs, cfg) - if vs_error: - validation_errors.append(vs_error) - - # Summary of validations - if validation_errors: - log_structured_entry( - "MCP server started with validation errors - service may not work correctly", - "WARNING", - {"validation_errors": validation_errors, "error_count": len(validation_errors)} - ) - else: - log_structured_entry("All validations passed - MCP server initialization complete", "INFO") - - yield AppContext( - vector_search=vs, - genai_client=genai_client, - settings=cfg, - ) - - except Exception as e: - log_structured_entry( - "Failed to initialize MCP server", - "ERROR", - { - "error": str(e), - "error_type": type(e).__name__, - } - ) - raise - finally: - log_structured_entry("MCP server lifespan ending", "INFO") - # Clean up resources - if vs is not None: - try: - await vs.close() - log_structured_entry("Closed aiohttp sessions", "INFO") - except Exception as e: - log_structured_entry( - "Error closing aiohttp sessions", - "WARNING", - {"error": str(e), "error_type": type(e).__name__} - ) - +from .models import AppContext, SourceNamespace +from .server import lifespan +from .services.search import filter_search_results, format_search_results, generate_query_embedding mcp = FastMCP( "knowledge-search", @@ -733,108 +19,6 @@ mcp = FastMCP( ) -async def _generate_query_embedding( - genai_client: genai.Client, - embedding_model: str, - query: str, -) -> tuple[list[float], str | None]: - """Generate embedding for search query. - - Returns: - Tuple of (embedding vector, error message). Error message is None on success. - """ - if not query or not query.strip(): - return ([], "Error: Query cannot be empty") - - log_structured_entry("Generating query embedding", "INFO") - try: - response = await genai_client.aio.models.embed_content( - model=embedding_model, - contents=query, - config=genai_types.EmbedContentConfig( - task_type="RETRIEVAL_QUERY", - ), - ) - embedding = response.embeddings[0].values - return (embedding, None) - except Exception as e: - error_type = type(e).__name__ - error_msg = str(e) - - # Check if it's a rate limit error - if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg: - log_structured_entry( - "Rate limit exceeded while generating embedding", - "WARNING", - { - "error": error_msg, - "error_type": error_type, - "query": query[:100] - } - ) - return ([], "Error: API rate limit exceeded. Please try again later.") - else: - log_structured_entry( - "Failed to generate query embedding", - "ERROR", - { - "error": error_msg, - "error_type": error_type, - "query": query[:100] - } - ) - return ([], f"Error generating embedding: {error_msg}") - - -def _filter_search_results( - results: list[SearchResult], - min_similarity: float = 0.6, - top_percent: float = 0.9, -) -> list[SearchResult]: - """Filter search results by similarity thresholds. - - Args: - results: Raw search results from vector search. - min_similarity: Minimum similarity score (distance) to include. - top_percent: Keep results within this percentage of the top score. - - Returns: - Filtered list of search results. - """ - if not results: - return [] - - max_sim = max(r["distance"] for r in results) - cutoff = max_sim * top_percent - - filtered = [ - s - for s in results - if s["distance"] > cutoff and s["distance"] > min_similarity - ] - - return filtered - - -def _format_search_results(results: list[SearchResult]) -> str: - """Format search results as XML-like documents. - - Args: - results: List of search results to format. - - Returns: - Formatted string with document tags. - """ - if not results: - return "No relevant documents found for your query." - - formatted_results = [ - f"\n{result['content']}\n" - for i, result in enumerate(results, start=1) - ] - return "\n".join(formatted_results) - - @mcp.tool() async def knowledge_search( query: str, @@ -865,7 +49,7 @@ async def knowledge_search( try: # Generate embedding for the query - embedding, error = await _generate_query_embedding( + embedding, error = await generate_query_embedding( app.genai_client, app.settings.embedding_model, query, @@ -903,7 +87,7 @@ async def knowledge_search( return f"Error performing vector search: {str(e)}" # Apply similarity filtering - filtered_results = _filter_search_results(search_results) + filtered_results = filter_search_results(search_results) log_structured_entry( "knowledge_search completed successfully", @@ -926,7 +110,7 @@ async def knowledge_search( {"query": query[:100]} ) - return _format_search_results(filtered_results) + return format_search_results(filtered_results) except Exception as e: # Catch-all for any unexpected errors diff --git a/src/knowledge_search_mcp/models.py b/src/knowledge_search_mcp/models.py new file mode 100644 index 0000000..37e412e --- /dev/null +++ b/src/knowledge_search_mcp/models.py @@ -0,0 +1,37 @@ +# ruff: noqa: INP001 +"""Domain models for knowledge search MCP server.""" + +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, TypedDict + +if TYPE_CHECKING: + from google import genai + + from .clients.vector_search import GoogleCloudVectorSearch + from .config import Settings + + +class SourceNamespace(str, Enum): + """Allowed values for the 'source' namespace filter.""" + + EDUCACION_FINANCIERA = "Educacion Financiera" + PRODUCTOS_Y_SERVICIOS = "Productos y Servicios" + FUNCIONALIDADES_APP_MOVIL = "Funcionalidades de la App Movil" + + +class SearchResult(TypedDict): + """Structured response item returned by the vector search API.""" + + id: str + distance: float + content: str + + +@dataclass +class AppContext: + """Shared resources initialised once at server startup.""" + + vector_search: "GoogleCloudVectorSearch" + genai_client: "genai.Client" + settings: "Settings" diff --git a/src/knowledge_search_mcp/server.py b/src/knowledge_search_mcp/server.py new file mode 100644 index 0000000..ca7591e --- /dev/null +++ b/src/knowledge_search_mcp/server.py @@ -0,0 +1,129 @@ +# ruff: noqa: INP001 +"""MCP server lifecycle management.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +from google import genai +from mcp.server.fastmcp import FastMCP + +from .clients.vector_search import GoogleCloudVectorSearch +from .config import Settings, cfg +from .logging import log_structured_entry +from .models import AppContext +from .services.validation import ( + validate_genai_access, + validate_gcs_access, + validate_vector_search_access, +) + + +@asynccontextmanager +async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]: + """Create and configure the vector-search client for the server lifetime.""" + log_structured_entry( + "Initializing MCP server", + "INFO", + { + "project_id": cfg.project_id, + "location": cfg.location, + "bucket": cfg.bucket, + "index_name": cfg.index_name, + } + ) + + vs: GoogleCloudVectorSearch | None = None + try: + # Initialize vector search client + log_structured_entry("Creating GoogleCloudVectorSearch client", "INFO") + vs = GoogleCloudVectorSearch( + project_id=cfg.project_id, + location=cfg.location, + bucket=cfg.bucket, + index_name=cfg.index_name, + ) + + # Configure endpoint + log_structured_entry( + "Configuring index endpoint", + "INFO", + { + "endpoint_name": cfg.endpoint_name, + "endpoint_domain": cfg.endpoint_domain, + } + ) + vs.configure_index_endpoint( + name=cfg.endpoint_name, + public_domain=cfg.endpoint_domain, + ) + + # Initialize GenAI client + log_structured_entry( + "Creating GenAI client", + "INFO", + {"project_id": cfg.project_id, "location": cfg.location} + ) + genai_client = genai.Client( + vertexai=True, + project=cfg.project_id, + location=cfg.location, + ) + + # Validate credentials and configuration by testing actual resources + # These validations are non-blocking - errors are logged but won't stop startup + log_structured_entry("Starting validation of credentials and resources", "INFO") + + validation_errors = [] + + # Run all validations + genai_error = await validate_genai_access(genai_client, cfg) + if genai_error: + validation_errors.append(genai_error) + + gcs_error = await validate_gcs_access(vs, cfg) + if gcs_error: + validation_errors.append(gcs_error) + + vs_error = await validate_vector_search_access(vs, cfg) + if vs_error: + validation_errors.append(vs_error) + + # Summary of validations + if validation_errors: + log_structured_entry( + "MCP server started with validation errors - service may not work correctly", + "WARNING", + {"validation_errors": validation_errors, "error_count": len(validation_errors)} + ) + else: + log_structured_entry("All validations passed - MCP server initialization complete", "INFO") + + yield AppContext( + vector_search=vs, + genai_client=genai_client, + settings=cfg, + ) + + except Exception as e: + log_structured_entry( + "Failed to initialize MCP server", + "ERROR", + { + "error": str(e), + "error_type": type(e).__name__, + } + ) + raise + finally: + log_structured_entry("MCP server lifespan ending", "INFO") + # Clean up resources + if vs is not None: + try: + await vs.close() + log_structured_entry("Closed aiohttp sessions", "INFO") + except Exception as e: + log_structured_entry( + "Error closing aiohttp sessions", + "WARNING", + {"error": str(e), "error_type": type(e).__name__} + ) diff --git a/src/knowledge_search_mcp/services/__init__.py b/src/knowledge_search_mcp/services/__init__.py new file mode 100644 index 0000000..6ea8345 --- /dev/null +++ b/src/knowledge_search_mcp/services/__init__.py @@ -0,0 +1,13 @@ +"""Service modules for business logic.""" + +from .search import filter_search_results, format_search_results, generate_query_embedding +from .validation import validate_genai_access, validate_gcs_access, validate_vector_search_access + +__all__ = [ + "filter_search_results", + "format_search_results", + "generate_query_embedding", + "validate_genai_access", + "validate_gcs_access", + "validate_vector_search_access", +] diff --git a/src/knowledge_search_mcp/services/search.py b/src/knowledge_search_mcp/services/search.py new file mode 100644 index 0000000..b33dd9e --- /dev/null +++ b/src/knowledge_search_mcp/services/search.py @@ -0,0 +1,110 @@ +# ruff: noqa: INP001 +"""Search helper functions.""" + +from google import genai +from google.genai import types as genai_types + +from ..logging import log_structured_entry +from ..models import SearchResult + + +async def generate_query_embedding( + genai_client: genai.Client, + embedding_model: str, + query: str, +) -> tuple[list[float], str | None]: + """Generate embedding for search query. + + Returns: + Tuple of (embedding vector, error message). Error message is None on success. + """ + if not query or not query.strip(): + return ([], "Error: Query cannot be empty") + + log_structured_entry("Generating query embedding", "INFO") + try: + response = await genai_client.aio.models.embed_content( + model=embedding_model, + contents=query, + config=genai_types.EmbedContentConfig( + task_type="RETRIEVAL_QUERY", + ), + ) + embedding = response.embeddings[0].values + return (embedding, None) + except Exception as e: + error_type = type(e).__name__ + error_msg = str(e) + + # Check if it's a rate limit error + if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg: + log_structured_entry( + "Rate limit exceeded while generating embedding", + "WARNING", + { + "error": error_msg, + "error_type": error_type, + "query": query[:100] + } + ) + return ([], "Error: API rate limit exceeded. Please try again later.") + else: + log_structured_entry( + "Failed to generate query embedding", + "ERROR", + { + "error": error_msg, + "error_type": error_type, + "query": query[:100] + } + ) + return ([], f"Error generating embedding: {error_msg}") + + +def filter_search_results( + results: list[SearchResult], + min_similarity: float = 0.6, + top_percent: float = 0.9, +) -> list[SearchResult]: + """Filter search results by similarity thresholds. + + Args: + results: Raw search results from vector search. + min_similarity: Minimum similarity score (distance) to include. + top_percent: Keep results within this percentage of the top score. + + Returns: + Filtered list of search results. + """ + if not results: + return [] + + max_sim = max(r["distance"] for r in results) + cutoff = max_sim * top_percent + + filtered = [ + s + for s in results + if s["distance"] > cutoff and s["distance"] > min_similarity + ] + + return filtered + + +def format_search_results(results: list[SearchResult]) -> str: + """Format search results as XML-like documents. + + Args: + results: List of search results to format. + + Returns: + Formatted string with document tags. + """ + if not results: + return "No relevant documents found for your query." + + formatted_results = [ + f"\n{result['content']}\n" + for i, result in enumerate(results, start=1) + ] + return "\n".join(formatted_results) diff --git a/src/knowledge_search_mcp/services/validation.py b/src/knowledge_search_mcp/services/validation.py new file mode 100644 index 0000000..23cbb7f --- /dev/null +++ b/src/knowledge_search_mcp/services/validation.py @@ -0,0 +1,171 @@ +# ruff: noqa: INP001 +"""Validation functions for Google Cloud services.""" + +from gcloud.aio.auth import Token +from google import genai +from google.genai import types as genai_types + +from ..clients.vector_search import GoogleCloudVectorSearch +from ..config import Settings +from ..logging import log_structured_entry + + +async def validate_genai_access(genai_client: genai.Client, cfg: Settings) -> str | None: + """Validate GenAI embedding access. + + Returns: + Error message if validation fails, None if successful. + """ + log_structured_entry("Validating GenAI embedding access", "INFO") + try: + test_response = await genai_client.aio.models.embed_content( + model=cfg.embedding_model, + contents="test", + config=genai_types.EmbedContentConfig( + task_type="RETRIEVAL_QUERY", + ), + ) + if test_response and test_response.embeddings: + embedding_values = test_response.embeddings[0].values + log_structured_entry( + "GenAI embedding validation successful", + "INFO", + {"embedding_dimension": len(embedding_values) if embedding_values else 0} + ) + return None + else: + msg = "Embedding validation returned empty response" + log_structured_entry(msg, "WARNING") + return msg + except Exception as e: + log_structured_entry( + "Failed to validate GenAI embedding access - service may not work correctly", + "WARNING", + {"error": str(e), "error_type": type(e).__name__} + ) + return f"GenAI: {str(e)}" + + +async def validate_gcs_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str | None: + """Validate GCS bucket access. + + Returns: + Error message if validation fails, None if successful. + """ + log_structured_entry( + "Validating GCS bucket access", + "INFO", + {"bucket": cfg.bucket} + ) + try: + session = vs.storage._get_aio_session() + token_obj = Token( + session=session, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + access_token = await token_obj.get() + headers = {"Authorization": f"Bearer {access_token}"} + + async with session.get( + f"https://storage.googleapis.com/storage/v1/b/{cfg.bucket}/o?maxResults=1", + headers=headers, + ) as response: + if response.status == 403: + msg = f"Access denied to bucket '{cfg.bucket}'. Check permissions." + log_structured_entry( + "GCS bucket validation failed - access denied - service may not work correctly", + "WARNING", + {"bucket": cfg.bucket, "status": response.status} + ) + return msg + elif response.status == 404: + msg = f"Bucket '{cfg.bucket}' not found. Check bucket name and project." + log_structured_entry( + "GCS bucket validation failed - not found - service may not work correctly", + "WARNING", + {"bucket": cfg.bucket, "status": response.status} + ) + return msg + elif not response.ok: + body = await response.text() + msg = f"Failed to access bucket '{cfg.bucket}': {response.status}" + log_structured_entry( + "GCS bucket validation failed - service may not work correctly", + "WARNING", + {"bucket": cfg.bucket, "status": response.status, "response": body} + ) + return msg + else: + log_structured_entry( + "GCS bucket validation successful", + "INFO", + {"bucket": cfg.bucket} + ) + return None + except Exception as e: + log_structured_entry( + "Failed to validate GCS bucket access - service may not work correctly", + "WARNING", + {"error": str(e), "error_type": type(e).__name__, "bucket": cfg.bucket} + ) + return f"GCS: {str(e)}" + + +async def validate_vector_search_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str | None: + """Validate vector search endpoint access. + + Returns: + Error message if validation fails, None if successful. + """ + log_structured_entry( + "Validating vector search endpoint access", + "INFO", + {"endpoint_name": cfg.endpoint_name} + ) + try: + headers = await vs._async_get_auth_headers() + session = vs._get_aio_session() + endpoint_url = ( + f"https://{cfg.location}-aiplatform.googleapis.com/v1/{cfg.endpoint_name}" + ) + + async with session.get(endpoint_url, headers=headers) as response: + if response.status == 403: + msg = f"Access denied to endpoint '{cfg.endpoint_name}'. Check permissions." + log_structured_entry( + "Vector search endpoint validation failed - access denied - service may not work correctly", + "WARNING", + {"endpoint": cfg.endpoint_name, "status": response.status} + ) + return msg + elif response.status == 404: + msg = f"Endpoint '{cfg.endpoint_name}' not found. Check endpoint name and project." + log_structured_entry( + "Vector search endpoint validation failed - not found - service may not work correctly", + "WARNING", + {"endpoint": cfg.endpoint_name, "status": response.status} + ) + return msg + elif not response.ok: + body = await response.text() + msg = f"Failed to access endpoint '{cfg.endpoint_name}': {response.status}" + log_structured_entry( + "Vector search endpoint validation failed - service may not work correctly", + "WARNING", + {"endpoint": cfg.endpoint_name, "status": response.status, "response": body} + ) + return msg + else: + log_structured_entry( + "Vector search endpoint validation successful", + "INFO", + {"endpoint": cfg.endpoint_name} + ) + return None + except Exception as e: + log_structured_entry( + "Failed to validate vector search endpoint access - service may not work correctly", + "WARNING", + {"error": str(e), "error_type": type(e).__name__, "endpoint": cfg.endpoint_name} + ) + return f"Vector Search: {str(e)}" diff --git a/src/knowledge_search_mcp/utils/__init__.py b/src/knowledge_search_mcp/utils/__init__.py new file mode 100644 index 0000000..b63ef0f --- /dev/null +++ b/src/knowledge_search_mcp/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utility modules for knowledge search MCP server.""" + +from .cache import LRUCache + +__all__ = ["LRUCache"] diff --git a/src/knowledge_search_mcp/utils/cache.py b/src/knowledge_search_mcp/utils/cache.py new file mode 100644 index 0000000..2235f66 --- /dev/null +++ b/src/knowledge_search_mcp/utils/cache.py @@ -0,0 +1,33 @@ +# ruff: noqa: INP001 +"""LRU cache implementation.""" + +from collections import OrderedDict + + +class LRUCache: + """Simple LRU cache with size limit.""" + + def __init__(self, max_size: int = 100) -> None: + """Initialize cache with maximum size.""" + self.cache: OrderedDict[str, bytes] = OrderedDict() + self.max_size = max_size + + def get(self, key: str) -> bytes | None: + """Get item from cache, returning None if not found.""" + if key not in self.cache: + return None + # Move to end to mark as recently used + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key: str, value: bytes) -> None: + """Put item in cache, evicting oldest if at capacity.""" + if key in self.cache: + self.cache.move_to_end(key) + self.cache[key] = value + if len(self.cache) > self.max_size: + self.cache.popitem(last=False) + + def __contains__(self, key: str) -> bool: + """Check if key exists in cache.""" + return key in self.cache diff --git a/tests/test_search.py b/tests/test_search.py index ad82b72..ee22f06 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from knowledge_search_mcp.main import ( +from knowledge_search_mcp import ( GoogleCloudFileStorage, GoogleCloudVectorSearch, LRUCache,