Compare commits

..

1 Commits

Author SHA1 Message Date
e81aac2e29 Add semantic caching 2026-03-04 06:17:47 +00:00
17 changed files with 223 additions and 355 deletions

View File

@@ -1,43 +0,0 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.12"
- run: uv sync --frozen
- name: Ruff check
run: uv run ruff check
- name: Ruff format check
run: uv run ruff format --check
typecheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.12"
- run: uv sync --frozen
- name: Type check
run: uv run ty check
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.12"
- run: uv sync --frozen
- name: Run tests
run: uv run pytest --cov

View File

@@ -1,3 +0,0 @@
Use `uv` for project management
Linter: `uv run ruff check`
Type-checking: `uv run ty check`

View File

@@ -38,19 +38,3 @@ pythonpath = ["."]
[build-system]
requires = ["uv_build>=0.8.3,<0.9.0"]
build-backend = "uv_build"
[tool.ruff]
exclude = ["scripts", "tests"]
[tool.ty.src]
exclude = ["scripts", "tests"]
[tool.ruff.lint]
select = ['ALL']
ignore = [
'D203', # one-blank-line-before-class
'D213', # multi-line-summary-second-line
'COM812', # missing-trailing-comma
'ANN401', # dynamically-typed-any
'ERA001', # commented-out-code
]

View File

@@ -6,10 +6,10 @@ from .models import AppContext, SearchResult, SourceNamespace
from .utils.cache import LRUCache
__all__ = [
"AppContext",
"GoogleCloudFileStorage",
"GoogleCloudVectorSearch",
"LRUCache",
"SearchResult",
"SourceNamespace",
"SearchResult",
"AppContext",
"LRUCache",
]

View File

