Compare commits
34 Commits
56e181a772
...
push-omyxs
| Author | SHA1 | Date | |
|---|---|---|---|
| 132ea1c04f | |||
| 0cdf9cd44e | |||
| d39b8a6ea7 | |||
| 86ed34887b | |||
| 694b060fa4 | |||
| d69c4e4f4a | |||
| f6e122b5a9 | |||
| dba94107a5 | |||
| d3cd8d5291 | |||
| 8dfd2048a5 | |||
| 3e2386b9b6 | |||
|
|
42e1660143 | ||
| 208d5ebebf | |||
| 83ed64326f | |||
| 4c59da0c22 | |||
|
|
cbf3ca7df4 | ||
| e9b4c93a20 | |||
| b95bb72b24 | |||
| a3ba340224 | |||
| 0fd97a31a5 | |||
| a8c611fbec | |||
| 13c8e122de | |||
| 753b5c7871 | |||
| 6feeeff4f3 | |||
| 3b7dd91a71 | |||
| 427de45522 | |||
|
|
f0b9d1b27a | ||
|
|
bf2cc2f556 | ||
|
|
b92a5d5b0e | ||
|
|
bd107a027a | ||
|
|
dcc05d697e | ||
|
|
82764bd60b | ||
|
|
54eb6f240c | ||
|
|
bb19770663 |
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
||||
.git/
|
||||
.venv/
|
||||
.ruff_cache/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.env
|
||||
agent.py
|
||||
AGENTS.md
|
||||
README.md
|
||||
43
.gitea/workflows/ci.yml
Normal file
43
.gitea/workflows/ci.yml
Normal file
@@ -0,0 +1,43 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: uv sync --frozen
|
||||
- name: Ruff check
|
||||
run: uv run ruff check
|
||||
- name: Ruff format check
|
||||
run: uv run ruff format --check
|
||||
|
||||
typecheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: uv sync --frozen
|
||||
- name: Type check
|
||||
run: uv run ty check
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: uv sync --frozen
|
||||
- name: Run tests
|
||||
run: uv run pytest --cov
|
||||
216
.gitignore
vendored
Normal file
216
.gitignore
vendored
Normal file
@@ -0,0 +1,216 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
# Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
# poetry.lock
|
||||
# poetry.toml
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||
# pdm.lock
|
||||
# pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# pixi
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||
# pixi.lock
|
||||
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||
.pixi
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# Redis
|
||||
*.rdb
|
||||
*.aof
|
||||
*.pid
|
||||
|
||||
# RabbitMQ
|
||||
mnesia/
|
||||
rabbitmq/
|
||||
rabbitmq-data/
|
||||
|
||||
# ActiveMQ
|
||||
activemq-data/
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.envrc
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
# .idea/
|
||||
|
||||
# Abstra
|
||||
# Abstra is an AI-powered process automation framework.
|
||||
# Ignore directories containing user credentials, local state, and settings.
|
||||
# Learn more at https://abstra.io/docs
|
||||
.abstra/
|
||||
|
||||
# Visual Studio Code
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# you could uncomment the following to ignore the entire vscode folder
|
||||
# .vscode/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Marimo
|
||||
marimo/_static/
|
||||
marimo/_lsp/
|
||||
__marimo__/
|
||||
|
||||
# Streamlit
|
||||
.streamlit/secrets.toml
|
||||
3
CLAUDE.md
Normal file
3
CLAUDE.md
Normal file
@@ -0,0 +1,3 @@
|
||||
Use `uv` for project management
|
||||
Linter: `uv run ruff check`
|
||||
Type-checking: `uv run ty check`
|
||||
14
DockerfileConnector
Normal file
14
DockerfileConnector
Normal file
@@ -0,0 +1,14 @@
|
||||
FROM quay.ocp.banorte.com/golden/python-312:latest
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv sync --no-dev --frozen
|
||||
|
||||
COPY src/ src/
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
CMD ["uv", "run", "python", "-m", "knowledge_search_mcp", "--transport", "streamable-http", "--port", "8000"]
|
||||
106
README.md
106
README.md
@@ -0,0 +1,106 @@
|
||||
# knowledge-search-mcp
|
||||
|
||||
An MCP (Model Context Protocol) server that exposes a `knowledge_search` tool for semantic search over a knowledge base backed by Vertex AI Vector Search and Google Cloud Storage.
|
||||
|
||||
## How it works
|
||||
|
||||
1. A natural-language query is embedded using a Gemini embedding model.
|
||||
2. The embedding is sent to a Vertex AI Matching Engine index endpoint to find nearest neighbors.
|
||||
3. The matched document contents are fetched from a GCS bucket and returned to the caller.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python ≥ 3.12
|
||||
- [uv](https://docs.astral.sh/uv/) for dependency management
|
||||
- A Google Cloud project with:
|
||||
- A Vertex AI Vector Search index and deployed endpoint
|
||||
- A GCS bucket containing the indexed document chunks
|
||||
- Application Default Credentials (or a service account) with appropriate permissions
|
||||
|
||||
## Configuration
|
||||
|
||||
Create a `config.yaml` file or `.env` file (see `Settings` in `src/knowledge_search_mcp/config.py` for all options):
|
||||
|
||||
```env
|
||||
PROJECT_ID=my-gcp-project
|
||||
LOCATION=us-central1
|
||||
BUCKET=my-knowledge-bucket
|
||||
INDEX_NAME=my-index
|
||||
DEPLOYED_INDEX_ID=my-deployed-index
|
||||
ENDPOINT_NAME=projects/…/locations/…/indexEndpoints/…
|
||||
ENDPOINT_DOMAIN=123456789.us-central1-aiplatform.googleapis.com
|
||||
# optional
|
||||
EMBEDDING_MODEL=gemini-embedding-001
|
||||
SEARCH_LIMIT=10
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Install dependencies
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
### Run the MCP server
|
||||
|
||||
**Using the installed command (recommended):**
|
||||
|
||||
```bash
|
||||
# stdio transport (default)
|
||||
uv run knowledge-search-mcp
|
||||
|
||||
# SSE transport for remote clients
|
||||
uv run knowledge-search-mcp --transport sse --port 8080
|
||||
|
||||
# streamable-http transport
|
||||
uv run knowledge-search-mcp --transport streamable-http --port 8080
|
||||
```
|
||||
|
||||
**Or run directly:**
|
||||
|
||||
```bash
|
||||
uv run python -m knowledge_search_mcp.main
|
||||
```
|
||||
|
||||
### Run the interactive agent (ADK)
|
||||
|
||||
The bundled agent spawns the MCP server as a subprocess and provides a REPL:
|
||||
|
||||
```bash
|
||||
uv run python agent.py
|
||||
```
|
||||
|
||||
Or connect to an already-running SSE server:
|
||||
|
||||
```bash
|
||||
uv run python agent.py --remote http://localhost:8080/sse
|
||||
```
|
||||
|
||||
### Run tests
|
||||
|
||||
```bash
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
```bash
|
||||
docker build -t knowledge-search-mcp .
|
||||
docker run -p 8080:8080 --env-file .env knowledge-search-mcp
|
||||
```
|
||||
|
||||
The container starts the server in SSE mode on the port specified by `PORT` (default `8080`).
|
||||
|
||||
## Project structure
|
||||
|
||||
```
|
||||
src/knowledge_search_mcp/
|
||||
├── __init__.py Package initialization
|
||||
├── config.py Configuration management (Settings, args parsing)
|
||||
├── logging.py Cloud Logging setup
|
||||
└── main.py MCP server, vector search client, and GCS storage helper
|
||||
agent.py Interactive ADK agent that consumes the MCP server
|
||||
tests/ Test suite
|
||||
pyproject.toml Project metadata, dependencies, and entry points
|
||||
```
|
||||
|
||||
401
main.py
401
main.py
@@ -1,401 +0,0 @@
|
||||
# ruff: noqa: INP001
|
||||
"""Async helpers for querying Vertex AI vector search via MCP."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import BinaryIO, TypedDict
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio.auth import Token
|
||||
from gcloud.aio.storage import Storage
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HTTP_TOO_MANY_REQUESTS = 429
|
||||
HTTP_SERVER_ERROR = 500
|
||||
|
||||
|
||||
class GoogleCloudFileStorage:
|
||||
"""Cache-aware helper for downloading files from Google Cloud Storage."""
|
||||
|
||||
def __init__(self, bucket: str) -> None:
|
||||
"""Initialize the storage helper."""
|
||||
self.bucket_name = bucket
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._aio_storage: Storage | None = None
|
||||
self._cache: dict[str, bytes] = {}
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=300,
|
||||
limit_per_host=50,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
connector=connector,
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def _get_aio_storage(self) -> Storage:
|
||||
if self._aio_storage is None:
|
||||
self._aio_storage = Storage(
|
||||
session=self._get_aio_session(),
|
||||
)
|
||||
return self._aio_storage
|
||||
|
||||
async def async_get_file_stream(
|
||||
self,
|
||||
file_name: str,
|
||||
max_retries: int = 3,
|
||||
) -> BinaryIO:
|
||||
"""Get a file asynchronously with retry on transient errors.
|
||||
|
||||
Args:
|
||||
file_name: The blob name to retrieve.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
|
||||
Returns:
|
||||
A BytesIO stream with the file contents.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If all retry attempts fail.
|
||||
|
||||
"""
|
||||
if file_name in self._cache:
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
storage_client = self._get_aio_storage()
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self._cache[file_name] = await storage_client.download(
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
)
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
except TimeoutError as exc:
|
||||
last_exception = exc
|
||||
logger.warning(
|
||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
last_exception = exc
|
||||
if (
|
||||
exc.status == HTTP_TOO_MANY_REQUESTS
|
||||
or exc.status >= HTTP_SERVER_ERROR
|
||||
):
|
||||
logger.warning(
|
||||
"HTTP %d downloading gs://%s/%s (attempt %d/%d)",
|
||||
exc.status,
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
return file_stream
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
delay = 0.5 * (2**attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
msg = (
|
||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
||||
f"after {max_retries} attempts"
|
||||
)
|
||||
raise TimeoutError(msg) from last_exception
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
"""Structured response item returned by the vector search API."""
|
||||
|
||||
id: str
|
||||
distance: float
|
||||
content: str
|
||||
|
||||
|
||||
class GoogleCloudVectorSearch:
|
||||
"""Minimal async client for the Vertex AI Matching Engine REST API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
bucket: str,
|
||||
index_name: str | None = None,
|
||||
) -> None:
|
||||
"""Store configuration used to issue Matching Engine queries."""
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||
self.index_name = index_name
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._async_token: Token | None = None
|
||||
self._endpoint_domain: str | None = None
|
||||
self._endpoint_name: str | None = None
|
||||
|
||||
async def _async_get_auth_headers(self) -> dict[str, str]:
|
||||
if self._async_token is None:
|
||||
self._async_token = Token(
|
||||
session=self._get_aio_session(),
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
],
|
||||
)
|
||||
access_token = await self._async_token.get()
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=300,
|
||||
limit_per_host=50,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
connector=connector,
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def configure_index_endpoint(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
public_domain: str,
|
||||
) -> None:
|
||||
"""Persist the metadata needed to access a deployed endpoint."""
|
||||
if not name:
|
||||
msg = "Index endpoint name must be a non-empty string."
|
||||
raise ValueError(msg)
|
||||
if not public_domain:
|
||||
msg = "Index endpoint domain must be a non-empty public domain."
|
||||
raise ValueError(msg)
|
||||
self._endpoint_name = name
|
||||
self._endpoint_domain = public_domain
|
||||
|
||||
async def async_run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: Sequence[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run an async similarity search via the REST API.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
if self._endpoint_domain is None or self._endpoint_name is None:
|
||||
msg = (
|
||||
"Missing endpoint metadata. Call "
|
||||
"`configure_index_endpoint` before querying."
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
domain = self._endpoint_domain
|
||||
endpoint_id = self._endpoint_name.split("/")[-1]
|
||||
url = (
|
||||
f"https://{domain}/v1/projects/{self.project_id}"
|
||||
f"/locations/{self.location}"
|
||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||
)
|
||||
payload = {
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"queries": [
|
||||
{
|
||||
"datapoint": {"feature_vector": list(query)},
|
||||
"neighbor_count": limit,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
headers = await self._async_get_auth_headers()
|
||||
session = self._get_aio_session()
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||
content_tasks = []
|
||||
for neighbor in neighbors:
|
||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
|
||||
content_tasks.append(
|
||||
self.storage.async_get_file_stream(file_path),
|
||||
)
|
||||
|
||||
file_streams = await asyncio.gather(*content_tasks)
|
||||
results: list[SearchResult] = []
|
||||
for neighbor, stream in zip(
|
||||
neighbors,
|
||||
file_streams,
|
||||
strict=True,
|
||||
):
|
||||
results.append(
|
||||
SearchResult(
|
||||
id=neighbor["datapoint"]["datapointId"],
|
||||
distance=neighbor["distance"],
|
||||
content=stream.read().decode("utf-8"),
|
||||
),
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCP Server
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Server configuration populated from environment variables."""
|
||||
|
||||
project_id: str
|
||||
location: str
|
||||
bucket: str
|
||||
index_name: str
|
||||
deployed_index_id: str
|
||||
endpoint_name: str
|
||||
endpoint_domain: str
|
||||
embedding_model: str = "text-embedding-005"
|
||||
search_limit: int = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppContext:
|
||||
"""Shared resources initialised once at server startup."""
|
||||
|
||||
vector_search: GoogleCloudVectorSearch
|
||||
genai_client: genai.Client
|
||||
settings: Settings
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
|
||||
"""Create and configure the vector-search client for the server lifetime."""
|
||||
cfg = Settings.model_validate({})
|
||||
|
||||
vs = GoogleCloudVectorSearch(
|
||||
project_id=cfg.project_id,
|
||||
location=cfg.location,
|
||||
bucket=cfg.bucket,
|
||||
index_name=cfg.index_name,
|
||||
)
|
||||
vs.configure_index_endpoint(
|
||||
name=cfg.endpoint_name,
|
||||
public_domain=cfg.endpoint_domain,
|
||||
)
|
||||
|
||||
genai_client = genai.Client(
|
||||
vertexai=True,
|
||||
project=cfg.project_id,
|
||||
location=cfg.location,
|
||||
)
|
||||
|
||||
yield AppContext(
|
||||
vector_search=vs,
|
||||
genai_client=genai_client,
|
||||
settings=cfg,
|
||||
)
|
||||
|
||||
|
||||
mcp = FastMCP("knowledge-search", lifespan=lifespan)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def knowledge_search(
|
||||
query: str,
|
||||
ctx: Context,
|
||||
) -> str:
|
||||
"""Search a knowledge base using a natural-language query.
|
||||
|
||||
Args:
|
||||
query: The text query to search for.
|
||||
ctx: MCP request context (injected automatically).
|
||||
|
||||
Returns:
|
||||
A formatted string containing matched documents with id and content.
|
||||
|
||||
"""
|
||||
import time
|
||||
|
||||
app: AppContext = ctx.request_context.lifespan_context
|
||||
t0 = time.perf_counter()
|
||||
min_sim = 0.6
|
||||
|
||||
response = await app.genai_client.aio.models.embed_content(
|
||||
model=app.settings.embedding_model,
|
||||
contents=query,
|
||||
config=genai_types.EmbedContentConfig(
|
||||
task_type="RETRIEVAL_QUERY",
|
||||
),
|
||||
)
|
||||
embedding = response.embeddings[0].values
|
||||
t_embed = time.perf_counter()
|
||||
|
||||
search_results = await app.vector_search.async_run_query(
|
||||
deployed_index_id=app.settings.deployed_index_id,
|
||||
query=embedding,
|
||||
limit=app.settings.search_limit,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
# Apply similarity filtering
|
||||
if search_results:
|
||||
max_sim = max(r["distance"] for r in search_results)
|
||||
cutoff = max_sim * 0.9
|
||||
search_results = [
|
||||
s
|
||||
for s in search_results
|
||||
if s["distance"] > cutoff and s["distance"] > min_sim
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, chunks=%s",
|
||||
round((t_embed - t0) * 1000, 1),
|
||||
round((t_search - t_embed) * 1000, 1),
|
||||
round((t_search - t0) * 1000, 1),
|
||||
[s["id"] for s in search_results],
|
||||
)
|
||||
|
||||
# Format results as XML-like documents
|
||||
formatted_results = [
|
||||
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
]
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "knowledge-search-mcp"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
description = "MCP server for semantic search over Vertex AI Vector Search"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
@@ -12,10 +12,45 @@ dependencies = [
|
||||
"google-genai>=1.64.0",
|
||||
"mcp[cli]>=1.26.0",
|
||||
"pydantic-settings>=2.9.1",
|
||||
"pyyaml>=6.0",
|
||||
"redis[hiredis]>=5.0.0,<7",
|
||||
"redisvl>=0.6.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
knowledge-search-mcp = "knowledge_search_mcp.__main__:main"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"google-adk>=1.25.1",
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=6.0.0",
|
||||
"ruff>=0.15.2",
|
||||
"ty>=0.0.18",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.ruff]
|
||||
exclude = ["scripts", "tests"]
|
||||
|
||||
[tool.ty.src]
|
||||
exclude = ["scripts", "tests"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ['ALL']
|
||||
ignore = [
|
||||
'D203', # one-blank-line-before-class
|
||||
'D213', # multi-line-summary-second-line
|
||||
'COM812', # missing-trailing-comma
|
||||
'ANN401', # dynamically-typed-any
|
||||
'ERA001', # commented-out-code
|
||||
]
|
||||
|
||||
111
scripts/agent.py
Normal file
111
scripts/agent.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# ruff: noqa: INP001
|
||||
"""ADK agent that connects to the knowledge-search MCP server."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions import InMemorySessionService
|
||||
from google.adk.tools.mcp_tool import McpToolset
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import (
|
||||
SseConnectionParams,
|
||||
StdioConnectionParams,
|
||||
)
|
||||
from google.genai import types
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
# ADK needs these env vars for Vertex AI; reuse the ones from .env
|
||||
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "True")
|
||||
if project := os.environ.get("PROJECT_ID"):
|
||||
os.environ.setdefault("GOOGLE_CLOUD_PROJECT", project)
|
||||
if location := os.environ.get("LOCATION"):
|
||||
os.environ.setdefault("GOOGLE_CLOUD_LOCATION", location)
|
||||
|
||||
SERVER_SCRIPT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src", "knowledge_search_mcp", "main.py")
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Knowledge Search Agent")
|
||||
parser.add_argument(
|
||||
"--remote",
|
||||
metavar="URL",
|
||||
help="Connect to an already-running MCP server at this SSE URL "
|
||||
"(e.g. http://localhost:8080/sse). Without this flag the agent "
|
||||
"spawns the server as a subprocess.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def async_main() -> None:
|
||||
args = _parse_args()
|
||||
|
||||
if args.remote:
|
||||
connection_params = SseConnectionParams(url=args.remote)
|
||||
else:
|
||||
connection_params = StdioConnectionParams(
|
||||
server_params=StdioServerParameters(
|
||||
command="uv",
|
||||
args=["run", "python", SERVER_SCRIPT],
|
||||
),
|
||||
)
|
||||
|
||||
toolset = McpToolset(connection_params=connection_params)
|
||||
|
||||
agent = LlmAgent(
|
||||
model="gemini-2.0-flash",
|
||||
name="knowledge_agent",
|
||||
instruction=(
|
||||
"You are a helpful assistant with access to a knowledge base. "
|
||||
"Use the knowledge_search tool to find relevant information "
|
||||
"when the user asks questions. Summarize the results clearly."
|
||||
),
|
||||
tools=[toolset],
|
||||
)
|
||||
|
||||
session_service = InMemorySessionService()
|
||||
session = await session_service.create_session(
|
||||
state={},
|
||||
app_name="knowledge_agent",
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
runner = Runner(
|
||||
app_name="knowledge_agent",
|
||||
agent=agent,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
print("Knowledge Search Agent ready. Type your query (Ctrl+C to exit):")
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
query = input("\n> ").strip()
|
||||
except EOFError:
|
||||
break
|
||||
if not query:
|
||||
continue
|
||||
|
||||
content = types.Content(
|
||||
role="user",
|
||||
parts=[types.Part(text=query)],
|
||||
)
|
||||
|
||||
async for event in runner.run_async(
|
||||
session_id=session.id,
|
||||
user_id=session.user_id,
|
||||
new_message=content,
|
||||
):
|
||||
if event.is_final_response() and event.content and event.content.parts:
|
||||
for part in event.content.parts:
|
||||
if part.text:
|
||||
print(part.text)
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutting down...")
|
||||
finally:
|
||||
await toolset.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(async_main())
|
||||
15
src/knowledge_search_mcp/__init__.py
Normal file
15
src/knowledge_search_mcp/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""MCP server for semantic search over Vertex AI Vector Search."""
|
||||
|
||||
from .clients.storage import GoogleCloudFileStorage
|
||||
from .clients.vector_search import GoogleCloudVectorSearch
|
||||
from .models import AppContext, SearchResult, SourceNamespace
|
||||
from .utils.cache import LRUCache
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
"GoogleCloudFileStorage",
|
||||
"GoogleCloudVectorSearch",
|
||||
"LRUCache",
|
||||
"SearchResult",
|
||||
"SourceNamespace",
|
||||
]
|
||||
152
src/knowledge_search_mcp/__main__.py
Normal file
152
src/knowledge_search_mcp/__main__.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""MCP server for semantic search over Vertex AI Vector Search."""
|
||||
|
||||
import time
|
||||
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
|
||||
from .config import _args
|
||||
from .logging import log_structured_entry
|
||||
from .models import AppContext, SourceNamespace
|
||||
from .server import lifespan
|
||||
from .services.search import (
|
||||
filter_search_results,
|
||||
format_search_results,
|
||||
generate_query_embedding,
|
||||
)
|
||||
|
||||
mcp = FastMCP(
|
||||
"knowledge-search",
|
||||
host=_args.host,
|
||||
port=_args.port,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def knowledge_search(
|
||||
query: str,
|
||||
ctx: Context,
|
||||
source: SourceNamespace | None = None,
|
||||
) -> str:
|
||||
"""Search a knowledge base using a natural-language query.
|
||||
|
||||
Args:
|
||||
query: The text query to search for.
|
||||
ctx: MCP request context (injected automatically).
|
||||
source: Optional filter to restrict results by source.
|
||||
Allowed values: 'Educacion Financiera',
|
||||
'Productos y Servicios', 'Funcionalidades de la App Movil'.
|
||||
|
||||
Returns:
|
||||
A formatted string containing matched documents with id and content.
|
||||
|
||||
"""
|
||||
app: AppContext = ctx.request_context.lifespan_context
|
||||
t0 = time.perf_counter()
|
||||
|
||||
log_structured_entry(
|
||||
"knowledge_search request received",
|
||||
"INFO",
|
||||
{"query": query[:100]}, # Log first 100 chars of query
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate embedding for the query
|
||||
embedding, error = await generate_query_embedding(
|
||||
app.genai_client,
|
||||
app.settings.embedding_model,
|
||||
query,
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
t_embed = time.perf_counter()
|
||||
log_structured_entry(
|
||||
"Query embedding generated successfully",
|
||||
"INFO",
|
||||
{"time_ms": round((t_embed - t0) * 1000, 1)},
|
||||
)
|
||||
|
||||
# Check semantic cache before vector search
|
||||
if app.semantic_cache is not None and source is None:
|
||||
cached = await app.semantic_cache.check(embedding)
|
||||
if cached is not None:
|
||||
t_cache = time.perf_counter()
|
||||
log_structured_entry(
|
||||
"knowledge_search completed from cache",
|
||||
"INFO",
|
||||
{
|
||||
"embedding_ms": f"{round((t_embed - t0) * 1000, 1)}ms",
|
||||
"cache_check_ms": f"{round((t_cache - t_embed) * 1000, 1)}ms",
|
||||
"total_ms": f"{round((t_cache - t0) * 1000, 1)}ms",
|
||||
"cache_hit": True,
|
||||
},
|
||||
)
|
||||
return cached
|
||||
|
||||
# Perform vector search
|
||||
log_structured_entry("Performing vector search", "INFO")
|
||||
try:
|
||||
search_results = await app.vector_search.async_run_query(
|
||||
deployed_index_id=app.settings.deployed_index_id,
|
||||
query=embedding,
|
||||
limit=app.settings.search_limit,
|
||||
source=source,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
except Exception as e: # noqa: BLE001
|
||||
log_structured_entry(
|
||||
"Vector search failed",
|
||||
"ERROR",
|
||||
{"error": str(e), "error_type": type(e).__name__, "query": query[:100]},
|
||||
)
|
||||
return f"Error performing vector search: {e!s}"
|
||||
|
||||
# Apply similarity filtering
|
||||
filtered_results = filter_search_results(search_results)
|
||||
|
||||
log_structured_entry(
|
||||
"knowledge_search completed successfully",
|
||||
"INFO",
|
||||
{
|
||||
"embedding_ms": f"{round((t_embed - t0) * 1000, 1)}ms",
|
||||
"vector_search_ms": f"{round((t_search - t_embed) * 1000, 1)}ms",
|
||||
"total_ms": f"{round((t_search - t0) * 1000, 1)}ms",
|
||||
"source_filter": source.value if source is not None else None,
|
||||
"results_count": len(filtered_results),
|
||||
"chunks": [s["id"] for s in filtered_results],
|
||||
"cache_hit": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Format and return results
|
||||
formatted = format_search_results(filtered_results)
|
||||
|
||||
if not filtered_results:
|
||||
log_structured_entry(
|
||||
"No results found for query", "INFO", {"query": query[:100]}
|
||||
)
|
||||
|
||||
# Store in semantic cache (only for unfiltered queries with results)
|
||||
if app.semantic_cache is not None and source is None and filtered_results:
|
||||
await app.semantic_cache.store(query, formatted, embedding)
|
||||
|
||||
return formatted
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
# Catch-all for any unexpected errors
|
||||
log_structured_entry(
|
||||
"Unexpected error in knowledge_search",
|
||||
"ERROR",
|
||||
{"error": str(e), "error_type": type(e).__name__, "query": query[:100]},
|
||||
)
|
||||
return f"Unexpected error during search: {e!s}"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point for the MCP server."""
|
||||
mcp.run(transport=_args.transport)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
11
src/knowledge_search_mcp/clients/__init__.py
Normal file
11
src/knowledge_search_mcp/clients/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Client modules for Google Cloud services."""
|
||||
|
||||
from .base import BaseGoogleCloudClient
|
||||
from .storage import GoogleCloudFileStorage
|
||||
from .vector_search import GoogleCloudVectorSearch
|
||||
|
||||
__all__ = [
|
||||
"BaseGoogleCloudClient",
|
||||
"GoogleCloudFileStorage",
|
||||
"GoogleCloudVectorSearch",
|
||||
]
|
||||
30
src/knowledge_search_mcp/clients/base.py
Normal file
30
src/knowledge_search_mcp/clients/base.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Base client with shared aiohttp session management."""
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class BaseGoogleCloudClient:
|
||||
"""Base class with shared aiohttp session management."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize session tracking."""
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create aiohttp session with connection pooling."""
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=300,
|
||||
limit_per_host=50,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
connector=connector,
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close aiohttp session if open."""
|
||||
if self._aio_session and not self._aio_session.closed:
|
||||
await self._aio_session.close()
|
||||
150
src/knowledge_search_mcp/clients/storage.py
Normal file
150
src/knowledge_search_mcp/clients/storage.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Google Cloud Storage client with caching."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from typing import BinaryIO
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio.storage import Storage
|
||||
|
||||
from knowledge_search_mcp.logging import log_structured_entry
|
||||
from knowledge_search_mcp.utils.cache import LRUCache
|
||||
|
||||
from .base import BaseGoogleCloudClient
|
||||
|
||||
HTTP_TOO_MANY_REQUESTS = 429
|
||||
HTTP_SERVER_ERROR = 500
|
||||
|
||||
|
||||
class GoogleCloudFileStorage(BaseGoogleCloudClient):
|
||||
"""Cache-aware helper for downloading files from Google Cloud Storage."""
|
||||
|
||||
def __init__(self, bucket: str, cache_size: int = 100) -> None:
|
||||
"""Initialize the storage helper with LRU cache."""
|
||||
super().__init__()
|
||||
self.bucket_name = bucket
|
||||
self._aio_storage: Storage | None = None
|
||||
self._cache = LRUCache(max_size=cache_size)
|
||||
|
||||
def _get_aio_storage(self) -> Storage:
|
||||
if self._aio_storage is None:
|
||||
self._aio_storage = Storage(
|
||||
session=self._get_aio_session(),
|
||||
)
|
||||
return self._aio_storage
|
||||
|
||||
async def async_get_file_stream(
|
||||
self,
|
||||
file_name: str,
|
||||
max_retries: int = 3,
|
||||
) -> BinaryIO:
|
||||
"""Get a file asynchronously with retry on transient errors.
|
||||
|
||||
Args:
|
||||
file_name: The blob name to retrieve.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
|
||||
Returns:
|
||||
A BytesIO stream with the file contents.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If all retry attempts fail.
|
||||
|
||||
"""
|
||||
cached_content = self._cache.get(file_name)
|
||||
if cached_content is not None:
|
||||
log_structured_entry(
|
||||
"File retrieved from cache",
|
||||
"INFO",
|
||||
{"file": file_name, "bucket": self.bucket_name},
|
||||
)
|
||||
file_stream = io.BytesIO(cached_content)
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
log_structured_entry(
|
||||
"Starting file download from GCS",
|
||||
"INFO",
|
||||
{"file": file_name, "bucket": self.bucket_name},
|
||||
)
|
||||
|
||||
storage_client = self._get_aio_storage()
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
content = await storage_client.download(
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
)
|
||||
self._cache.put(file_name, content)
|
||||
file_stream = io.BytesIO(content)
|
||||
file_stream.name = file_name
|
||||
log_structured_entry(
|
||||
"File downloaded successfully",
|
||||
"INFO",
|
||||
{
|
||||
"file": file_name,
|
||||
"bucket": self.bucket_name,
|
||||
"size_bytes": len(content),
|
||||
"attempt": attempt + 1,
|
||||
},
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
last_exception = exc
|
||||
log_structured_entry(
|
||||
(
|
||||
f"Timeout downloading gs://{self.bucket_name}/{file_name} "
|
||||
f"(attempt {attempt + 1}/{max_retries})"
|
||||
),
|
||||
"WARNING",
|
||||
{"error": str(exc)},
|
||||
)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
last_exception = exc
|
||||
if (
|
||||
exc.status == HTTP_TOO_MANY_REQUESTS
|
||||
or exc.status >= HTTP_SERVER_ERROR
|
||||
):
|
||||
log_structured_entry(
|
||||
(
|
||||
f"HTTP {exc.status} downloading gs://{self.bucket_name}/"
|
||||
f"{file_name} (attempt {attempt + 1}/{max_retries})"
|
||||
),
|
||||
"WARNING",
|
||||
{"status": exc.status, "message": str(exc)},
|
||||
)
|
||||
else:
|
||||
log_structured_entry(
|
||||
f"Non-retryable HTTP error downloading gs://{self.bucket_name}/{file_name}",
|
||||
"ERROR",
|
||||
{"status": exc.status, "message": str(exc)},
|
||||
)
|
||||
raise
|
||||
else:
|
||||
return file_stream
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
delay = 0.5 * (2**attempt)
|
||||
log_structured_entry(
|
||||
"Retrying file download",
|
||||
"INFO",
|
||||
{"file": file_name, "delay_seconds": delay},
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
msg = (
|
||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
||||
f"after {max_retries} attempts"
|
||||
)
|
||||
log_structured_entry(
|
||||
"File download failed after all retries",
|
||||
"ERROR",
|
||||
{
|
||||
"file": file_name,
|
||||
"bucket": self.bucket_name,
|
||||
"max_retries": max_retries,
|
||||
"last_error": str(last_exception),
|
||||
},
|
||||
)
|
||||
raise TimeoutError(msg) from last_exception
|
||||
223
src/knowledge_search_mcp/clients/vector_search.py
Normal file
223
src/knowledge_search_mcp/clients/vector_search.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Google Cloud Vector Search client."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
|
||||
from gcloud.aio.auth import Token
|
||||
|
||||
from knowledge_search_mcp.logging import log_structured_entry
|
||||
from knowledge_search_mcp.models import SearchResult, SourceNamespace
|
||||
|
||||
from .base import BaseGoogleCloudClient
|
||||
from .storage import GoogleCloudFileStorage
|
||||
|
||||
|
||||
class GoogleCloudVectorSearch(BaseGoogleCloudClient):
|
||||
"""Minimal async client for the Vertex AI Matching Engine REST API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
bucket: str,
|
||||
index_name: str | None = None,
|
||||
) -> None:
|
||||
"""Store configuration used to issue Matching Engine queries."""
|
||||
super().__init__()
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||
self.index_name = index_name
|
||||
self._async_token: Token | None = None
|
||||
self._endpoint_domain: str | None = None
|
||||
self._endpoint_name: str | None = None
|
||||
|
||||
async def _async_get_auth_headers(self) -> dict[str, str]:
|
||||
if self._async_token is None:
|
||||
self._async_token = Token(
|
||||
session=self._get_aio_session(),
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
],
|
||||
)
|
||||
access_token = await self._async_token.get()
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close aiohttp sessions for both vector search and storage."""
|
||||
await super().close()
|
||||
await self.storage.close()
|
||||
|
||||
def configure_index_endpoint(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
public_domain: str,
|
||||
) -> None:
|
||||
"""Persist the metadata needed to access a deployed endpoint."""
|
||||
if not name:
|
||||
msg = "Index endpoint name must be a non-empty string."
|
||||
raise ValueError(msg)
|
||||
if not public_domain:
|
||||
msg = "Index endpoint domain must be a non-empty public domain."
|
||||
raise ValueError(msg)
|
||||
self._endpoint_name = name
|
||||
self._endpoint_domain = public_domain
|
||||
|
||||
async def async_run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: Sequence[float],
|
||||
limit: int,
|
||||
source: SourceNamespace | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Run an async similarity search via the REST API.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
source: Optional namespace filter to restrict results by source.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
if self._endpoint_domain is None or self._endpoint_name is None:
|
||||
msg = (
|
||||
"Missing endpoint metadata. Call "
|
||||
"`configure_index_endpoint` before querying."
|
||||
)
|
||||
log_structured_entry(
|
||||
"Vector search query failed - endpoint not configured",
|
||||
"ERROR",
|
||||
{"error": msg},
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
domain = self._endpoint_domain
|
||||
endpoint_id = self._endpoint_name.split("/")[-1]
|
||||
url = (
|
||||
f"https://{domain}/v1/projects/{self.project_id}"
|
||||
f"/locations/{self.location}"
|
||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||
)
|
||||
|
||||
log_structured_entry(
|
||||
"Starting vector search query",
|
||||
"INFO",
|
||||
{
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"neighbor_count": limit,
|
||||
"endpoint_id": endpoint_id,
|
||||
"embedding_dimension": len(query),
|
||||
},
|
||||
)
|
||||
|
||||
datapoint: dict = {"feature_vector": list(query)}
|
||||
if source is not None:
|
||||
datapoint["restricts"] = [
|
||||
{"namespace": "source", "allow_list": [source.value]},
|
||||
]
|
||||
payload = {
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"queries": [
|
||||
{
|
||||
"datapoint": datapoint,
|
||||
"neighbor_count": limit,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
try:
|
||||
headers = await self._async_get_auth_headers()
|
||||
session = self._get_aio_session()
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
) as response:
|
||||
if not response.ok:
|
||||
body = await response.text()
|
||||
msg = f"findNeighbors returned {response.status}: {body}"
|
||||
log_structured_entry(
|
||||
"Vector search API request failed",
|
||||
"ERROR",
|
||||
{
|
||||
"status": response.status,
|
||||
"response_body": body,
|
||||
"deployed_index_id": deployed_index_id,
|
||||
},
|
||||
)
|
||||
raise RuntimeError(msg) # noqa: TRY301
|
||||
data = await response.json()
|
||||
|
||||
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||
log_structured_entry(
|
||||
"Vector search API request successful",
|
||||
"INFO",
|
||||
{
|
||||
"neighbors_found": len(neighbors),
|
||||
"deployed_index_id": deployed_index_id,
|
||||
},
|
||||
)
|
||||
|
||||
if not neighbors:
|
||||
log_structured_entry(
|
||||
"No neighbors found in vector search",
|
||||
"WARNING",
|
||||
{"deployed_index_id": deployed_index_id},
|
||||
)
|
||||
return []
|
||||
|
||||
# Fetch content for all neighbors
|
||||
content_tasks = []
|
||||
for neighbor in neighbors:
|
||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
|
||||
content_tasks.append(
|
||||
self.storage.async_get_file_stream(file_path),
|
||||
)
|
||||
|
||||
log_structured_entry(
|
||||
"Fetching content for search results",
|
||||
"INFO",
|
||||
{"file_count": len(content_tasks)},
|
||||
)
|
||||
|
||||
file_streams = await asyncio.gather(*content_tasks)
|
||||
results: list[SearchResult] = []
|
||||
for neighbor, stream in zip(
|
||||
neighbors,
|
||||
file_streams,
|
||||
strict=True,
|
||||
):
|
||||
results.append(
|
||||
SearchResult(
|
||||
id=neighbor["datapoint"]["datapointId"],
|
||||
distance=neighbor["distance"],
|
||||
content=stream.read().decode("utf-8"),
|
||||
),
|
||||
)
|
||||
|
||||
log_structured_entry(
|
||||
"Vector search completed successfully",
|
||||
"INFO",
|
||||
{"results_count": len(results), "deployed_index_id": deployed_index_id},
|
||||
)
|
||||
return results # noqa: TRY300
|
||||
|
||||
except Exception as e:
|
||||
log_structured_entry(
|
||||
"Vector search query failed with exception",
|
||||
"ERROR",
|
||||
{
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
"deployed_index_id": deployed_index_id,
|
||||
},
|
||||
)
|
||||
raise
|
||||
111
src/knowledge_search_mcp/config.py
Normal file
111
src/knowledge_search_mcp/config.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Configuration management for the MCP server."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
"""Parse command-line arguments.
|
||||
|
||||
Returns a namespace with default values if running under pytest.
|
||||
"""
|
||||
# Don't parse args if running under pytest
|
||||
if "pytest" in sys.modules:
|
||||
parser = argparse.ArgumentParser()
|
||||
return argparse.Namespace(
|
||||
transport="stdio",
|
||||
host="0.0.0.0", # noqa: S104
|
||||
port=8080,
|
||||
config=os.environ.get("CONFIG_FILE", "config.yaml"),
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
choices=["stdio", "sse", "streamable-http"],
|
||||
default="stdio",
|
||||
)
|
||||
parser.add_argument("--host", default="0.0.0.0") # noqa: S104
|
||||
parser.add_argument("--port", type=int, default=8080)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
default=os.environ.get("CONFIG_FILE", "config.yaml"),
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
_args = _parse_args()
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Server configuration populated from env vars and a YAML config file."""
|
||||
|
||||
model_config = {"env_file": ".env", "yaml_file": _args.config}
|
||||
|
||||
project_id: str
|
||||
location: str
|
||||
bucket: str
|
||||
index_name: str
|
||||
deployed_index_id: str
|
||||
endpoint_name: str
|
||||
endpoint_domain: str
|
||||
embedding_model: str = "gemini-embedding-001"
|
||||
search_limit: int = 10
|
||||
log_name: str = "va_agent_evaluation_logs"
|
||||
log_level: str = "INFO"
|
||||
cloud_logging_enabled: bool = False
|
||||
|
||||
# Semantic cache (Redis)
|
||||
redis_url: str | None = None
|
||||
cache_name: str = "knowledge_search_cache"
|
||||
cache_vector_dims: int = 3072
|
||||
cache_distance_threshold: float = 0.12
|
||||
cache_ttl: int | None = 3600
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""Customize the order of settings sources to include YAML config."""
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
dotenv_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
file_secret_settings,
|
||||
)
|
||||
|
||||
|
||||
# Lazy singleton instance of Settings
|
||||
_cfg: Settings | None = None
|
||||
|
||||
|
||||
def get_config() -> Settings:
|
||||
"""Get or create the singleton Settings instance."""
|
||||
global _cfg # noqa: PLW0603
|
||||
if _cfg is None:
|
||||
_cfg = Settings.model_validate({})
|
||||
return _cfg
|
||||
|
||||
|
||||
# For backwards compatibility, provide cfg as a property-like accessor
|
||||
class _ConfigProxy:
|
||||
"""Proxy object that lazily loads config on attribute access."""
|
||||
|
||||
def __getattr__(self, name: str) -> object:
|
||||
return getattr(get_config(), name)
|
||||
|
||||
|
||||
cfg = _ConfigProxy()
|
||||
67
src/knowledge_search_mcp/logging.py
Normal file
67
src/knowledge_search_mcp/logging.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Centralized Cloud Logging setup.
|
||||
|
||||
Uses CloudLoggingHandler (background thread) so logging does not add latency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import google.cloud.logging
|
||||
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||
|
||||
from .config import get_config
|
||||
|
||||
_eval_log: logging.Logger | None = None
|
||||
|
||||
|
||||
def _get_logger() -> logging.Logger:
|
||||
"""Get or create the singleton evaluation logger."""
|
||||
global _eval_log # noqa: PLW0603
|
||||
if _eval_log is not None:
|
||||
return _eval_log
|
||||
|
||||
cfg = get_config()
|
||||
logger = logging.getLogger(cfg.log_name)
|
||||
if any(isinstance(h, CloudLoggingHandler) for h in logger.handlers):
|
||||
_eval_log = logger
|
||||
return logger
|
||||
|
||||
if cfg.cloud_logging_enabled:
|
||||
try:
|
||||
client = google.cloud.logging.Client(project=cfg.project_id)
|
||||
handler = CloudLoggingHandler(client, name=cfg.log_name) # async transport
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(getattr(logging, cfg.log_level.upper()))
|
||||
except Exception as e: # noqa: BLE001
|
||||
# Fallback to console if Cloud Logging is unavailable (local dev)
|
||||
logging.basicConfig(level=getattr(logging, cfg.log_level.upper()))
|
||||
logger = logging.getLogger(cfg.log_name)
|
||||
logger.warning("Cloud Logging setup failed; using console. Error: %s", e)
|
||||
else:
|
||||
logging.basicConfig(level=getattr(logging, cfg.log_level.upper()))
|
||||
logger = logging.getLogger(cfg.log_name)
|
||||
|
||||
_eval_log = logger
|
||||
return logger
|
||||
|
||||
|
||||
def log_structured_entry(
|
||||
message: str,
|
||||
severity: Literal["INFO", "WARNING", "ERROR"],
|
||||
custom_log: dict | None = None,
|
||||
) -> None:
|
||||
"""Emit a JSON-structured log row.
|
||||
|
||||
Args:
|
||||
message: Short label for the row (e.g., "Final agent turn").
|
||||
severity: "INFO" | "WARNING" | "ERROR"
|
||||
custom_log: A dict with your structured payload.
|
||||
|
||||
"""
|
||||
level = getattr(logging, severity.upper(), logging.INFO)
|
||||
logger = _get_logger()
|
||||
logger.log(
|
||||
level,
|
||||
message,
|
||||
extra={"json_fields": {"message": message, "custom": custom_log or {}}},
|
||||
)
|
||||
38
src/knowledge_search_mcp/models.py
Normal file
38
src/knowledge_search_mcp/models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Domain models for knowledge search MCP server."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google import genai
|
||||
|
||||
from .clients.vector_search import GoogleCloudVectorSearch
|
||||
from .config import Settings
|
||||
from .services.semantic_cache import KnowledgeSemanticCache
|
||||
|
||||
|
||||
class SourceNamespace(StrEnum):
|
||||
"""Allowed values for the 'source' namespace filter."""
|
||||
|
||||
EDUCACION_FINANCIERA = "Educacion Financiera"
|
||||
PRODUCTOS_Y_SERVICIOS = "Productos y Servicios"
|
||||
FUNCIONALIDADES_APP_MOVIL = "Funcionalidades de la App Movil"
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
"""Structured response item returned by the vector search API."""
|
||||
|
||||
id: str
|
||||
distance: float
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppContext:
|
||||
"""Shared resources initialised once at server startup."""
|
||||
|
||||
vector_search: "GoogleCloudVectorSearch"
|
||||
genai_client: "genai.Client"
|
||||
settings: "Settings"
|
||||
semantic_cache: "KnowledgeSemanticCache | None" = None
|
||||
168
src/knowledge_search_mcp/server.py
Normal file
168
src/knowledge_search_mcp/server.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""MCP server lifecycle management."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from google import genai
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from .clients.vector_search import GoogleCloudVectorSearch
|
||||
from .config import get_config
|
||||
from .logging import log_structured_entry
|
||||
from .models import AppContext
|
||||
from .services.semantic_cache import KnowledgeSemanticCache
|
||||
from .services.validation import (
|
||||
validate_gcs_access,
|
||||
validate_genai_access,
|
||||
validate_vector_search_access,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
|
||||
"""Create and configure the vector-search client for the server lifetime."""
|
||||
# Get config with proper types for initialization
|
||||
config_for_init = get_config()
|
||||
|
||||
log_structured_entry(
|
||||
"Initializing MCP server",
|
||||
"INFO",
|
||||
{
|
||||
"project_id": config_for_init.project_id,
|
||||
"location": config_for_init.location,
|
||||
"bucket": config_for_init.bucket,
|
||||
"index_name": config_for_init.index_name,
|
||||
},
|
||||
)
|
||||
|
||||
vs: GoogleCloudVectorSearch | None = None
|
||||
try:
|
||||
# Initialize vector search client
|
||||
log_structured_entry("Creating GoogleCloudVectorSearch client", "INFO")
|
||||
vs = GoogleCloudVectorSearch(
|
||||
project_id=config_for_init.project_id,
|
||||
location=config_for_init.location,
|
||||
bucket=config_for_init.bucket,
|
||||
index_name=config_for_init.index_name,
|
||||
)
|
||||
|
||||
# Configure endpoint
|
||||
log_structured_entry(
|
||||
"Configuring index endpoint",
|
||||
"INFO",
|
||||
{
|
||||
"endpoint_name": config_for_init.endpoint_name,
|
||||
"endpoint_domain": config_for_init.endpoint_domain,
|
||||
},
|
||||
)
|
||||
vs.configure_index_endpoint(
|
||||
name=config_for_init.endpoint_name,
|
||||
public_domain=config_for_init.endpoint_domain,
|
||||
)
|
||||
|
||||
# Initialize GenAI client
|
||||
log_structured_entry(
|
||||
"Creating GenAI client",
|
||||
"INFO",
|
||||
{
|
||||
"project_id": config_for_init.project_id,
|
||||
"location": config_for_init.location,
|
||||
},
|
||||
)
|
||||
genai_client = genai.Client(
|
||||
vertexai=True,
|
||||
project=config_for_init.project_id,
|
||||
location=config_for_init.location,
|
||||
)
|
||||
|
||||
# Validate credentials and configuration by testing actual resources
|
||||
# These validations are non-blocking - errors are logged but won't stop startup
|
||||
log_structured_entry("Starting validation of credentials and resources", "INFO")
|
||||
|
||||
validation_errors = []
|
||||
|
||||
# Run all validations
|
||||
config = get_config()
|
||||
genai_error = await validate_genai_access(genai_client, config)
|
||||
if genai_error:
|
||||
validation_errors.append(genai_error)
|
||||
|
||||
gcs_error = await validate_gcs_access(vs, config)
|
||||
if gcs_error:
|
||||
validation_errors.append(gcs_error)
|
||||
|
||||
vs_error = await validate_vector_search_access(vs, config)
|
||||
if vs_error:
|
||||
validation_errors.append(vs_error)
|
||||
|
||||
# Summary of validations
|
||||
if validation_errors:
|
||||
log_structured_entry(
|
||||
(
|
||||
"MCP server started with validation errors - "
|
||||
"service may not work correctly"
|
||||
),
|
||||
"WARNING",
|
||||
{
|
||||
"validation_errors": validation_errors,
|
||||
"error_count": len(validation_errors),
|
||||
},
|
||||
)
|
||||
else:
|
||||
log_structured_entry(
|
||||
"All validations passed - MCP server initialization complete", "INFO"
|
||||
)
|
||||
|
||||
# Initialize semantic cache if Redis is configured
|
||||
semantic_cache = None
|
||||
if config_for_init.redis_url:
|
||||
try:
|
||||
semantic_cache = KnowledgeSemanticCache(
|
||||
redis_url=config_for_init.redis_url,
|
||||
name=config_for_init.cache_name,
|
||||
vector_dims=config_for_init.cache_vector_dims,
|
||||
distance_threshold=config_for_init.cache_distance_threshold,
|
||||
ttl=config_for_init.cache_ttl,
|
||||
)
|
||||
log_structured_entry(
|
||||
"Semantic cache initialized",
|
||||
"INFO",
|
||||
{"redis_url": config_for_init.redis_url, "cache_name": config_for_init.cache_name},
|
||||
)
|
||||
except Exception as e:
|
||||
log_structured_entry(
|
||||
"Semantic cache initialization failed, continuing without cache",
|
||||
"WARNING",
|
||||
{"error": str(e), "error_type": type(e).__name__},
|
||||
)
|
||||
|
||||
yield AppContext(
|
||||
vector_search=vs,
|
||||
genai_client=genai_client,
|
||||
settings=config,
|
||||
semantic_cache=semantic_cache,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log_structured_entry(
|
||||
"Failed to initialize MCP server",
|
||||
"ERROR",
|
||||
{
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
},
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
log_structured_entry("MCP server lifespan ending", "INFO")
|
||||
# Clean up resources
|
||||
if vs is not None:
|
||||
try:
|
||||
await vs.close()
|
||||
log_structured_entry("Closed aiohttp sessions", "INFO")
|
||||
except Exception as e: # noqa: BLE001
|
||||
log_structured_entry(
|
||||
"Error closing aiohttp sessions",
|
||||
"WARNING",
|
||||
{"error": str(e), "error_type": type(e).__name__},
|
||||
)
|
||||
21
src/knowledge_search_mcp/services/__init__.py
Normal file
21
src/knowledge_search_mcp/services/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Service modules for business logic."""
|
||||
|
||||
from .search import (
|
||||
filter_search_results,
|
||||
format_search_results,
|
||||
generate_query_embedding,
|
||||
)
|
||||
from .validation import (
|
||||
validate_gcs_access,
|
||||
validate_genai_access,
|
||||
validate_vector_search_access,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"filter_search_results",
|
||||
"format_search_results",
|
||||
"generate_query_embedding",
|
||||
"validate_gcs_access",
|
||||
"validate_genai_access",
|
||||
"validate_vector_search_access",
|
||||
]
|
||||
101
src/knowledge_search_mcp/services/search.py
Normal file
101
src/knowledge_search_mcp/services/search.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Search helper functions."""
|
||||
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
|
||||
from knowledge_search_mcp.logging import log_structured_entry
|
||||
from knowledge_search_mcp.models import SearchResult
|
||||
|
||||
|
||||
async def generate_query_embedding(
|
||||
genai_client: genai.Client,
|
||||
embedding_model: str,
|
||||
query: str,
|
||||
) -> tuple[list[float], str | None]:
|
||||
"""Generate embedding for search query.
|
||||
|
||||
Returns:
|
||||
Tuple of (embedding vector, error message). Error message is None on success.
|
||||
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
return ([], "Error: Query cannot be empty")
|
||||
|
||||
log_structured_entry("Generating query embedding", "INFO")
|
||||
try:
|
||||
response = await genai_client.aio.models.embed_content(
|
||||
model=embedding_model,
|
||||
contents=query,
|
||||
config=genai_types.EmbedContentConfig(
|
||||
task_type="RETRIEVAL_QUERY",
|
||||
),
|
||||
)
|
||||
if not response.embeddings or not response.embeddings[0].values:
|
||||
return ([], "Error: Failed to generate embedding - empty response")
|
||||
embedding = response.embeddings[0].values
|
||||
return (embedding, None) # noqa: TRY300
|
||||
except Exception as e: # noqa: BLE001
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)
|
||||
|
||||
# Check if it's a rate limit error
|
||||
if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
|
||||
log_structured_entry(
|
||||
"Rate limit exceeded while generating embedding",
|
||||
"WARNING",
|
||||
{"error": error_msg, "error_type": error_type, "query": query[:100]},
|
||||
)
|
||||
return ([], "Error: API rate limit exceeded. Please try again later.")
|
||||
log_structured_entry(
|
||||
"Failed to generate query embedding",
|
||||
"ERROR",
|
||||
{"error": error_msg, "error_type": error_type, "query": query[:100]},
|
||||
)
|
||||
return ([], f"Error generating embedding: {error_msg}")
|
||||
|
||||
|
||||
def filter_search_results(
|
||||
results: list[SearchResult],
|
||||
min_similarity: float = 0.6,
|
||||
top_percent: float = 0.9,
|
||||
) -> list[SearchResult]:
|
||||
"""Filter search results by similarity thresholds.
|
||||
|
||||
Args:
|
||||
results: Raw search results from vector search.
|
||||
min_similarity: Minimum similarity score (distance) to include.
|
||||
top_percent: Keep results within this percentage of the top score.
|
||||
|
||||
Returns:
|
||||
Filtered list of search results.
|
||||
|
||||
"""
|
||||
if not results:
|
||||
return []
|
||||
|
||||
max_sim = max(r["distance"] for r in results)
|
||||
cutoff = max_sim * top_percent
|
||||
|
||||
return [
|
||||
s for s in results if s["distance"] > cutoff and s["distance"] > min_similarity
|
||||
]
|
||||
|
||||
|
||||
def format_search_results(results: list[SearchResult]) -> str:
|
||||
"""Format search results as XML-like documents.
|
||||
|
||||
Args:
|
||||
results: List of search results to format.
|
||||
|
||||
Returns:
|
||||
Formatted string with document tags.
|
||||
|
||||
"""
|
||||
if not results:
|
||||
return "No relevant documents found for your query."
|
||||
|
||||
formatted_results = [
|
||||
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
|
||||
for i, result in enumerate(results, start=1)
|
||||
]
|
||||
return "\n".join(formatted_results)
|
||||
97
src/knowledge_search_mcp/services/semantic_cache.py
Normal file
97
src/knowledge_search_mcp/services/semantic_cache.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# ruff: noqa: INP001
|
||||
"""Semantic cache backed by Redis for knowledge search results."""
|
||||
|
||||
from redisvl.extensions.cache.llm.semantic import SemanticCache
|
||||
from redisvl.utils.vectorize.custom import CustomVectorizer
|
||||
|
||||
from ..logging import log_structured_entry
|
||||
|
||||
|
||||
def _stub_embed(content: object) -> list[float]:
|
||||
"""Stub vectorizer so SemanticCache creates an index with the right dims.
|
||||
|
||||
Never called at runtime — we always pass pre-computed vectors to
|
||||
``acheck`` and ``astore``. Only invoked once by ``CustomVectorizer``
|
||||
at init time to discover the dimensionality.
|
||||
"""
|
||||
return [0.0] * _stub_embed.dims # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class KnowledgeSemanticCache:
|
||||
"""Thin wrapper around RedisVL SemanticCache with FLAT indexing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str,
|
||||
name: str = "knowledge_search_cache",
|
||||
vector_dims: int = 3072,
|
||||
distance_threshold: float = 0.12,
|
||||
ttl: int | None = 3600,
|
||||
) -> None:
|
||||
_stub_embed.dims = vector_dims # type: ignore[attr-defined]
|
||||
vectorizer = CustomVectorizer(embed=_stub_embed)
|
||||
|
||||
self._cache = SemanticCache(
|
||||
name=name,
|
||||
distance_threshold=distance_threshold,
|
||||
ttl=ttl,
|
||||
redis_url=redis_url,
|
||||
vectorizer=vectorizer,
|
||||
overwrite=False,
|
||||
)
|
||||
self._name = name
|
||||
|
||||
async def check(
|
||||
self,
|
||||
embedding: list[float],
|
||||
) -> str | None:
|
||||
"""Return cached response for a semantically similar query, or None."""
|
||||
try:
|
||||
results = await self._cache.acheck(
|
||||
vector=embedding,
|
||||
num_results=1,
|
||||
return_fields=["response", "prompt", "vector_distance"],
|
||||
)
|
||||
except Exception as e:
|
||||
log_structured_entry(
|
||||
"Semantic cache check failed, skipping cache",
|
||||
"WARNING",
|
||||
{"error": str(e), "error_type": type(e).__name__},
|
||||
)
|
||||
return None
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
hit = results[0]
|
||||
log_structured_entry(
|
||||
"Semantic cache hit",
|
||||
"INFO",
|
||||
{
|
||||
"vector_distance": hit.get("vector_distance"),
|
||||
"original_prompt": hit.get("prompt", "")[:100],
|
||||
},
|
||||
)
|
||||
return hit.get("response")
|
||||
|
||||
async def store(
|
||||
self,
|
||||
query: str,
|
||||
response: str,
|
||||
embedding: list[float],
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
"""Store a query/response pair in the cache."""
|
||||
try:
|
||||
await self._cache.astore(
|
||||
prompt=query,
|
||||
response=response,
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
log_structured_entry(
|
||||
"Semantic cache store failed",
|
||||
"WARNING",
|
||||
{"error": str(e), "error_type": type(e).__name__},
|
||||
)
|
||||
214
src/knowledge_search_mcp/services/validation.py
Normal file
214
src/knowledge_search_mcp/services/validation.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Validation functions for Google Cloud services."""
|
||||
|
||||
from gcloud.aio.auth import Token
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
|
||||
from knowledge_search_mcp.clients.vector_search import GoogleCloudVectorSearch
|
||||
from knowledge_search_mcp.config import Settings
|
||||
from knowledge_search_mcp.logging import log_structured_entry
|
||||
|
||||
# HTTP status codes
|
||||
HTTP_FORBIDDEN = 403
|
||||
HTTP_NOT_FOUND = 404
|
||||
|
||||
|
||||
async def validate_genai_access(
|
||||
genai_client: genai.Client, cfg: Settings
|
||||
) -> str | None:
|
||||
"""Validate GenAI embedding access.
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if successful.
|
||||
|
||||
"""
|
||||
log_structured_entry("Validating GenAI embedding access", "INFO")
|
||||
try:
|
||||
test_response = await genai_client.aio.models.embed_content(
|
||||
model=cfg.embedding_model,
|
||||
contents="test",
|
||||
config=genai_types.EmbedContentConfig(
|
||||
task_type="RETRIEVAL_QUERY",
|
||||
),
|
||||
)
|
||||
if test_response and test_response.embeddings:
|
||||
embedding_values = test_response.embeddings[0].values
|
||||
log_structured_entry(
|
||||
"GenAI embedding validation successful",
|
||||
"INFO",
|
||||
{
|
||||
"embedding_dimension": len(embedding_values)
|
||||
if embedding_values
|
||||
else 0
|
||||
},
|
||||
)
|
||||
return None
|
||||
msg = "Embedding validation returned empty response"
|
||||
log_structured_entry(msg, "WARNING")
|
||||
return msg # noqa: TRY300
|
||||
except Exception as e: # noqa: BLE001
|
||||
log_structured_entry(
|
||||
(
|
||||
"Failed to validate GenAI embedding access - "
|
||||
"service may not work correctly"
|
||||
),
|
||||
"WARNING",
|
||||
{"error": str(e), "error_type": type(e).__name__},
|
||||
)
|
||||
return f"GenAI: {e!s}"
|
||||
|
||||
|
||||
async def validate_gcs_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str | None:
|
||||
"""Validate GCS bucket access.
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if successful.
|
||||
|
||||
"""
|
||||
log_structured_entry("Validating GCS bucket access", "INFO", {"bucket": cfg.bucket})
|
||||
try:
|
||||
session = vs.storage._get_aio_session() # noqa: SLF001
|
||||
token_obj = Token(
|
||||
session=session,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
access_token = await token_obj.get()
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
async with session.get(
|
||||
f"https://storage.googleapis.com/storage/v1/b/{cfg.bucket}/o?maxResults=1",
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status == HTTP_FORBIDDEN:
|
||||
msg = f"Access denied to bucket '{cfg.bucket}'. Check permissions."
|
||||
log_structured_entry(
|
||||
(
|
||||
"GCS bucket validation failed - access denied - "
|
||||
"service may not work correctly"
|
||||
),
|
||||
"WARNING",
|
||||
{"bucket": cfg.bucket, "status": response.status},
|
||||
)
|
||||
return msg
|
||||
if response.status == HTTP_NOT_FOUND:
|
||||
msg = f"Bucket '{cfg.bucket}' not found. Check bucket name and project."
|
||||
log_structured_entry(
|
||||
(
|
||||
"GCS bucket validation failed - not found - "
|
||||
"service may not work correctly"
|
||||
),
|
||||
"WARNING",
|
||||
{"bucket": cfg.bucket, "status": response.status},
|
||||
)
|
||||
return msg
|
||||
if not response.ok:
|
||||
body = await response.text()
|
||||
msg = f"Failed to access bucket '{cfg.bucket}': {response.status}"
|
||||
log_structured_entry(
|
||||
"GCS bucket validation failed - service may not work correctly",
|
||||
"WARNING",
|
||||
{"bucket": cfg.bucket, "status": response.status, "response": body},
|
||||
)
|
||||
return msg
|
||||
log_structured_entry(
|
||||
"GCS bucket validation successful", "INFO", {"bucket": cfg.bucket}
|
||||
)
|
||||
return None
|
||||
except Exception as e: # noqa: BLE001
|
||||
log_structured_entry(
|
||||
"Failed to validate GCS bucket access - service may not work correctly",
|
||||
"WARNING",
|
||||
{"error": str(e), "error_type": type(e).__name__, "bucket": cfg.bucket},
|
||||
)
|
||||
return f"GCS: {e!s}"
|
||||
|
||||
|
||||
async def validate_vector_search_access(
|
||||
vs: GoogleCloudVectorSearch, cfg: Settings
|
||||
) -> str | None:
|
||||
"""Validate vector search endpoint access.
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if successful.
|
||||
|
||||
"""
|
||||
log_structured_entry(
|
||||
"Validating vector search endpoint access",
|
||||
"INFO",
|
||||
{"endpoint_name": cfg.endpoint_name},
|
||||
)
|
||||
try:
|
||||
headers = await vs._async_get_auth_headers() # noqa: SLF001
|
||||
session = vs._get_aio_session() # noqa: SLF001
|
||||
endpoint_url = (
|
||||
f"https://{cfg.location}-aiplatform.googleapis.com/v1/{cfg.endpoint_name}"
|
||||
)
|
||||
|
||||
async with session.get(endpoint_url, headers=headers) as response:
|
||||
if response.status == HTTP_FORBIDDEN:
|
||||
msg = (
|
||||
f"Access denied to endpoint '{cfg.endpoint_name}'. "
|
||||
"Check permissions."
|
||||
)
|
||||
log_structured_entry(
|
||||
(
|
||||
"Vector search endpoint validation failed - "
|
||||
"access denied - service may not work correctly"
|
||||
),
|
||||
"WARNING",
|
||||
{"endpoint": cfg.endpoint_name, "status": response.status},
|
||||
)
|
||||
return msg
|
||||
if response.status == HTTP_NOT_FOUND:
|
||||
msg = (
|
||||
f"Endpoint '{cfg.endpoint_name}' not found. "
|
||||
"Check endpoint name and project."
|
||||
)
|
||||
log_structured_entry(
|
||||
(
|
||||
"Vector search endpoint validation failed - "
|
||||
"not found - service may not work correctly"
|
||||
),
|
||||
"WARNING",
|
||||
{"endpoint": cfg.endpoint_name, "status": response.status},
|
||||
)
|
||||
return msg
|
||||
if not response.ok:
|
||||
body = await response.text()
|
||||
msg = (
|
||||
f"Failed to access endpoint '{cfg.endpoint_name}': "
|
||||
f"{response.status}"
|
||||
)
|
||||
log_structured_entry(
|
||||
(
|
||||
"Vector search endpoint validation failed - "
|
||||
"service may not work correctly"
|
||||
),
|
||||
"WARNING",
|
||||
{
|
||||
"endpoint": cfg.endpoint_name,
|
||||
"status": response.status,
|
||||
"response": body,
|
||||
},
|
||||
)
|
||||
return msg
|
||||
log_structured_entry(
|
||||
"Vector search endpoint validation successful",
|
||||
"INFO",
|
||||
{"endpoint": cfg.endpoint_name},
|
||||
)
|
||||
return None
|
||||
except Exception as e: # noqa: BLE001
|
||||
log_structured_entry(
|
||||
(
|
||||
"Failed to validate vector search endpoint access - "
|
||||
"service may not work correctly"
|
||||
),
|
||||
"WARNING",
|
||||
{
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
"endpoint": cfg.endpoint_name,
|
||||
},
|
||||
)
|
||||
return f"Vector Search: {e!s}"
|
||||
5
src/knowledge_search_mcp/utils/__init__.py
Normal file
5
src/knowledge_search_mcp/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Utility modules for knowledge search MCP server."""
|
||||
|
||||
from .cache import LRUCache
|
||||
|
||||
__all__ = ["LRUCache"]
|
||||
32
src/knowledge_search_mcp/utils/cache.py
Normal file
32
src/knowledge_search_mcp/utils/cache.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""LRU cache implementation."""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class LRUCache:
|
||||
"""Simple LRU cache with size limit."""
|
||||
|
||||
def __init__(self, max_size: int = 100) -> None:
|
||||
"""Initialize cache with maximum size."""
|
||||
self.cache: OrderedDict[str, bytes] = OrderedDict()
|
||||
self.max_size = max_size
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
"""Get item from cache, returning None if not found."""
|
||||
if key not in self.cache:
|
||||
return None
|
||||
# Move to end to mark as recently used
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: str, value: bytes) -> None:
|
||||
"""Put item in cache, evicting oldest if at capacity."""
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
self.cache[key] = value
|
||||
if len(self.cache) > self.max_size:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if key exists in cache."""
|
||||
return key in self.cache
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for knowledge-search-mcp."""
|
||||
36
tests/conftest.py
Normal file
36
tests/conftest.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Pytest configuration and shared fixtures."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_env_vars(monkeypatch):
|
||||
"""Set required environment variables for testing."""
|
||||
test_env = {
|
||||
"PROJECT_ID": "test-project",
|
||||
"LOCATION": "us-central1",
|
||||
"BUCKET": "test-bucket",
|
||||
"INDEX_NAME": "test-index",
|
||||
"DEPLOYED_INDEX_ID": "test-deployed-index",
|
||||
"ENDPOINT_NAME": "projects/test/locations/us-central1/indexEndpoints/test",
|
||||
"ENDPOINT_DOMAIN": "test.us-central1-aiplatform.googleapis.com",
|
||||
}
|
||||
for key, value in test_env.items():
|
||||
monkeypatch.setenv(key, value)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gcs_storage():
|
||||
"""Mock Google Cloud Storage client."""
|
||||
mock = MagicMock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_search():
|
||||
"""Mock vector search client."""
|
||||
mock = MagicMock()
|
||||
return mock
|
||||
56
tests/test_config.py
Normal file
56
tests/test_config.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Tests for configuration management."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from knowledge_search_mcp.config import Settings
|
||||
|
||||
|
||||
def test_settings_from_env():
|
||||
"""Test that Settings can be loaded from environment variables."""
|
||||
# Environment is set by conftest.py fixture
|
||||
settings = Settings.model_validate({})
|
||||
|
||||
assert settings.project_id == "test-project"
|
||||
assert settings.location == "us-central1"
|
||||
assert settings.bucket == "test-bucket"
|
||||
assert settings.index_name == "test-index"
|
||||
assert settings.deployed_index_id == "test-deployed-index"
|
||||
|
||||
|
||||
def test_settings_defaults():
|
||||
"""Test that Settings has correct default values."""
|
||||
settings = Settings.model_validate({})
|
||||
|
||||
assert settings.embedding_model == "gemini-embedding-001"
|
||||
assert settings.search_limit == 10
|
||||
assert settings.log_name == "va_agent_evaluation_logs"
|
||||
assert settings.log_level == "INFO"
|
||||
|
||||
|
||||
def test_settings_custom_values(monkeypatch):
|
||||
"""Test that Settings can be customized via environment."""
|
||||
monkeypatch.setenv("EMBEDDING_MODEL", "custom-embedding-model")
|
||||
monkeypatch.setenv("SEARCH_LIMIT", "20")
|
||||
monkeypatch.setenv("LOG_LEVEL", "DEBUG")
|
||||
|
||||
settings = Settings.model_validate({})
|
||||
|
||||
assert settings.embedding_model == "custom-embedding-model"
|
||||
assert settings.search_limit == 20
|
||||
assert settings.log_level == "DEBUG"
|
||||
|
||||
|
||||
def test_settings_validation_error():
|
||||
"""Test that Settings raises ValidationError when required fields are missing."""
|
||||
# Clear all env vars temporarily
|
||||
required_vars = [
|
||||
"PROJECT_ID", "LOCATION", "BUCKET", "INDEX_NAME",
|
||||
"DEPLOYED_INDEX_ID", "ENDPOINT_NAME", "ENDPOINT_DOMAIN"
|
||||
]
|
||||
|
||||
# This should work with conftest fixture
|
||||
settings = Settings.model_validate({})
|
||||
assert settings.project_id == "test-project"
|
||||
411
tests/test_main_tool.py
Normal file
411
tests/test_main_tool.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""Tests for the main knowledge_search tool."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from knowledge_search_mcp.__main__ import knowledge_search
|
||||
from knowledge_search_mcp.models import AppContext, SourceNamespace, SearchResult
|
||||
|
||||
|
||||
class TestKnowledgeSearch:
|
||||
"""Tests for knowledge_search tool function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_context(self):
|
||||
"""Create a mock AppContext."""
|
||||
app = MagicMock(spec=AppContext)
|
||||
|
||||
# Mock genai_client
|
||||
app.genai_client = MagicMock()
|
||||
|
||||
# Mock vector_search
|
||||
app.vector_search = MagicMock()
|
||||
app.vector_search.async_run_query = AsyncMock()
|
||||
|
||||
# Mock settings
|
||||
app.settings = MagicMock()
|
||||
app.settings.embedding_model = "models/text-embedding-004"
|
||||
app.settings.deployed_index_id = "test-deployed-index"
|
||||
app.settings.search_limit = 10
|
||||
|
||||
# No semantic cache by default
|
||||
app.semantic_cache = None
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(self, mock_app_context):
|
||||
"""Create a mock MCP Context."""
|
||||
ctx = MagicMock()
|
||||
ctx.request_context = MagicMock()
|
||||
ctx.request_context.lifespan_context = mock_app_context
|
||||
return ctx
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding(self):
|
||||
"""Create a sample embedding vector."""
|
||||
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_results(self):
|
||||
"""Create sample search results."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": "First document content"},
|
||||
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content"},
|
||||
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content"},
|
||||
]
|
||||
return results
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_successful_search(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test successful search workflow."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("What is financial education?", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
|
||||
mock_generate.assert_called_once()
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
|
||||
deployed_index_id="test-deployed-index",
|
||||
query=sample_embedding,
|
||||
limit=10,
|
||||
source=None,
|
||||
)
|
||||
mock_filter.assert_called_once_with(sample_search_results)
|
||||
mock_format.assert_called_once_with(sample_search_results)
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_embedding_generation_error(self, mock_generate, mock_context):
|
||||
"""Test handling of embedding generation error."""
|
||||
# Setup mock to return error
|
||||
mock_generate.return_value = ([], "Error: API rate limit exceeded. Please try again later.")
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "Error: API rate limit exceeded. Please try again later."
|
||||
mock_generate.assert_called_once()
|
||||
# Vector search should not be called
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called()
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_empty_query_error(self, mock_generate, mock_context):
|
||||
"""Test handling of empty query."""
|
||||
# Setup mock to return error for empty query
|
||||
mock_generate.return_value = ([], "Error: Query cannot be empty")
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "Error: Query cannot be empty"
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_vector_search_error(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of vector search error."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = Exception(
|
||||
"Vector search service unavailable"
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Vector search service unavailable" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_empty_search_results(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding
|
||||
):
|
||||
"""Test handling of empty search results."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = []
|
||||
mock_filter.return_value = []
|
||||
mock_format.return_value = "No relevant documents found for your query."
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("obscure query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "No relevant documents found for your query."
|
||||
mock_format.assert_called_once_with([])
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_filtered_results_empty(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test when filtering removes all results."""
|
||||
# Setup mocks - results exist but get filtered out
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = [] # All filtered out
|
||||
mock_format.return_value = "No relevant documents found for your query."
|
||||
|
||||
# Execute
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Assert
|
||||
assert result == "No relevant documents found for your query."
|
||||
mock_filter.assert_called_once_with(sample_search_results)
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_source_filter_parameter(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test that source filter is passed correctly to vector search."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "formatted results"
|
||||
|
||||
# Execute with source filter
|
||||
source_filter = SourceNamespace.EDUCACION_FINANCIERA
|
||||
result = await knowledge_search("test query", mock_context, source=source_filter)
|
||||
|
||||
# Assert
|
||||
assert result == "formatted results"
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
|
||||
deployed_index_id="test-deployed-index",
|
||||
query=sample_embedding,
|
||||
limit=10,
|
||||
source=source_filter,
|
||||
)
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_all_source_filters(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test all available source filter values."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "results"
|
||||
|
||||
# Test each source filter
|
||||
for source in SourceNamespace:
|
||||
result = await knowledge_search("test query", mock_context, source=source)
|
||||
assert result == "results"
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_vector_search_timeout(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of vector search timeout."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = TimeoutError(
|
||||
"Request timed out"
|
||||
)
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Request timed out" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_vector_search_connection_error(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of vector search connection error."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = ConnectionError(
|
||||
"Connection refused"
|
||||
)
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Connection refused" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
async def test_format_results_unexpected_error(
|
||||
self,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test handling of unexpected error in format_search_results."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
|
||||
# Mock format_search_results to raise an error
|
||||
with patch('knowledge_search_mcp.__main__.format_search_results', side_effect=ValueError("Format error")):
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Unexpected error during search:" in result
|
||||
assert "Format error" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_filter_results_unexpected_error(self, mock_generate, mock_context, sample_embedding):
|
||||
"""Test handling of unexpected error in filter_search_results."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = [
|
||||
{"id": "doc1", "distance": 0.9, "content": "test"}
|
||||
]
|
||||
|
||||
# Mock filter_search_results to raise an error
|
||||
with patch('knowledge_search_mcp.__main__.filter_search_results', side_effect=TypeError("Filter error")):
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert "Unexpected error during search:" in result
|
||||
assert "Filter error" in result
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_long_query_truncation_in_logs(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test that long queries are handled correctly."""
|
||||
# Setup mocks
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "results"
|
||||
|
||||
# Execute with very long query
|
||||
long_query = "a" * 500
|
||||
result = await knowledge_search(long_query, mock_context)
|
||||
|
||||
# Assert - should succeed
|
||||
assert result == "results"
|
||||
# Verify generate_query_embedding was called with full query
|
||||
assert mock_generate.call_args[0][2] == long_query
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_multiple_results_returned(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding
|
||||
):
|
||||
"""Test handling of multiple search results."""
|
||||
# Create larger result set
|
||||
large_results: list[SearchResult] = [
|
||||
{"id": f"doc{i}.txt", "distance": 0.9 - (i * 0.05), "content": f"Content {i}"}
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = large_results
|
||||
mock_filter.return_value = large_results[:5] # Filter to top 5
|
||||
mock_format.return_value = "formatted 5 results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "formatted 5 results"
|
||||
mock_filter.assert_called_once_with(large_results)
|
||||
mock_format.assert_called_once_with(large_results[:5])
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||
async def test_settings_values_used_correctly(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_search_results
|
||||
):
|
||||
"""Test that settings values are used correctly."""
|
||||
# Customize settings
|
||||
mock_context.request_context.lifespan_context.settings.embedding_model = "custom-model"
|
||||
mock_context.request_context.lifespan_context.settings.deployed_index_id = "custom-index"
|
||||
mock_context.request_context.lifespan_context.settings.search_limit = 20
|
||||
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||
mock_filter.return_value = sample_search_results
|
||||
mock_format.return_value = "results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
# Verify embedding model
|
||||
assert mock_generate.call_args[0][1] == "custom-model"
|
||||
|
||||
# Verify vector search parameters
|
||||
call_kwargs = mock_context.request_context.lifespan_context.vector_search.async_run_query.call_args.kwargs
|
||||
assert call_kwargs["deployed_index_id"] == "custom-index"
|
||||
assert call_kwargs["limit"] == 20
|
||||
|
||||
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||
async def test_graceful_degradation_on_partial_failure(
|
||||
self, mock_generate, mock_context, sample_embedding
|
||||
):
|
||||
"""Test that errors are caught and returned as strings, not raised."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = RuntimeError(
|
||||
"Critical failure"
|
||||
)
|
||||
|
||||
# Should not raise, should return error message
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Error performing vector search:" in result
|
||||
assert "Critical failure" in result
|
||||
110
tests/test_search.py
Normal file
110
tests/test_search.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tests for vector search functionality."""
|
||||
|
||||
import io
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from knowledge_search_mcp import (
|
||||
GoogleCloudFileStorage,
|
||||
GoogleCloudVectorSearch,
|
||||
LRUCache,
|
||||
SourceNamespace,
|
||||
)
|
||||
|
||||
|
||||
class TestGoogleCloudFileStorage:
|
||||
"""Tests for GoogleCloudFileStorage."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test storage initialization."""
|
||||
storage = GoogleCloudFileStorage(bucket="test-bucket")
|
||||
assert storage.bucket_name == "test-bucket"
|
||||
assert isinstance(storage._cache, LRUCache)
|
||||
assert storage._cache.max_size == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit(self):
|
||||
"""Test that cached files are returned without fetching."""
|
||||
storage = GoogleCloudFileStorage(bucket="test-bucket")
|
||||
test_content = b"cached content"
|
||||
storage._cache.put("test.md", test_content)
|
||||
|
||||
result = await storage.async_get_file_stream("test.md")
|
||||
|
||||
assert result.read() == test_content
|
||||
assert result.name == "test.md"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_miss(self):
|
||||
"""Test that uncached files are fetched from GCS."""
|
||||
storage = GoogleCloudFileStorage(bucket="test-bucket")
|
||||
test_content = b"fetched content"
|
||||
|
||||
# Mock the storage download
|
||||
with patch.object(storage, '_get_aio_storage') as mock_storage_getter:
|
||||
mock_storage = AsyncMock()
|
||||
mock_storage.download = AsyncMock(return_value=test_content)
|
||||
mock_storage_getter.return_value = mock_storage
|
||||
|
||||
result = await storage.async_get_file_stream("test.md")
|
||||
|
||||
assert result.read() == test_content
|
||||
assert storage._cache.get("test.md") == test_content
|
||||
|
||||
|
||||
class TestGoogleCloudVectorSearch:
|
||||
"""Tests for GoogleCloudVectorSearch."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test vector search client initialization."""
|
||||
vs = GoogleCloudVectorSearch(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
bucket="test-bucket",
|
||||
index_name="test-index",
|
||||
)
|
||||
|
||||
assert vs.project_id == "test-project"
|
||||
assert vs.location == "us-central1"
|
||||
assert vs.index_name == "test-index"
|
||||
|
||||
def test_configure_index_endpoint(self):
|
||||
"""Test endpoint configuration."""
|
||||
vs = GoogleCloudVectorSearch(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
bucket="test-bucket",
|
||||
)
|
||||
|
||||
vs.configure_index_endpoint(
|
||||
name="test-endpoint",
|
||||
public_domain="test.domain.com",
|
||||
)
|
||||
|
||||
assert vs._endpoint_name == "test-endpoint"
|
||||
assert vs._endpoint_domain == "test.domain.com"
|
||||
|
||||
def test_configure_index_endpoint_validation(self):
|
||||
"""Test that endpoint configuration validates inputs."""
|
||||
vs = GoogleCloudVectorSearch(
|
||||
project_id="test-project",
|
||||
location="us-central1",
|
||||
bucket="test-bucket",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="endpoint name"):
|
||||
vs.configure_index_endpoint(name="", public_domain="test.com")
|
||||
|
||||
with pytest.raises(ValueError, match="endpoint domain"):
|
||||
vs.configure_index_endpoint(name="test", public_domain="")
|
||||
|
||||
|
||||
class TestSourceNamespace:
|
||||
"""Tests for SourceNamespace enum."""
|
||||
|
||||
def test_source_namespace_values(self):
|
||||
"""Test that SourceNamespace has expected values."""
|
||||
assert SourceNamespace.EDUCACION_FINANCIERA.value == "Educacion Financiera"
|
||||
assert SourceNamespace.PRODUCTOS_Y_SERVICIOS.value == "Productos y Servicios"
|
||||
assert SourceNamespace.FUNCIONALIDADES_APP_MOVIL.value == "Funcionalidades de la App Movil"
|
||||
381
tests/test_search_services.py
Normal file
381
tests/test_search_services.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Tests for search service functions."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from knowledge_search_mcp.services.search import (
|
||||
generate_query_embedding,
|
||||
filter_search_results,
|
||||
format_search_results,
|
||||
)
|
||||
from knowledge_search_mcp.models import SearchResult
|
||||
|
||||
|
||||
class TestGenerateQueryEmbedding:
|
||||
"""Tests for generate_query_embedding function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai_client(self):
|
||||
"""Create a mock genai client."""
|
||||
client = MagicMock()
|
||||
client.aio = MagicMock()
|
||||
client.aio.models = MagicMock()
|
||||
return client
|
||||
|
||||
async def test_successful_embedding_generation(self, mock_genai_client):
|
||||
"""Test successful embedding generation."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"What is financial education?"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert error is None
|
||||
assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
mock_genai_client.aio.models.embed_content.assert_called_once()
|
||||
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
|
||||
assert call_kwargs["model"] == "models/text-embedding-004"
|
||||
assert call_kwargs["contents"] == "What is financial education?"
|
||||
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
|
||||
|
||||
async def test_empty_query_string(self, mock_genai_client):
|
||||
"""Test handling of empty query string."""
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
""
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: Query cannot be empty"
|
||||
mock_genai_client.aio.models.embed_content.assert_not_called()
|
||||
|
||||
async def test_whitespace_only_query(self, mock_genai_client):
|
||||
"""Test handling of whitespace-only query."""
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
" \t\n "
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: Query cannot be empty"
|
||||
mock_genai_client.aio.models.embed_content.assert_not_called()
|
||||
|
||||
async def test_rate_limit_error_429(self, mock_genai_client):
|
||||
"""Test handling of 429 rate limit error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("429 Too Many Requests")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: API rate limit exceeded. Please try again later."
|
||||
|
||||
async def test_rate_limit_error_resource_exhausted(self, mock_genai_client):
|
||||
"""Test handling of RESOURCE_EXHAUSTED error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("RESOURCE_EXHAUSTED: Quota exceeded")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error == "Error: API rate limit exceeded. Please try again later."
|
||||
|
||||
async def test_generic_api_error(self, mock_genai_client):
|
||||
"""Test handling of generic API error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ValueError("Invalid model name")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"invalid-model",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert "Error generating embedding: Invalid model name" in error
|
||||
|
||||
async def test_network_error(self, mock_genai_client):
|
||||
"""Test handling of network error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ConnectionError("Network unreachable")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
"test query"
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert "Error generating embedding: Network unreachable" in error
|
||||
|
||||
async def test_long_query_truncation_in_logging(self, mock_genai_client):
|
||||
"""Test that long queries are truncated in error logging."""
|
||||
long_query = "a" * 200
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("API error")
|
||||
)
|
||||
|
||||
embedding, error = await generate_query_embedding(
|
||||
mock_genai_client,
|
||||
"models/text-embedding-004",
|
||||
long_query
|
||||
)
|
||||
|
||||
assert embedding == []
|
||||
assert error is not None
|
||||
|
||||
|
||||
class TestFilterSearchResults:
|
||||
"""Tests for filter_search_results function."""
|
||||
|
||||
def test_empty_results(self):
|
||||
"""Test filtering empty results list."""
|
||||
filtered = filter_search_results([])
|
||||
assert filtered == []
|
||||
|
||||
def test_single_result_above_thresholds(self):
|
||||
"""Test single result above both thresholds."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.85, "content": "test content"}
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_single_result_below_min_similarity(self):
|
||||
"""Test single result below minimum similarity threshold."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.5, "content": "test content"}
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert filtered == []
|
||||
|
||||
def test_multiple_results_all_above_thresholds(self):
|
||||
"""Test multiple results all above thresholds."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.90, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.85, "content": "content 3"},
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.8)
|
||||
# max_sim = 0.95, cutoff = 0.95 * 0.8 = 0.76
|
||||
# Results with distance > 0.76 and > 0.6: all three
|
||||
assert len(filtered) == 3
|
||||
|
||||
def test_top_percent_filtering(self):
|
||||
"""Test filtering by top_percent threshold."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 1.0, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.95, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.85, "content": "content 3"},
|
||||
{"id": "doc4", "distance": 0.70, "content": "content 4"},
|
||||
]
|
||||
# max_sim = 1.0, cutoff = 1.0 * 0.9 = 0.9
|
||||
# Results with distance > 0.9: doc1 (1.0), doc2 (0.95)
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 2
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
assert filtered[1]["id"] == "doc2"
|
||||
|
||||
def test_min_similarity_filtering(self):
|
||||
"""Test filtering by minimum similarity threshold."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.75, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.55, "content": "content 3"},
|
||||
]
|
||||
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
|
||||
# doc1 > 0.855 and > 0.7: included
|
||||
# doc2 < 0.855: excluded by top_percent
|
||||
# doc3 < 0.7: excluded by min_similarity
|
||||
filtered = filter_search_results(results, min_similarity=0.7, top_percent=0.9)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_default_parameters(self):
|
||||
"""Test filtering with default parameters."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.85, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.50, "content": "content 3"},
|
||||
]
|
||||
# Default: min_similarity=0.6, top_percent=0.9
|
||||
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
|
||||
# doc1 > 0.855 and > 0.6: included
|
||||
# doc2 < 0.855: excluded
|
||||
# doc3 < 0.6: excluded
|
||||
filtered = filter_search_results(results)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_all_results_filtered_out(self):
|
||||
"""Test when all results are filtered out."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.55, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.45, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.35, "content": "content 3"},
|
||||
]
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert filtered == []
|
||||
|
||||
def test_exact_threshold_boundaries(self):
|
||||
"""Test behavior at exact threshold boundaries."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.9, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.6, "content": "content 2"},
|
||||
]
|
||||
# max_sim = 0.9, cutoff = 0.9 * 0.9 = 0.81
|
||||
# doc1: 0.9 > 0.81 and 0.9 > 0.6: included
|
||||
# doc2: 0.6 < 0.81: excluded
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["id"] == "doc1"
|
||||
|
||||
def test_identical_distances(self):
|
||||
"""Test filtering with identical distance values."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1", "distance": 0.8, "content": "content 1"},
|
||||
{"id": "doc2", "distance": 0.8, "content": "content 2"},
|
||||
{"id": "doc3", "distance": 0.8, "content": "content 3"},
|
||||
]
|
||||
# max_sim = 0.8, cutoff = 0.8 * 0.9 = 0.72
|
||||
# All have distance 0.8 > 0.72 and > 0.6: all included
|
||||
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||
assert len(filtered) == 3
|
||||
|
||||
|
||||
class TestFormatSearchResults:
|
||||
"""Tests for format_search_results function."""
|
||||
|
||||
def test_empty_results(self):
|
||||
"""Test formatting empty results list."""
|
||||
formatted = format_search_results([])
|
||||
assert formatted == "No relevant documents found for your query."
|
||||
|
||||
def test_single_result(self):
|
||||
"""Test formatting single result."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": "This is the content."}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=doc1.txt>\nThis is the content.\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_multiple_results(self):
|
||||
"""Test formatting multiple results."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": "First document content."},
|
||||
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content."},
|
||||
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content."},
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = (
|
||||
"<document 1 name=doc1.txt>\nFirst document content.\n</document 1>\n"
|
||||
"<document 2 name=doc2.txt>\nSecond document content.\n</document 2>\n"
|
||||
"<document 3 name=doc3.txt>\nThird document content.\n</document 3>"
|
||||
)
|
||||
assert formatted == expected
|
||||
|
||||
def test_multiline_content(self):
|
||||
"""Test formatting results with multiline content."""
|
||||
results: list[SearchResult] = [
|
||||
{
|
||||
"id": "doc1.txt",
|
||||
"distance": 0.95,
|
||||
"content": "Line 1\nLine 2\nLine 3"
|
||||
}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=doc1.txt>\nLine 1\nLine 2\nLine 3\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_special_characters_in_content(self):
|
||||
"""Test formatting with special characters in content."""
|
||||
results: list[SearchResult] = [
|
||||
{
|
||||
"id": "doc1.txt",
|
||||
"distance": 0.95,
|
||||
"content": "Content with <special> & \"characters\""
|
||||
}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = '<document 1 name=doc1.txt>\nContent with <special> & "characters"\n</document 1>'
|
||||
assert formatted == expected
|
||||
|
||||
def test_special_characters_in_document_id(self):
|
||||
"""Test formatting with special characters in document ID."""
|
||||
results: list[SearchResult] = [
|
||||
{
|
||||
"id": "path/to/doc-name_v2.txt",
|
||||
"distance": 0.95,
|
||||
"content": "Some content"
|
||||
}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=path/to/doc-name_v2.txt>\nSome content\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_empty_content(self):
|
||||
"""Test formatting result with empty content."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": ""}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
expected = "<document 1 name=doc1.txt>\n\n</document 1>"
|
||||
assert formatted == expected
|
||||
|
||||
def test_document_numbering(self):
|
||||
"""Test that document numbering starts at 1 and increments correctly."""
|
||||
results: list[SearchResult] = [
|
||||
{"id": "a.txt", "distance": 0.9, "content": "A"},
|
||||
{"id": "b.txt", "distance": 0.8, "content": "B"},
|
||||
{"id": "c.txt", "distance": 0.7, "content": "C"},
|
||||
{"id": "d.txt", "distance": 0.6, "content": "D"},
|
||||
{"id": "e.txt", "distance": 0.5, "content": "E"},
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
|
||||
assert "<document 1 name=a.txt>" in formatted
|
||||
assert "</document 1>" in formatted
|
||||
assert "<document 2 name=b.txt>" in formatted
|
||||
assert "</document 2>" in formatted
|
||||
assert "<document 3 name=c.txt>" in formatted
|
||||
assert "</document 3>" in formatted
|
||||
assert "<document 4 name=d.txt>" in formatted
|
||||
assert "</document 4>" in formatted
|
||||
assert "<document 5 name=e.txt>" in formatted
|
||||
assert "</document 5>" in formatted
|
||||
|
||||
def test_very_long_content(self):
|
||||
"""Test formatting with very long content."""
|
||||
long_content = "A" * 10000
|
||||
results: list[SearchResult] = [
|
||||
{"id": "doc1.txt", "distance": 0.95, "content": long_content}
|
||||
]
|
||||
formatted = format_search_results(results)
|
||||
assert f"<document 1 name=doc1.txt>\n{long_content}\n</document 1>" == formatted
|
||||
assert len(formatted) > 10000
|
||||
272
tests/test_semantic_cache.py
Normal file
272
tests/test_semantic_cache.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Tests for the semantic cache service and its integration."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from knowledge_search_mcp.__main__ import knowledge_search
|
||||
from knowledge_search_mcp.models import AppContext, SearchResult, SourceNamespace
|
||||
from knowledge_search_mcp.services.semantic_cache import KnowledgeSemanticCache
|
||||
|
||||
|
||||
class TestKnowledgeSemanticCache:
|
||||
"""Unit tests for the KnowledgeSemanticCache wrapper."""
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
def test_init_creates_cache(self, mock_sc_cls, mock_vec_cls):
|
||||
"""Test that __init__ creates the SemanticCache with correct params."""
|
||||
mock_vectorizer = MagicMock()
|
||||
mock_vec_cls.return_value = mock_vectorizer
|
||||
|
||||
KnowledgeSemanticCache(
|
||||
redis_url="redis://localhost:6379",
|
||||
name="test_cache",
|
||||
vector_dims=3072,
|
||||
distance_threshold=0.12,
|
||||
ttl=3600,
|
||||
)
|
||||
|
||||
mock_vec_cls.assert_called_once()
|
||||
mock_sc_cls.assert_called_once_with(
|
||||
name="test_cache",
|
||||
distance_threshold=0.12,
|
||||
ttl=3600,
|
||||
redis_url="redis://localhost:6379",
|
||||
vectorizer=mock_vectorizer,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_check_returns_response_on_hit(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test cache check returns response when a similar vector is found."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.acheck = AsyncMock(return_value=[
|
||||
{"response": "cached answer", "prompt": "original q", "vector_distance": 0.05},
|
||||
])
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
result = await cache.check([0.1] * 3072)
|
||||
|
||||
assert result == "cached answer"
|
||||
mock_inner.acheck.assert_awaited_once_with(
|
||||
vector=[0.1] * 3072,
|
||||
num_results=1,
|
||||
)
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_check_returns_none_on_miss(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test cache check returns None when no similar vector is found."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.acheck = AsyncMock(return_value=[])
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
result = await cache.check([0.1] * 3072)
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_check_returns_none_on_error(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test cache check degrades gracefully on Redis errors."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.acheck = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
result = await cache.check([0.1] * 3072)
|
||||
|
||||
assert result is None
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_store_calls_astore(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test store delegates to SemanticCache.astore."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.astore = AsyncMock()
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
await cache.store("query", "response", [0.1] * 3072, {"key": "val"})
|
||||
|
||||
mock_inner.astore.assert_awaited_once_with(
|
||||
prompt="query",
|
||||
response="response",
|
||||
vector=[0.1] * 3072,
|
||||
metadata={"key": "val"},
|
||||
)
|
||||
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.CustomVectorizer")
|
||||
@patch("knowledge_search_mcp.services.semantic_cache.SemanticCache")
|
||||
async def test_store_does_not_raise_on_error(self, mock_sc_cls, _mock_vec_cls):
|
||||
"""Test store degrades gracefully on Redis errors."""
|
||||
mock_inner = MagicMock()
|
||||
mock_inner.astore = AsyncMock(side_effect=ConnectionError("Redis down"))
|
||||
mock_sc_cls.return_value = mock_inner
|
||||
|
||||
cache = KnowledgeSemanticCache(redis_url="redis://localhost:6379")
|
||||
await cache.store("query", "response", [0.1] * 3072)
|
||||
|
||||
|
||||
class TestKnowledgeSearchCacheIntegration:
|
||||
"""Tests for cache integration in the knowledge_search tool."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cache(self):
|
||||
"""Create a mock KnowledgeSemanticCache."""
|
||||
cache = MagicMock(spec=KnowledgeSemanticCache)
|
||||
cache.check = AsyncMock(return_value=None)
|
||||
cache.store = AsyncMock()
|
||||
return cache
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_context(self, mock_cache):
|
||||
"""Create a mock AppContext with semantic cache."""
|
||||
app = MagicMock(spec=AppContext)
|
||||
app.genai_client = MagicMock()
|
||||
app.vector_search = MagicMock()
|
||||
app.vector_search.async_run_query = AsyncMock()
|
||||
app.settings = MagicMock()
|
||||
app.settings.embedding_model = "gemini-embedding-001"
|
||||
app.settings.deployed_index_id = "test-deployed-index"
|
||||
app.settings.search_limit = 10
|
||||
app.semantic_cache = mock_cache
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context(self, mock_app_context):
|
||||
"""Create a mock MCP Context."""
|
||||
ctx = MagicMock()
|
||||
ctx.request_context.lifespan_context = mock_app_context
|
||||
return ctx
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding(self):
|
||||
return [0.1] * 3072
|
||||
|
||||
@pytest.fixture
|
||||
def sample_results(self) -> list[SearchResult]:
|
||||
return [
|
||||
{"id": "doc1", "distance": 0.95, "content": "Content 1"},
|
||||
{"id": "doc2", "distance": 0.90, "content": "Content 2"},
|
||||
]
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
async def test_cache_hit_skips_vector_search(
|
||||
self, mock_generate, mock_context, sample_embedding, mock_cache
|
||||
):
|
||||
"""On cache hit, vector search is never called."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_cache.check.return_value = "cached result"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "cached result"
|
||||
mock_cache.check.assert_awaited_once_with(sample_embedding)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called()
|
||||
mock_cache.store.assert_not_awaited()
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
@patch("knowledge_search_mcp.__main__.filter_search_results")
|
||||
@patch("knowledge_search_mcp.__main__.format_search_results")
|
||||
async def test_cache_miss_stores_result(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_results,
|
||||
mock_cache,
|
||||
):
|
||||
"""On cache miss, results are fetched and stored in cache."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_cache.check.return_value = None
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_results
|
||||
mock_filter.return_value = sample_results
|
||||
mock_format.return_value = "formatted results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "formatted results"
|
||||
mock_cache.check.assert_awaited_once_with(sample_embedding)
|
||||
mock_cache.store.assert_awaited_once_with(
|
||||
"test query", "formatted results", sample_embedding,
|
||||
)
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
@patch("knowledge_search_mcp.__main__.filter_search_results")
|
||||
@patch("knowledge_search_mcp.__main__.format_search_results")
|
||||
async def test_cache_skipped_when_source_filter_set(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_results,
|
||||
mock_cache,
|
||||
):
|
||||
"""Cache is bypassed when a source filter is specified."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_results
|
||||
mock_filter.return_value = sample_results
|
||||
mock_format.return_value = "formatted results"
|
||||
|
||||
result = await knowledge_search(
|
||||
"test query", mock_context, source=SourceNamespace.EDUCACION_FINANCIERA,
|
||||
)
|
||||
|
||||
assert result == "formatted results"
|
||||
mock_cache.check.assert_not_awaited()
|
||||
mock_cache.store.assert_not_awaited()
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
@patch("knowledge_search_mcp.__main__.filter_search_results")
|
||||
@patch("knowledge_search_mcp.__main__.format_search_results")
|
||||
async def test_cache_not_stored_when_no_results(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
mock_cache,
|
||||
):
|
||||
"""Empty results are not stored in the cache."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_cache.check.return_value = None
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = []
|
||||
mock_filter.return_value = []
|
||||
mock_format.return_value = "No relevant documents found for your query."
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "No relevant documents found for your query."
|
||||
mock_cache.store.assert_not_awaited()
|
||||
|
||||
@patch("knowledge_search_mcp.__main__.generate_query_embedding")
|
||||
@patch("knowledge_search_mcp.__main__.filter_search_results")
|
||||
@patch("knowledge_search_mcp.__main__.format_search_results")
|
||||
async def test_works_without_cache(
|
||||
self,
|
||||
mock_format,
|
||||
mock_filter,
|
||||
mock_generate,
|
||||
mock_context,
|
||||
sample_embedding,
|
||||
sample_results,
|
||||
):
|
||||
"""Tool works normally when semantic_cache is None."""
|
||||
mock_generate.return_value = (sample_embedding, None)
|
||||
mock_context.request_context.lifespan_context.semantic_cache = None
|
||||
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_results
|
||||
mock_filter.return_value = sample_results
|
||||
mock_format.return_value = "formatted results"
|
||||
|
||||
result = await knowledge_search("test query", mock_context)
|
||||
|
||||
assert result == "formatted results"
|
||||
436
tests/test_validation_services.py
Normal file
436
tests/test_validation_services.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""Tests for validation service functions."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from knowledge_search_mcp.services.validation import (
|
||||
validate_genai_access,
|
||||
validate_gcs_access,
|
||||
validate_vector_search_access,
|
||||
)
|
||||
from knowledge_search_mcp.config import Settings
|
||||
|
||||
|
||||
class TestValidateGenAIAccess:
|
||||
"""Tests for validate_genai_access function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings."""
|
||||
settings = MagicMock(spec=Settings)
|
||||
settings.embedding_model = "models/text-embedding-004"
|
||||
settings.project_id = "test-project"
|
||||
settings.location = "us-central1"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_genai_client(self):
|
||||
"""Create a mock genai client."""
|
||||
client = MagicMock()
|
||||
client.aio = MagicMock()
|
||||
client.aio.models = MagicMock()
|
||||
return client
|
||||
|
||||
async def test_successful_validation(self, mock_genai_client, mock_settings):
|
||||
"""Test successful GenAI access validation."""
|
||||
# Setup mock response
|
||||
mock_response = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = [0.1] * 768 # Typical embedding dimension
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Execute
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
# Assert
|
||||
assert error is None
|
||||
mock_genai_client.aio.models.embed_content.assert_called_once()
|
||||
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
|
||||
assert call_kwargs["model"] == "models/text-embedding-004"
|
||||
assert call_kwargs["contents"] == "test"
|
||||
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
|
||||
|
||||
async def test_empty_response(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of empty response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.embeddings = []
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error == "Embedding validation returned empty response"
|
||||
|
||||
async def test_none_response(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of None response."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=None)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error == "Embedding validation returned empty response"
|
||||
|
||||
async def test_api_permission_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of permission denied error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=PermissionError("Permission denied for GenAI API")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Permission denied for GenAI API" in error
|
||||
|
||||
async def test_api_quota_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of quota exceeded error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=Exception("Quota exceeded")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Quota exceeded" in error
|
||||
|
||||
async def test_network_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of network error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ConnectionError("Network unreachable")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Network unreachable" in error
|
||||
|
||||
async def test_invalid_model_error(self, mock_genai_client, mock_settings):
|
||||
"""Test handling of invalid model error."""
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||
side_effect=ValueError("Invalid model name")
|
||||
)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GenAI:" in error
|
||||
assert "Invalid model name" in error
|
||||
|
||||
async def test_embeddings_with_zero_values(self, mock_genai_client, mock_settings):
|
||||
"""Test validation with empty embedding values."""
|
||||
mock_response = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.values = []
|
||||
mock_response.embeddings = [mock_embedding]
|
||||
|
||||
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||
|
||||
# Should succeed even with empty values, as long as embeddings exist
|
||||
assert error is None
|
||||
|
||||
|
||||
class TestValidateGCSAccess:
|
||||
"""Tests for validate_gcs_access function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings."""
|
||||
settings = MagicMock(spec=Settings)
|
||||
settings.bucket = "test-bucket"
|
||||
settings.project_id = "test-project"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_search(self):
|
||||
"""Create a mock vector search client."""
|
||||
vs = MagicMock()
|
||||
vs.storage = MagicMock()
|
||||
return vs
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock aiohttp session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Create a mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.text = AsyncMock(return_value='{"items": []}')
|
||||
return response
|
||||
|
||||
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test successful GCS bucket access validation."""
|
||||
# Setup mocks
|
||||
mock_response.status = 200
|
||||
mock_response.ok = True
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is None
|
||||
mock_session.get.assert_called_once()
|
||||
call_args = mock_session.get.call_args
|
||||
assert "test-bucket" in call_args[0][0]
|
||||
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-access-token"
|
||||
|
||||
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 403 access denied."""
|
||||
mock_response.status = 403
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Access denied to bucket 'test-bucket'" in error
|
||||
assert "permissions" in error.lower()
|
||||
|
||||
async def test_bucket_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 404 bucket not found."""
|
||||
mock_response.status = 404
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Bucket 'test-bucket' not found" in error
|
||||
assert "bucket name" in error.lower()
|
||||
|
||||
async def test_server_error_500(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 500 server error."""
|
||||
mock_response.status = 500
|
||||
mock_response.ok = False
|
||||
mock_response.text = AsyncMock(return_value='{"error": "Internal server error"}')
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Failed to access bucket 'test-bucket': 500" in error
|
||||
|
||||
async def test_token_acquisition_error(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of token acquisition error."""
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(side_effect=Exception("Failed to get access token"))
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GCS:" in error
|
||||
assert "Failed to get access token" in error
|
||||
|
||||
async def test_network_error(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of network error."""
|
||||
mock_session.get = MagicMock(side_effect=ConnectionError("Network unreachable"))
|
||||
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||
|
||||
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||
mock_token = MockToken.return_value
|
||||
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||
|
||||
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "GCS:" in error
|
||||
assert "Network unreachable" in error
|
||||
|
||||
|
||||
class TestValidateVectorSearchAccess:
|
||||
"""Tests for validate_vector_search_access function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Create mock settings."""
|
||||
settings = MagicMock(spec=Settings)
|
||||
settings.endpoint_name = "projects/test/locations/us-central1/indexEndpoints/test-endpoint"
|
||||
settings.location = "us-central1"
|
||||
return settings
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_search(self):
|
||||
"""Create a mock vector search client."""
|
||||
vs = MagicMock()
|
||||
vs._async_get_auth_headers = AsyncMock(return_value={"Authorization": "Bearer fake-token"})
|
||||
return vs
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
"""Create a mock aiohttp session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
"""Create a mock HTTP response."""
|
||||
response = MagicMock()
|
||||
response.text = AsyncMock(return_value='{"name": "test-endpoint"}')
|
||||
return response
|
||||
|
||||
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test successful vector search endpoint validation."""
|
||||
mock_response.status = 200
|
||||
mock_response.ok = True
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is None
|
||||
mock_vector_search._async_get_auth_headers.assert_called_once()
|
||||
mock_session.get.assert_called_once()
|
||||
call_args = mock_session.get.call_args
|
||||
assert "us-central1-aiplatform.googleapis.com" in call_args[0][0]
|
||||
assert "test-endpoint" in call_args[0][0]
|
||||
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-token"
|
||||
|
||||
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 403 access denied."""
|
||||
mock_response.status = 403
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Access denied to endpoint" in error
|
||||
assert "test-endpoint" in error
|
||||
assert "permissions" in error.lower()
|
||||
|
||||
async def test_endpoint_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 404 endpoint not found."""
|
||||
mock_response.status = 404
|
||||
mock_response.ok = False
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "not found" in error.lower()
|
||||
assert "test-endpoint" in error
|
||||
|
||||
async def test_server_error_503(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test handling of 503 service unavailable."""
|
||||
mock_response.status = 503
|
||||
mock_response.ok = False
|
||||
mock_response.text = AsyncMock(return_value='{"error": "Service unavailable"}')
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Failed to access endpoint" in error
|
||||
assert "503" in error
|
||||
|
||||
async def test_auth_header_error(self, mock_vector_search, mock_settings):
|
||||
"""Test handling of authentication header error."""
|
||||
mock_vector_search._async_get_auth_headers = AsyncMock(
|
||||
side_effect=Exception("Failed to get auth headers")
|
||||
)
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Vector Search:" in error
|
||||
assert "Failed to get auth headers" in error
|
||||
|
||||
async def test_network_timeout(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of network timeout."""
|
||||
mock_session.get = MagicMock(side_effect=TimeoutError("Request timed out"))
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Vector Search:" in error
|
||||
assert "Request timed out" in error
|
||||
|
||||
async def test_connection_error(self, mock_vector_search, mock_settings, mock_session):
|
||||
"""Test handling of connection error."""
|
||||
mock_session.get = MagicMock(side_effect=ConnectionError("Connection refused"))
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is not None
|
||||
assert "Vector Search:" in error
|
||||
assert "Connection refused" in error
|
||||
|
||||
async def test_endpoint_url_construction(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||
"""Test that endpoint URL is constructed correctly."""
|
||||
mock_response.status = 200
|
||||
mock_response.ok = True
|
||||
mock_session.get = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_vector_search._get_aio_session.return_value = mock_session
|
||||
|
||||
# Custom location
|
||||
mock_settings.location = "europe-west1"
|
||||
mock_settings.endpoint_name = "projects/my-project/locations/europe-west1/indexEndpoints/my-endpoint"
|
||||
|
||||
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||
|
||||
assert error is None
|
||||
call_args = mock_session.get.call_args
|
||||
url = call_args[0][0]
|
||||
assert "europe-west1-aiplatform.googleapis.com" in url
|
||||
assert "my-endpoint" in url
|
||||
Reference in New Issue
Block a user