@@ -1,3 +1,4 @@
# ruff: noqa: INP001
"""MCP server for semantic search over Vertex AI Vector Search."""
import time
@@ -8,11 +9,7 @@ from .config import _args
from .logging import log_structured_entry
from .models import AppContext, SourceNamespace
from .server import lifespan
from .services.search import (
filter_search_results,
format_search_results,
generate_query_embedding,
)
from .services.search import filter_search_results, format_search_results, generate_query_embedding
mcp = FastMCP(
"knowledge-search",
@@ -47,7 +44,7 @@ async def knowledge_search(
log_structured_entry(
"knowledge_search request received",
"INFO",
{"query": query[:100]}, # Log first 100 chars of query
{"query": query[:100]} # Log first 100 chars of query
)
try:
@@ -64,7 +61,7 @@ async def knowledge_search(
log_structured_entry(
"Query embedding generated successfully",
"INFO",
{"time_ms": round((t_embed - t0) * 1000, 1)},
{"time_ms": round((t_embed - t0) * 1000, 1)}
)
# Check semantic cache before vector search
@@ -94,13 +91,17 @@ async def knowledge_search(
source=source,
)
t_search = time.perf_counter()
except Exception as e: # noqa: BLE001
except Exception as e:
log_structured_entry(
"Vector search failed",
"ERROR",
{"error": str(e), "error_type": type(e).__name__, "query": query[:100]},
{
"error": str(e),
"error_type": type(e).__name__,
"query": query[:100]
}
)
return f"Error performing vector search: {e!s}"
return f"Error performing vector search: {str(e)}"
# Apply similarity filtering
filtered_results = filter_search_results(search_results)
@@ -116,7 +117,7 @@ async def knowledge_search(
"results_count": len(filtered_results),
"chunks": [s["id"] for s in filtered_results],
"cache_hit": False,
},
}
)
# Format and return results
@@ -124,7 +125,9 @@ async def knowledge_search(
if not filtered_results:
log_structured_entry(
"No results found for query", "INFO", {"query": query[:100]}
"No results found for query",
"INFO",
{"query": query[:100]}
)
# Store in semantic cache (only for unfiltered queries with results)
@@ -133,14 +136,18 @@ async def knowledge_search(
return formatted
except Exception as e: # noqa: BLE001
except Exception as e:
# Catch-all for any unexpected errors
log_structured_entry(
"Unexpected error in knowledge_search",
"ERROR",
{"error": str(e), "error_type": type(e).__name__, "query": query[:100]},
{
"error": str(e),
"error_type": type(e).__name__,
"query": query[:100]
}
)
return f"Unexpected error during search: {e!s}"
return f"Unexpected error during search: {str(e)}"
def main() -> None:

View File

@@ -1,3 +1,4 @@
# ruff: noqa: INP001
"""Base client with shared aiohttp session management."""
import aiohttp

View File

@@ -1,3 +1,4 @@
# ruff: noqa: INP001
"""Google Cloud Storage client with caching."""
import asyncio
@@ -7,9 +8,8 @@ from typing import BinaryIO
import aiohttp
from gcloud.aio.storage import Storage
from knowledge_search_mcp.logging import log_structured_entry
from knowledge_search_mcp.utils.cache import LRUCache
from ..logging import log_structured_entry
from ..utils.cache import LRUCache
from .base import BaseGoogleCloudClient
HTTP_TOO_MANY_REQUESTS = 429
@@ -56,7 +56,7 @@ class GoogleCloudFileStorage(BaseGoogleCloudClient):
log_structured_entry(
"File retrieved from cache",
"INFO",
{"file": file_name, "bucket": self.bucket_name},
{"file": file_name, "bucket": self.bucket_name}
)
file_stream = io.BytesIO(cached_content)
file_stream.name = file_name
@@ -65,7 +65,7 @@ class GoogleCloudFileStorage(BaseGoogleCloudClient):
log_structured_entry(
"Starting file download from GCS",
"INFO",
{"file": file_name, "bucket": self.bucket_name},
{"file": file_name, "bucket": self.bucket_name}
)
storage_client = self._get_aio_storage()
@@ -87,18 +87,15 @@ class GoogleCloudFileStorage(BaseGoogleCloudClient):
"file": file_name,
"bucket": self.bucket_name,
"size_bytes": len(content),
"attempt": attempt + 1,
},
"attempt": attempt + 1
}
)
except TimeoutError as exc:
last_exception = exc
log_structured_entry(
(
f"Timeout downloading gs://{self.bucket_name}/{file_name} "
f"(attempt {attempt + 1}/{max_retries})"
),
f"Timeout downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})",
"WARNING",
{"error": str(exc)},
{"error": str(exc)}
)
except aiohttp.ClientResponseError as exc:
last_exception = exc
@@ -107,18 +104,15 @@ class GoogleCloudFileStorage(BaseGoogleCloudClient):
or exc.status >= HTTP_SERVER_ERROR
):
log_structured_entry(
(
f"HTTP {exc.status} downloading gs://{self.bucket_name}/"
f"{file_name} (attempt {attempt + 1}/{max_retries})"
),
f"HTTP {exc.status} downloading gs://{self.bucket_name}/{file_name} (attempt {attempt + 1}/{max_retries})",
"WARNING",
{"status": exc.status, "message": str(exc)},
{"status": exc.status, "message": str(exc)}
)
else:
log_structured_entry(
f"Non-retryable HTTP error downloading gs://{self.bucket_name}/{file_name}",
"ERROR",
{"status": exc.status, "message": str(exc)},
{"status": exc.status, "message": str(exc)}
)
raise
else:
@@ -129,7 +123,7 @@ class GoogleCloudFileStorage(BaseGoogleCloudClient):
log_structured_entry(
"Retrying file download",
"INFO",
{"file": file_name, "delay_seconds": delay},
{"file": file_name, "delay_seconds": delay}
)
await asyncio.sleep(delay)
@@ -144,7 +138,7 @@ class GoogleCloudFileStorage(BaseGoogleCloudClient):
"file": file_name,
"bucket": self.bucket_name,
"max_retries": max_retries,
"last_error": str(last_exception),
},
"last_error": str(last_exception)
}
)
raise TimeoutError(msg) from last_exception

View File

@@ -1,3 +1,4 @@
# ruff: noqa: INP001
"""Google Cloud Vector Search client."""
import asyncio
@@ -5,9 +6,8 @@ from collections.abc import Sequence
from gcloud.aio.auth import Token
from knowledge_search_mcp.logging import log_structured_entry
from knowledge_search_mcp.models import SearchResult, SourceNamespace
from ..logging import log_structured_entry
from ..models import SearchResult, SourceNamespace
from .base import BaseGoogleCloudClient
from .storage import GoogleCloudFileStorage
@@ -94,7 +94,7 @@ class GoogleCloudVectorSearch(BaseGoogleCloudClient):
log_structured_entry(
"Vector search query failed - endpoint not configured",
"ERROR",
{"error": msg},
{"error": msg}
)
raise RuntimeError(msg)
@@ -113,8 +113,8 @@ class GoogleCloudVectorSearch(BaseGoogleCloudClient):
"deployed_index_id": deployed_index_id,
"neighbor_count": limit,
"endpoint_id": endpoint_id,
"embedding_dimension": len(query),
},
"embedding_dimension": len(query)
}
)
datapoint: dict = {"feature_vector": list(query)}
@@ -149,10 +149,10 @@ class GoogleCloudVectorSearch(BaseGoogleCloudClient):
{
"status": response.status,
"response_body": body,
"deployed_index_id": deployed_index_id,
},
"deployed_index_id": deployed_index_id
}
)
raise RuntimeError(msg) # noqa: TRY301
raise RuntimeError(msg)
data = await response.json()
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
@@ -161,15 +161,15 @@ class GoogleCloudVectorSearch(BaseGoogleCloudClient):
"INFO",
{
"neighbors_found": len(neighbors),
"deployed_index_id": deployed_index_id,
},
"deployed_index_id": deployed_index_id
}
)
if not neighbors:
log_structured_entry(
"No neighbors found in vector search",
"WARNING",
{"deployed_index_id": deployed_index_id},
{"deployed_index_id": deployed_index_id}
)
return []
@@ -185,7 +185,7 @@ class GoogleCloudVectorSearch(BaseGoogleCloudClient):
log_structured_entry(
"Fetching content for search results",
"INFO",
{"file_count": len(content_tasks)},
{"file_count": len(content_tasks)}
)
file_streams = await asyncio.gather(*content_tasks)
@@ -206,9 +206,12 @@ class GoogleCloudVectorSearch(BaseGoogleCloudClient):
log_structured_entry(
"Vector search completed successfully",
"INFO",
{"results_count": len(results), "deployed_index_id": deployed_index_id},
{
"results_count": len(results),
"deployed_index_id": deployed_index_id
}
)
return results # noqa: TRY300
return results
except Exception as e:
log_structured_entry(
@@ -217,7 +220,7 @@ class GoogleCloudVectorSearch(BaseGoogleCloudClient):
{
"error": str(e),
"error_type": type(e).__name__,
"deployed_index_id": deployed_index_id,
},
"deployed_index_id": deployed_index_id
}
)
raise

View File

@@ -1,14 +1,7 @@
"""Configuration management for the MCP server."""
import argparse
import os
import sys
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
YamlConfigSettingsSource,
)
import argparse
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, YamlConfigSettingsSource
def _parse_args() -> argparse.Namespace:
@@ -21,7 +14,7 @@ def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
return argparse.Namespace(
transport="stdio",
host="0.0.0.0", # noqa: S104
host="0.0.0.0",
port=8080,
config=os.environ.get("CONFIG_FILE", "config.yaml"),
)
@@ -32,7 +25,7 @@ def _parse_args() -> argparse.Namespace:
choices=["stdio", "sse", "streamable-http"],
default="stdio",
)
parser.add_argument("--host", default="0.0.0.0") # noqa: S104
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080)
parser.add_argument(
"--config",
@@ -43,7 +36,6 @@ def _parse_args() -> argparse.Namespace:
_args = _parse_args()
class Settings(BaseSettings):
"""Server configuration populated from env vars and a YAML config file."""
@@ -60,7 +52,6 @@ class Settings(BaseSettings):
search_limit: int = 10
log_name: str = "va_agent_evaluation_logs"
log_level: str = "INFO"
cloud_logging_enabled: bool = False
# Semantic cache (Redis)
redis_url: str | None = None
@@ -78,7 +69,6 @@ class Settings(BaseSettings):
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
"""Customize the order of settings sources to include YAML config."""
return (
init_settings,
env_settings,
@@ -94,7 +84,7 @@ _cfg: Settings | None = None
def get_config() -> Settings:
"""Get or create the singleton Settings instance."""
global _cfg # noqa: PLW0603
global _cfg
if _cfg is None:
_cfg = Settings.model_validate({})
return _cfg
@@ -104,8 +94,8 @@ def get_config() -> Settings:
class _ConfigProxy:
"""Proxy object that lazily loads config on attribute access."""
def __getattr__(self, name: str) -> object:
def __getattr__(self, name: str):
return getattr(get_config(), name)
cfg = _ConfigProxy()
cfg = _ConfigProxy() # type: ignore[assignment]

View File

@@ -1,22 +1,23 @@
"""Centralized Cloud Logging setup.
Uses CloudLoggingHandler (background thread) so logging does not add latency.
"""
Centralized Cloud Logging setup.
Uses CloudLoggingHandler (background thread) so logging does not add latency
"""
import logging
from typing import Literal
from typing import Optional, Dict, Literal
import google.cloud.logging
from google.cloud.logging.handlers import CloudLoggingHandler
from .config import get_config
_eval_log: logging.Logger | None = None
def _get_logger() -> logging.Logger:
"""Get or create the singleton evaluation logger."""
global _eval_log # noqa: PLW0603
global _eval_log
if _eval_log is not None:
return _eval_log
@@ -26,42 +27,30 @@ def _get_logger() -> logging.Logger:
_eval_log = logger
return logger
if cfg.cloud_logging_enabled:
try:
client = google.cloud.logging.Client(project=cfg.project_id)
handler = CloudLoggingHandler(client, name=cfg.log_name) # async transport
logger.addHandler(handler)
logger.setLevel(getattr(logging, cfg.log_level.upper()))
except Exception as e: # noqa: BLE001
except Exception as e:
# Fallback to console if Cloud Logging is unavailable (local dev)
logging.basicConfig(level=getattr(logging, cfg.log_level.upper()))
logger = logging.getLogger(cfg.log_name)
logger.warning("Cloud Logging setup failed; using console. Error: %s", e)
else:
logging.basicConfig(level=getattr(logging, cfg.log_level.upper()))
logger = logging.getLogger(cfg.log_name)
_eval_log = logger
return logger
def log_structured_entry(
message: str,
severity: Literal["INFO", "WARNING", "ERROR"],
custom_log: dict | None = None,
) -> None:
"""Emit a JSON-structured log row.
def log_structured_entry(message: str, severity: Literal["INFO", "WARNING", "ERROR"], custom_log: Optional[Dict] = None) -> None:
"""
Emit a JSON-structured log row.
Args:
message: Short label for the row (e.g., "Final agent turn").
severity: "INFO" | "WARNING" | "ERROR"
custom_log: A dict with your structured payload.
"""
level = getattr(logging, severity.upper(), logging.INFO)
logger = _get_logger()
logger.log(
level,
message,
extra={"json_fields": {"message": message, "custom": custom_log or {}}},
)
logger.log(level, message, extra={"json_fields": {"message": message, "custom": custom_log or {}}})

View File

@@ -1,7 +1,8 @@
# ruff: noqa: INP001
"""Domain models for knowledge search MCP server."""
from dataclasses import dataclass
from enum import StrEnum
from enum import Enum
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
@@ -12,7 +13,7 @@ if TYPE_CHECKING:
from .services.semantic_cache import KnowledgeSemanticCache
class SourceNamespace(StrEnum):
class SourceNamespace(str, Enum):
"""Allowed values for the 'source' namespace filter."""
EDUCACION_FINANCIERA = "Educacion Financiera"

View File

@@ -1,3 +1,4 @@
# ruff: noqa: INP001
"""MCP server lifecycle management."""
from collections.abc import AsyncIterator
@@ -7,13 +8,13 @@ from google import genai
from mcp.server.fastmcp import FastMCP
from .clients.vector_search import GoogleCloudVectorSearch
from .config import get_config
from .config import Settings, cfg
from .logging import log_structured_entry
from .models import AppContext
from .services.semantic_cache import KnowledgeSemanticCache
from .services.validation import (
validate_gcs_access,
validate_genai_access,
validate_gcs_access,
validate_vector_search_access,
)
@@ -21,18 +22,15 @@ from .services.validation import (
@asynccontextmanager
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
"""Create and configure the vector-search client for the server lifetime."""
# Get config with proper types for initialization
config_for_init = get_config()
log_structured_entry(
"Initializing MCP server",
"INFO",
{
"project_id": config_for_init.project_id,
"location": config_for_init.location,
"bucket": config_for_init.bucket,
"index_name": config_for_init.index_name,
},
"project_id": cfg.project_id,
"location": cfg.location,
"bucket": cfg.bucket,
"index_name": cfg.index_name,
}
)
vs: GoogleCloudVectorSearch | None = None
@@ -40,10 +38,10 @@ async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
# Initialize vector search client
log_structured_entry("Creating GoogleCloudVectorSearch client", "INFO")
vs = GoogleCloudVectorSearch(
project_id=config_for_init.project_id,
location=config_for_init.location,
bucket=config_for_init.bucket,
index_name=config_for_init.index_name,
project_id=cfg.project_id,
location=cfg.location,
bucket=cfg.bucket,
index_name=cfg.index_name,
)
# Configure endpoint
@@ -51,28 +49,25 @@ async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
"Configuring index endpoint",
"INFO",
{
"endpoint_name": config_for_init.endpoint_name,
"endpoint_domain": config_for_init.endpoint_domain,
},
"endpoint_name": cfg.endpoint_name,
"endpoint_domain": cfg.endpoint_domain,
}
)
vs.configure_index_endpoint(
name=config_for_init.endpoint_name,
public_domain=config_for_init.endpoint_domain,
name=cfg.endpoint_name,
public_domain=cfg.endpoint_domain,
)
# Initialize GenAI client
log_structured_entry(
"Creating GenAI client",
"INFO",
{
"project_id": config_for_init.project_id,
"location": config_for_init.location,
},
{"project_id": cfg.project_id, "location": cfg.location}
)
genai_client = genai.Client(
vertexai=True,
project=config_for_init.project_id,
location=config_for_init.location,
project=cfg.project_id,
location=cfg.location,
)
# Validate credentials and configuration by testing actual resources
@@ -82,52 +77,43 @@ async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
validation_errors = []
# Run all validations
config = get_config()
genai_error = await validate_genai_access(genai_client, config)
genai_error = await validate_genai_access(genai_client, cfg)
if genai_error:
validation_errors.append(genai_error)
gcs_error = await validate_gcs_access(vs, config)
gcs_error = await validate_gcs_access(vs, cfg)
if gcs_error:
validation_errors.append(gcs_error)
vs_error = await validate_vector_search_access(vs, config)
vs_error = await validate_vector_search_access(vs, cfg)
if vs_error:
validation_errors.append(vs_error)
# Summary of validations
if validation_errors:
log_structured_entry(
(
"MCP server started with validation errors - "
"service may not work correctly"
),
"MCP server started with validation errors - service may not work correctly",
"WARNING",
{
"validation_errors": validation_errors,
"error_count": len(validation_errors),
},
{"validation_errors": validation_errors, "error_count": len(validation_errors)}
)
else:
log_structured_entry(
"All validations passed - MCP server initialization complete", "INFO"
)
log_structured_entry("All validations passed - MCP server initialization complete", "INFO")
# Initialize semantic cache if Redis is configured
semantic_cache = None
if config_for_init.redis_url:
if cfg.redis_url:
try:
semantic_cache = KnowledgeSemanticCache(
redis_url=config_for_init.redis_url,
name=config_for_init.cache_name,
vector_dims=config_for_init.cache_vector_dims,
distance_threshold=config_for_init.cache_distance_threshold,
ttl=config_for_init.cache_ttl,
redis_url=cfg.redis_url,
name=cfg.cache_name,
vector_dims=cfg.cache_vector_dims,
distance_threshold=cfg.cache_distance_threshold,
ttl=cfg.cache_ttl,
)
log_structured_entry(
"Semantic cache initialized",
"INFO",
{"redis_url": config_for_init.redis_url, "cache_name": config_for_init.cache_name},
{"redis_url": cfg.redis_url, "cache_name": cfg.cache_name},
)
except Exception as e:
log_structured_entry(
@@ -139,7 +125,7 @@ async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
yield AppContext(
vector_search=vs,
genai_client=genai_client,
settings=config,
settings=cfg,
semantic_cache=semantic_cache,
)
@@ -150,7 +136,7 @@ async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
{
"error": str(e),
"error_type": type(e).__name__,
},
}
)
raise
finally:
@@ -160,9 +146,9 @@ async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
try:
await vs.close()
log_structured_entry("Closed aiohttp sessions", "INFO")
except Exception as e: # noqa: BLE001
except Exception as e:
log_structured_entry(
"Error closing aiohttp sessions",
"WARNING",
{"error": str(e), "error_type": type(e).__name__},
{"error": str(e), "error_type": type(e).__name__}
)

View File

@@ -1,21 +1,13 @@
"""Service modules for business logic."""
from .search import (
filter_search_results,
format_search_results,
generate_query_embedding,
)
from .validation import (
validate_gcs_access,
validate_genai_access,
validate_vector_search_access,
)
from .search import filter_search_results, format_search_results, generate_query_embedding
from .validation import validate_genai_access, validate_gcs_access, validate_vector_search_access
__all__ = [
"filter_search_results",
"format_search_results",
"generate_query_embedding",
"validate_gcs_access",
"validate_genai_access",
"validate_gcs_access",
"validate_vector_search_access",
]

View File

@@ -1,10 +1,11 @@
# ruff: noqa: INP001
"""Search helper functions."""
from google import genai
from google.genai import types as genai_types
from knowledge_search_mcp.logging import log_structured_entry
from knowledge_search_mcp.models import SearchResult
from ..logging import log_structured_entry
from ..models import SearchResult
async def generate_query_embedding(
@@ -16,7 +17,6 @@ async def generate_query_embedding(
Returns:
Tuple of (embedding vector, error message). Error message is None on success.
"""
if not query or not query.strip():
return ([], "Error: Query cannot be empty")
@@ -30,11 +30,9 @@ async def generate_query_embedding(
task_type="RETRIEVAL_QUERY",
),
)
if not response.embeddings or not response.embeddings[0].values:
return ([], "Error: Failed to generate embedding - empty response")
embedding = response.embeddings[0].values
return (embedding, None) # noqa: TRY300
except Exception as e: # noqa: BLE001
return (embedding, None)
except Exception as e:
error_type = type(e).__name__
error_msg = str(e)
@@ -43,13 +41,22 @@ async def generate_query_embedding(
log_structured_entry(
"Rate limit exceeded while generating embedding",
"WARNING",
{"error": error_msg, "error_type": error_type, "query": query[:100]},
{
"error": error_msg,
"error_type": error_type,
"query": query[:100]
}
)
return ([], "Error: API rate limit exceeded. Please try again later.")
else:
log_structured_entry(
"Failed to generate query embedding",
"ERROR",
{"error": error_msg, "error_type": error_type, "query": query[:100]},
{
"error": error_msg,
"error_type": error_type,
"query": query[:100]
}
)
return ([], f"Error generating embedding: {error_msg}")
@@ -68,7 +75,6 @@ def filter_search_results(
Returns:
Filtered list of search results.
"""
if not results:
return []
@@ -76,10 +82,14 @@ def filter_search_results(
max_sim = max(r["distance"] for r in results)
cutoff = max_sim * top_percent
return [
s for s in results if s["distance"] > cutoff and s["distance"] > min_similarity
filtered = [
s
for s in results
if s["distance"] > cutoff and s["distance"] > min_similarity
]
return filtered
def format_search_results(results: list[SearchResult]) -> str:
"""Format search results as XML-like documents.
@@ -89,7 +99,6 @@ def format_search_results(results: list[SearchResult]) -> str:
Returns:
Formatted string with document tags.
"""
if not results:
return "No relevant documents found for your query."

View File

@@ -1,26 +1,20 @@
# ruff: noqa: INP001
"""Validation functions for Google Cloud services."""
from gcloud.aio.auth import Token
from google import genai
from google.genai import types as genai_types
from knowledge_search_mcp.clients.vector_search import GoogleCloudVectorSearch
from knowledge_search_mcp.config import Settings
from knowledge_search_mcp.logging import log_structured_entry
# HTTP status codes
HTTP_FORBIDDEN = 403
HTTP_NOT_FOUND = 404
from ..clients.vector_search import GoogleCloudVectorSearch
from ..config import Settings
from ..logging import log_structured_entry
async def validate_genai_access(
genai_client: genai.Client, cfg: Settings
) -> str | None:
async def validate_genai_access(genai_client: genai.Client, cfg: Settings) -> str | None:
"""Validate GenAI embedding access.
Returns:
Error message if validation fails, None if successful.
"""
log_structured_entry("Validating GenAI embedding access", "INFO")
try:
@@ -36,26 +30,20 @@ async def validate_genai_access(
log_structured_entry(
"GenAI embedding validation successful",
"INFO",
{
"embedding_dimension": len(embedding_values)
if embedding_values
else 0
},
{"embedding_dimension": len(embedding_values) if embedding_values else 0}
)
return None
else:
msg = "Embedding validation returned empty response"
log_structured_entry(msg, "WARNING")
return msg # noqa: TRY300
except Exception as e: # noqa: BLE001
return msg
except Exception as e:
log_structured_entry(
(
"Failed to validate GenAI embedding access - "
"service may not work correctly"
),
"Failed to validate GenAI embedding access - service may not work correctly",
"WARNING",
{"error": str(e), "error_type": type(e).__name__},
{"error": str(e), "error_type": type(e).__name__}
)
return f"GenAI: {e!s}"
return f"GenAI: {str(e)}"
async def validate_gcs_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str | None:
@@ -63,11 +51,14 @@ async def validate_gcs_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str
Returns:
Error message if validation fails, None if successful.
"""
log_structured_entry("Validating GCS bucket access", "INFO", {"bucket": cfg.bucket})
log_structured_entry(
"Validating GCS bucket access",
"INFO",
{"bucket": cfg.bucket}
)
try:
session = vs.storage._get_aio_session() # noqa: SLF001
session = vs.storage._get_aio_session()
token_obj = Token(
session=session,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
@@ -79,136 +70,102 @@ async def validate_gcs_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str
f"https://storage.googleapis.com/storage/v1/b/{cfg.bucket}/o?maxResults=1",
headers=headers,
) as response:
if response.status == HTTP_FORBIDDEN:
if response.status == 403:
msg = f"Access denied to bucket '{cfg.bucket}'. Check permissions."
log_structured_entry(
(
"GCS bucket validation failed - access denied - "
"service may not work correctly"
),
"GCS bucket validation failed - access denied - service may not work correctly",
"WARNING",
{"bucket": cfg.bucket, "status": response.status},
{"bucket": cfg.bucket, "status": response.status}
)
return msg
if response.status == HTTP_NOT_FOUND:
elif response.status == 404:
msg = f"Bucket '{cfg.bucket}' not found. Check bucket name and project."
log_structured_entry(
(
"GCS bucket validation failed - not found - "
"service may not work correctly"
),
"GCS bucket validation failed - not found - service may not work correctly",
"WARNING",
{"bucket": cfg.bucket, "status": response.status},
{"bucket": cfg.bucket, "status": response.status}
)
return msg
if not response.ok:
elif not response.ok:
body = await response.text()
msg = f"Failed to access bucket '{cfg.bucket}': {response.status}"
log_structured_entry(
"GCS bucket validation failed - service may not work correctly",
"WARNING",
{"bucket": cfg.bucket, "status": response.status, "response": body},
{"bucket": cfg.bucket, "status": response.status, "response": body}
)
return msg
else:
log_structured_entry(
"GCS bucket validation successful", "INFO", {"bucket": cfg.bucket}
"GCS bucket validation successful",
"INFO",
{"bucket": cfg.bucket}
)
return None
except Exception as e: # noqa: BLE001
except Exception as e:
log_structured_entry(
"Failed to validate GCS bucket access - service may not work correctly",
"WARNING",
{"error": str(e), "error_type": type(e).__name__, "bucket": cfg.bucket},
{"error": str(e), "error_type": type(e).__name__, "bucket": cfg.bucket}
)
return f"GCS: {e!s}"
return f"GCS: {str(e)}"
async def validate_vector_search_access(
vs: GoogleCloudVectorSearch, cfg: Settings
) -> str | None:
async def validate_vector_search_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str | None:
"""Validate vector search endpoint access.
Returns:
Error message if validation fails, None if successful.
"""
log_structured_entry(
"Validating vector search endpoint access",
"INFO",
{"endpoint_name": cfg.endpoint_name},
{"endpoint_name": cfg.endpoint_name}
)
try:
headers = await vs._async_get_auth_headers() # noqa: SLF001
session = vs._get_aio_session() # noqa: SLF001
headers = await vs._async_get_auth_headers()
session = vs._get_aio_session()
endpoint_url = (
f"https://{cfg.location}-aiplatform.googleapis.com/v1/{cfg.endpoint_name}"
)
async with session.get(endpoint_url, headers=headers) as response:
if response.status == HTTP_FORBIDDEN:
msg = (
f"Access denied to endpoint '{cfg.endpoint_name}'. "
"Check permissions."
)
if response.status == 403:
msg = f"Access denied to endpoint '{cfg.endpoint_name}'. Check permissions."
log_structured_entry(
(
"Vector search endpoint validation failed - "
"access denied - service may not work correctly"
),
"Vector search endpoint validation failed - access denied - service may not work correctly",
"WARNING",
{"endpoint": cfg.endpoint_name, "status": response.status},
{"endpoint": cfg.endpoint_name, "status": response.status}
)
return msg
if response.status == HTTP_NOT_FOUND:
msg = (
f"Endpoint '{cfg.endpoint_name}' not found. "
"Check endpoint name and project."
)
elif response.status == 404:
msg = f"Endpoint '{cfg.endpoint_name}' not found. Check endpoint name and project."
log_structured_entry(
(
"Vector search endpoint validation failed - "
"not found - service may not work correctly"
),
"Vector search endpoint validation failed - not found - service may not work correctly",
"WARNING",
{"endpoint": cfg.endpoint_name, "status": response.status},
{"endpoint": cfg.endpoint_name, "status": response.status}
)
return msg
if not response.ok:
elif not response.ok:
body = await response.text()
msg = (
f"Failed to access endpoint '{cfg.endpoint_name}': "
f"{response.status}"
)
msg = f"Failed to access endpoint '{cfg.endpoint_name}': {response.status}"
log_structured_entry(
(
"Vector search endpoint validation failed - "
"service may not work correctly"
),
"Vector search endpoint validation failed - service may not work correctly",
"WARNING",
{
"endpoint": cfg.endpoint_name,
"status": response.status,
"response": body,
},
{"endpoint": cfg.endpoint_name, "status": response.status, "response": body}
)
return msg
else:
log_structured_entry(
"Vector search endpoint validation successful",
"INFO",
{"endpoint": cfg.endpoint_name},
{"endpoint": cfg.endpoint_name}
)
return None
except Exception as e: # noqa: BLE001
except Exception as e:
log_structured_entry(
(
"Failed to validate vector search endpoint access - "
"service may not work correctly"
),
"Failed to validate vector search endpoint access - service may not work correctly",
"WARNING",
{
"error": str(e),
"error_type": type(e).__name__,
"endpoint": cfg.endpoint_name,
},
{"error": str(e), "error_type": type(e).__name__, "endpoint": cfg.endpoint_name}
)
return f"Vector Search: {e!s}"
return f"Vector Search: {str(e)}"

View File

@@ -1,3 +1,4 @@
# ruff: noqa: INP001
"""LRU cache implementation."""
from collections import OrderedDict