Compare commits
34 Commits
56e181a772
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 6c7635c3e3 | |||
| 0cdf9cd44e | |||
| d39b8a6ea7 | |||
| 86ed34887b | |||
| 694b060fa4 | |||
| d69c4e4f4a | |||
| f6e122b5a9 | |||
| dba94107a5 | |||
| d3cd8d5291 | |||
| 8dfd2048a5 | |||
| 3e2386b9b6 | |||
|
|
42e1660143 | ||
| 208d5ebebf | |||
| 83ed64326f | |||
| 4c59da0c22 | |||
|
|
cbf3ca7df4 | ||
| e9b4c93a20 | |||
| b95bb72b24 | |||
| a3ba340224 | |||
| 0fd97a31a5 | |||
| a8c611fbec | |||
| 13c8e122de | |||
| 753b5c7871 | |||
| 6feeeff4f3 | |||
| 3b7dd91a71 | |||
| 427de45522 | |||
|
|
f0b9d1b27a | ||
|
|
bf2cc2f556 | ||
|
|
b92a5d5b0e | ||
|
|
bd107a027a | ||
|
|
dcc05d697e | ||
|
|
82764bd60b | ||
|
|
54eb6f240c | ||
|
|
bb19770663 |
9
.dockerignore
Normal file
9
.dockerignore
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
.git/
|
||||||
|
.venv/
|
||||||
|
.ruff_cache/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.env
|
||||||
|
agent.py
|
||||||
|
AGENTS.md
|
||||||
|
|
||||||
43
.gitea/workflows/ci.yml
Normal file
43
.gitea/workflows/ci.yml
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- run: uv sync --frozen
|
||||||
|
- name: Ruff check
|
||||||
|
run: uv run ruff check
|
||||||
|
- name: Ruff format check
|
||||||
|
run: uv run ruff format --check
|
||||||
|
|
||||||
|
typecheck:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- run: uv sync --frozen
|
||||||
|
- name: Type check
|
||||||
|
run: uv run ty check
|
||||||
|
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
- run: uv sync --frozen
|
||||||
|
- name: Run tests
|
||||||
|
run: uv run pytest --cov
|
||||||
216
.gitignore
vendored
Normal file
216
.gitignore
vendored
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[codz]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
# Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
# poetry.lock
|
||||||
|
# poetry.toml
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||||
|
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||||
|
# pdm.lock
|
||||||
|
# pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# pixi
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||||
|
# pixi.lock
|
||||||
|
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||||
|
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||||
|
.pixi
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
*.rdb
|
||||||
|
*.aof
|
||||||
|
*.pid
|
||||||
|
|
||||||
|
# RabbitMQ
|
||||||
|
mnesia/
|
||||||
|
rabbitmq/
|
||||||
|
rabbitmq-data/
|
||||||
|
|
||||||
|
# ActiveMQ
|
||||||
|
activemq-data/
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.envrc
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
# .idea/
|
||||||
|
|
||||||
|
# Abstra
|
||||||
|
# Abstra is an AI-powered process automation framework.
|
||||||
|
# Ignore directories containing user credentials, local state, and settings.
|
||||||
|
# Learn more at https://abstra.io/docs
|
||||||
|
.abstra/
|
||||||
|
|
||||||
|
# Visual Studio Code
|
||||||
|
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||||
|
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||||
|
# you could uncomment the following to ignore the entire vscode folder
|
||||||
|
# .vscode/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
||||||
|
# Marimo
|
||||||
|
marimo/_static/
|
||||||
|
marimo/_lsp/
|
||||||
|
__marimo__/
|
||||||
|
|
||||||
|
# Streamlit
|
||||||
|
.streamlit/secrets.toml
|
||||||
3
CLAUDE.md
Normal file
3
CLAUDE.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
Use `uv` for project management
|
||||||
|
Linter: `uv run ruff check`
|
||||||
|
Type-checking: `uv run ty check`
|
||||||
15
DockerfileConnector
Normal file
15
DockerfileConnector
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY pyproject.toml uv.lock README.md ./
|
||||||
|
RUN uv sync --no-dev --frozen --no-install-project
|
||||||
|
|
||||||
|
COPY src/ src/
|
||||||
|
RUN uv sync --no-dev --frozen
|
||||||
|
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
|
||||||
|
CMD ["uv", "run", "python", "-m", "knowledge_search_mcp", "--transport", "streamable-http", "--port", "8000"]
|
||||||
106
README.md
106
README.md
@@ -0,0 +1,106 @@
|
|||||||
|
# knowledge-search-mcp
|
||||||
|
|
||||||
|
An MCP (Model Context Protocol) server that exposes a `knowledge_search` tool for semantic search over a knowledge base backed by Vertex AI Vector Search and Google Cloud Storage.
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
1. A natural-language query is embedded using a Gemini embedding model.
|
||||||
|
2. The embedding is sent to a Vertex AI Matching Engine index endpoint to find nearest neighbors.
|
||||||
|
3. The matched document contents are fetched from a GCS bucket and returned to the caller.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Python ≥ 3.12
|
||||||
|
- [uv](https://docs.astral.sh/uv/) for dependency management
|
||||||
|
- A Google Cloud project with:
|
||||||
|
- A Vertex AI Vector Search index and deployed endpoint
|
||||||
|
- A GCS bucket containing the indexed document chunks
|
||||||
|
- Application Default Credentials (or a service account) with appropriate permissions
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Create a `config.yaml` file or `.env` file (see `Settings` in `src/knowledge_search_mcp/config.py` for all options):
|
||||||
|
|
||||||
|
```env
|
||||||
|
PROJECT_ID=my-gcp-project
|
||||||
|
LOCATION=us-central1
|
||||||
|
BUCKET=my-knowledge-bucket
|
||||||
|
INDEX_NAME=my-index
|
||||||
|
DEPLOYED_INDEX_ID=my-deployed-index
|
||||||
|
ENDPOINT_NAME=projects/…/locations/…/indexEndpoints/…
|
||||||
|
ENDPOINT_DOMAIN=123456789.us-central1-aiplatform.googleapis.com
|
||||||
|
# optional
|
||||||
|
EMBEDDING_MODEL=gemini-embedding-001
|
||||||
|
SEARCH_LIMIT=10
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run the MCP server
|
||||||
|
|
||||||
|
**Using the installed command (recommended):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# stdio transport (default)
|
||||||
|
uv run knowledge-search-mcp
|
||||||
|
|
||||||
|
# SSE transport for remote clients
|
||||||
|
uv run knowledge-search-mcp --transport sse --port 8080
|
||||||
|
|
||||||
|
# streamable-http transport
|
||||||
|
uv run knowledge-search-mcp --transport streamable-http --port 8080
|
||||||
|
```
|
||||||
|
|
||||||
|
**Or run directly:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python -m knowledge_search_mcp.main
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run the interactive agent (ADK)
|
||||||
|
|
||||||
|
The bundled agent spawns the MCP server as a subprocess and provides a REPL:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python agent.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Or connect to an already-running SSE server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python agent.py --remote http://localhost:8080/sse
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
## Docker
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -t knowledge-search-mcp .
|
||||||
|
docker run -p 8080:8080 --env-file .env knowledge-search-mcp
|
||||||
|
```
|
||||||
|
|
||||||
|
The container starts the server in SSE mode on the port specified by `PORT` (default `8080`).
|
||||||
|
|
||||||
|
## Project structure
|
||||||
|
|
||||||
|
```
|
||||||
|
src/knowledge_search_mcp/
|
||||||
|
├── __init__.py Package initialization
|
||||||
|
├── config.py Configuration management (Settings, args parsing)
|
||||||
|
├── logging.py Cloud Logging setup
|
||||||
|
└── main.py MCP server, vector search client, and GCS storage helper
|
||||||
|
agent.py Interactive ADK agent that consumes the MCP server
|
||||||
|
tests/ Test suite
|
||||||
|
pyproject.toml Project metadata, dependencies, and entry points
|
||||||
|
```
|
||||||
|
|||||||
401
main.py
401
main.py
@@ -1,401 +0,0 @@
|
|||||||
# ruff: noqa: INP001
|
|
||||||
"""Async helpers for querying Vertex AI vector search via MCP."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncIterator, Sequence
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import BinaryIO, TypedDict
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from gcloud.aio.auth import Token
|
|
||||||
from gcloud.aio.storage import Storage
|
|
||||||
from google import genai
|
|
||||||
from google.genai import types as genai_types
|
|
||||||
from mcp.server.fastmcp import Context, FastMCP
|
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
HTTP_TOO_MANY_REQUESTS = 429
|
|
||||||
HTTP_SERVER_ERROR = 500
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleCloudFileStorage:
|
|
||||||
"""Cache-aware helper for downloading files from Google Cloud Storage."""
|
|
||||||
|
|
||||||
def __init__(self, bucket: str) -> None:
|
|
||||||
"""Initialize the storage helper."""
|
|
||||||
self.bucket_name = bucket
|
|
||||||
self._aio_session: aiohttp.ClientSession | None = None
|
|
||||||
self._aio_storage: Storage | None = None
|
|
||||||
self._cache: dict[str, bytes] = {}
|
|
||||||
|
|
||||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._aio_session is None or self._aio_session.closed:
|
|
||||||
connector = aiohttp.TCPConnector(
|
|
||||||
limit=300,
|
|
||||||
limit_per_host=50,
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=60)
|
|
||||||
self._aio_session = aiohttp.ClientSession(
|
|
||||||
timeout=timeout,
|
|
||||||
connector=connector,
|
|
||||||
)
|
|
||||||
return self._aio_session
|
|
||||||
|
|
||||||
def _get_aio_storage(self) -> Storage:
|
|
||||||
if self._aio_storage is None:
|
|
||||||
self._aio_storage = Storage(
|
|
||||||
session=self._get_aio_session(),
|
|
||||||
)
|
|
||||||
return self._aio_storage
|
|
||||||
|
|
||||||
async def async_get_file_stream(
|
|
||||||
self,
|
|
||||||
file_name: str,
|
|
||||||
max_retries: int = 3,
|
|
||||||
) -> BinaryIO:
|
|
||||||
"""Get a file asynchronously with retry on transient errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_name: The blob name to retrieve.
|
|
||||||
max_retries: Maximum number of retry attempts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A BytesIO stream with the file contents.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TimeoutError: If all retry attempts fail.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if file_name in self._cache:
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
storage_client = self._get_aio_storage()
|
|
||||||
last_exception: Exception | None = None
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
self._cache[file_name] = await storage_client.download(
|
|
||||||
self.bucket_name,
|
|
||||||
file_name,
|
|
||||||
)
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
except TimeoutError as exc:
|
|
||||||
last_exception = exc
|
|
||||||
logger.warning(
|
|
||||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
|
||||||
self.bucket_name,
|
|
||||||
file_name,
|
|
||||||
attempt + 1,
|
|
||||||
max_retries,
|
|
||||||
)
|
|
||||||
except aiohttp.ClientResponseError as exc:
|
|
||||||
last_exception = exc
|
|
||||||
if (
|
|
||||||
exc.status == HTTP_TOO_MANY_REQUESTS
|
|
||||||
or exc.status >= HTTP_SERVER_ERROR
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"HTTP %d downloading gs://%s/%s (attempt %d/%d)",
|
|
||||||
exc.status,
|
|
||||||
self.bucket_name,
|
|
||||||
file_name,
|
|
||||||
attempt + 1,
|
|
||||||
max_retries,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
delay = 0.5 * (2**attempt)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
msg = (
|
|
||||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
|
||||||
f"after {max_retries} attempts"
|
|
||||||
)
|
|
||||||
raise TimeoutError(msg) from last_exception
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResult(TypedDict):
|
|
||||||
"""Structured response item returned by the vector search API."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
distance: float
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleCloudVectorSearch:
|
|
||||||
"""Minimal async client for the Vertex AI Matching Engine REST API."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
project_id: str,
|
|
||||||
location: str,
|
|
||||||
bucket: str,
|
|
||||||
index_name: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Store configuration used to issue Matching Engine queries."""
|
|
||||||
self.project_id = project_id
|
|
||||||
self.location = location
|
|
||||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
|
||||||
self.index_name = index_name
|
|
||||||
self._aio_session: aiohttp.ClientSession | None = None
|
|
||||||
self._async_token: Token | None = None
|
|
||||||
self._endpoint_domain: str | None = None
|
|
||||||
self._endpoint_name: str | None = None
|
|
||||||
|
|
||||||
async def _async_get_auth_headers(self) -> dict[str, str]:
|
|
||||||
if self._async_token is None:
|
|
||||||
self._async_token = Token(
|
|
||||||
session=self._get_aio_session(),
|
|
||||||
scopes=[
|
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
access_token = await self._async_token.get()
|
|
||||||
return {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._aio_session is None or self._aio_session.closed:
|
|
||||||
connector = aiohttp.TCPConnector(
|
|
||||||
limit=300,
|
|
||||||
limit_per_host=50,
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=60)
|
|
||||||
self._aio_session = aiohttp.ClientSession(
|
|
||||||
timeout=timeout,
|
|
||||||
connector=connector,
|
|
||||||
)
|
|
||||||
return self._aio_session
|
|
||||||
|
|
||||||
def configure_index_endpoint(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
name: str,
|
|
||||||
public_domain: str,
|
|
||||||
) -> None:
|
|
||||||
"""Persist the metadata needed to access a deployed endpoint."""
|
|
||||||
if not name:
|
|
||||||
msg = "Index endpoint name must be a non-empty string."
|
|
||||||
raise ValueError(msg)
|
|
||||||
if not public_domain:
|
|
||||||
msg = "Index endpoint domain must be a non-empty public domain."
|
|
||||||
raise ValueError(msg)
|
|
||||||
self._endpoint_name = name
|
|
||||||
self._endpoint_domain = public_domain
|
|
||||||
|
|
||||||
async def async_run_query(
|
|
||||||
self,
|
|
||||||
deployed_index_id: str,
|
|
||||||
query: Sequence[float],
|
|
||||||
limit: int,
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""Run an async similarity search via the REST API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployed_index_id: The ID of the deployed index.
|
|
||||||
query: The embedding vector for the search query.
|
|
||||||
limit: Maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of matched items with id, distance, and content.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self._endpoint_domain is None or self._endpoint_name is None:
|
|
||||||
msg = (
|
|
||||||
"Missing endpoint metadata. Call "
|
|
||||||
"`configure_index_endpoint` before querying."
|
|
||||||
)
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
domain = self._endpoint_domain
|
|
||||||
endpoint_id = self._endpoint_name.split("/")[-1]
|
|
||||||
url = (
|
|
||||||
f"https://{domain}/v1/projects/{self.project_id}"
|
|
||||||
f"/locations/{self.location}"
|
|
||||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
|
||||||
)
|
|
||||||
payload = {
|
|
||||||
"deployed_index_id": deployed_index_id,
|
|
||||||
"queries": [
|
|
||||||
{
|
|
||||||
"datapoint": {"feature_vector": list(query)},
|
|
||||||
"neighbor_count": limit,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = await self._async_get_auth_headers()
|
|
||||||
session = self._get_aio_session()
|
|
||||||
async with session.post(
|
|
||||||
url,
|
|
||||||
json=payload,
|
|
||||||
headers=headers,
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
data = await response.json()
|
|
||||||
|
|
||||||
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
|
||||||
content_tasks = []
|
|
||||||
for neighbor in neighbors:
|
|
||||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
|
||||||
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
|
|
||||||
content_tasks.append(
|
|
||||||
self.storage.async_get_file_stream(file_path),
|
|
||||||
)
|
|
||||||
|
|
||||||
file_streams = await asyncio.gather(*content_tasks)
|
|
||||||
results: list[SearchResult] = []
|
|
||||||
for neighbor, stream in zip(
|
|
||||||
neighbors,
|
|
||||||
file_streams,
|
|
||||||
strict=True,
|
|
||||||
):
|
|
||||||
results.append(
|
|
||||||
SearchResult(
|
|
||||||
id=neighbor["datapoint"]["datapointId"],
|
|
||||||
distance=neighbor["distance"],
|
|
||||||
content=stream.read().decode("utf-8"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# MCP Server
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
"""Server configuration populated from environment variables."""
|
|
||||||
|
|
||||||
project_id: str
|
|
||||||
location: str
|
|
||||||
bucket: str
|
|
||||||
index_name: str
|
|
||||||
deployed_index_id: str
|
|
||||||
endpoint_name: str
|
|
||||||
endpoint_domain: str
|
|
||||||
embedding_model: str = "text-embedding-005"
|
|
||||||
search_limit: int = 10
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AppContext:
|
|
||||||
"""Shared resources initialised once at server startup."""
|
|
||||||
|
|
||||||
vector_search: GoogleCloudVectorSearch
|
|
||||||
genai_client: genai.Client
|
|
||||||
settings: Settings
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
|
|
||||||
"""Create and configure the vector-search client for the server lifetime."""
|
|
||||||
cfg = Settings.model_validate({})
|
|
||||||
|
|
||||||
vs = GoogleCloudVectorSearch(
|
|
||||||
project_id=cfg.project_id,
|
|
||||||
location=cfg.location,
|
|
||||||
bucket=cfg.bucket,
|
|
||||||
index_name=cfg.index_name,
|
|
||||||
)
|
|
||||||
vs.configure_index_endpoint(
|
|
||||||
name=cfg.endpoint_name,
|
|
||||||
public_domain=cfg.endpoint_domain,
|
|
||||||
)
|
|
||||||
|
|
||||||
genai_client = genai.Client(
|
|
||||||
vertexai=True,
|
|
||||||
project=cfg.project_id,
|
|
||||||
location=cfg.location,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield AppContext(
|
|
||||||
vector_search=vs,
|
|
||||||
genai_client=genai_client,
|
|
||||||
settings=cfg,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
mcp = FastMCP("knowledge-search", lifespan=lifespan)
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
async def knowledge_search(
|
|
||||||
query: str,
|
|
||||||
ctx: Context,
|
|
||||||
) -> str:
|
|
||||||
"""Search a knowledge base using a natural-language query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The text query to search for.
|
|
||||||
ctx: MCP request context (injected automatically).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A formatted string containing matched documents with id and content.
|
|
||||||
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
|
|
||||||
app: AppContext = ctx.request_context.lifespan_context
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
min_sim = 0.6
|
|
||||||
|
|
||||||
response = await app.genai_client.aio.models.embed_content(
|
|
||||||
model=app.settings.embedding_model,
|
|
||||||
contents=query,
|
|
||||||
config=genai_types.EmbedContentConfig(
|
|
||||||
task_type="RETRIEVAL_QUERY",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
embedding = response.embeddings[0].values
|
|
||||||
t_embed = time.perf_counter()
|
|
||||||
|
|
||||||
search_results = await app.vector_search.async_run_query(
|
|
||||||
deployed_index_id=app.settings.deployed_index_id,
|
|
||||||
query=embedding,
|
|
||||||
limit=app.settings.search_limit,
|
|
||||||
)
|
|
||||||
t_search = time.perf_counter()
|
|
||||||
|
|
||||||
# Apply similarity filtering
|
|
||||||
if search_results:
|
|
||||||
max_sim = max(r["distance"] for r in search_results)
|
|
||||||
cutoff = max_sim * 0.9
|
|
||||||
search_results = [
|
|
||||||
s
|
|
||||||
for s in search_results
|
|
||||||
if s["distance"] > cutoff and s["distance"] > min_sim
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"knowledge_search timing: embedding=%sms, vector_search=%sms, total=%sms, chunks=%s",
|
|
||||||
round((t_embed - t0) * 1000, 1),
|
|
||||||
round((t_search - t_embed) * 1000, 1),
|
|
||||||
round((t_search - t0) * 1000, 1),
|
|
||||||
[s["id"] for s in search_results],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format results as XML-like documents
|
|
||||||
formatted_results = [
|
|
||||||
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
|
|
||||||
for i, result in enumerate(search_results, start=1)
|
|
||||||
]
|
|
||||||
return "\n".join(formatted_results)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
mcp.run()
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "knowledge-search-mcp"
|
name = "knowledge-search-mcp"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "MCP server for semantic search over Vertex AI Vector Search"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
@@ -12,10 +12,43 @@ dependencies = [
|
|||||||
"google-genai>=1.64.0",
|
"google-genai>=1.64.0",
|
||||||
"mcp[cli]>=1.26.0",
|
"mcp[cli]>=1.26.0",
|
||||||
"pydantic-settings>=2.9.1",
|
"pydantic-settings>=2.9.1",
|
||||||
|
"pyyaml>=6.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
knowledge-search-mcp = "knowledge_search_mcp.__main__:main"
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
|
"google-adk>=1.25.1",
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-asyncio>=0.24.0",
|
||||||
|
"pytest-cov>=6.0.0",
|
||||||
"ruff>=0.15.2",
|
"ruff>=0.15.2",
|
||||||
"ty>=0.0.18",
|
"ty>=0.0.18",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
pythonpath = ["."]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||||
|
build-backend = "uv_build"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
exclude = ["scripts", "tests"]
|
||||||
|
|
||||||
|
[tool.ty.src]
|
||||||
|
exclude = ["scripts", "tests"]
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ['ALL']
|
||||||
|
ignore = [
|
||||||
|
'D203', # one-blank-line-before-class
|
||||||
|
'D213', # multi-line-summary-second-line
|
||||||
|
'COM812', # missing-trailing-comma
|
||||||
|
'ANN401', # dynamically-typed-any
|
||||||
|
'ERA001', # commented-out-code
|
||||||
|
]
|
||||||
|
|||||||
111
scripts/agent.py
Normal file
111
scripts/agent.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# ruff: noqa: INP001
|
||||||
|
"""ADK agent that connects to the knowledge-search MCP server."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
from google.adk.agents.llm_agent import LlmAgent
|
||||||
|
from google.adk.runners import Runner
|
||||||
|
from google.adk.sessions import InMemorySessionService
|
||||||
|
from google.adk.tools.mcp_tool import McpToolset
|
||||||
|
from google.adk.tools.mcp_tool.mcp_session_manager import (
|
||||||
|
SseConnectionParams,
|
||||||
|
StdioConnectionParams,
|
||||||
|
)
|
||||||
|
from google.genai import types
|
||||||
|
from mcp import StdioServerParameters
|
||||||
|
|
||||||
|
# ADK needs these env vars for Vertex AI; reuse the ones from .env
|
||||||
|
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "True")
|
||||||
|
if project := os.environ.get("PROJECT_ID"):
|
||||||
|
os.environ.setdefault("GOOGLE_CLOUD_PROJECT", project)
|
||||||
|
if location := os.environ.get("LOCATION"):
|
||||||
|
os.environ.setdefault("GOOGLE_CLOUD_LOCATION", location)
|
||||||
|
|
||||||
|
SERVER_SCRIPT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src", "knowledge_search_mcp", "main.py")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Knowledge Search Agent")
|
||||||
|
parser.add_argument(
|
||||||
|
"--remote",
|
||||||
|
metavar="URL",
|
||||||
|
help="Connect to an already-running MCP server at this SSE URL "
|
||||||
|
"(e.g. http://localhost:8080/sse). Without this flag the agent "
|
||||||
|
"spawns the server as a subprocess.",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_main() -> None:
|
||||||
|
args = _parse_args()
|
||||||
|
|
||||||
|
if args.remote:
|
||||||
|
connection_params = SseConnectionParams(url=args.remote)
|
||||||
|
else:
|
||||||
|
connection_params = StdioConnectionParams(
|
||||||
|
server_params=StdioServerParameters(
|
||||||
|
command="uv",
|
||||||
|
args=["run", "python", SERVER_SCRIPT],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
toolset = McpToolset(connection_params=connection_params)
|
||||||
|
|
||||||
|
agent = LlmAgent(
|
||||||
|
model="gemini-2.0-flash",
|
||||||
|
name="knowledge_agent",
|
||||||
|
instruction=(
|
||||||
|
"You are a helpful assistant with access to a knowledge base. "
|
||||||
|
"Use the knowledge_search tool to find relevant information "
|
||||||
|
"when the user asks questions. Summarize the results clearly."
|
||||||
|
),
|
||||||
|
tools=[toolset],
|
||||||
|
)
|
||||||
|
|
||||||
|
session_service = InMemorySessionService()
|
||||||
|
session = await session_service.create_session(
|
||||||
|
state={},
|
||||||
|
app_name="knowledge_agent",
|
||||||
|
user_id="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = Runner(
|
||||||
|
app_name="knowledge_agent",
|
||||||
|
agent=agent,
|
||||||
|
session_service=session_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Knowledge Search Agent ready. Type your query (Ctrl+C to exit):")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("\n> ").strip()
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
content = types.Content(
|
||||||
|
role="user",
|
||||||
|
parts=[types.Part(text=query)],
|
||||||
|
)
|
||||||
|
|
||||||
|
async for event in runner.run_async(
|
||||||
|
session_id=session.id,
|
||||||
|
user_id=session.user_id,
|
||||||
|
new_message=content,
|
||||||
|
):
|
||||||
|
if event.is_final_response() and event.content and event.content.parts:
|
||||||
|
for part in event.content.parts:
|
||||||
|
if part.text:
|
||||||
|
print(part.text)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nShutting down...")
|
||||||
|
finally:
|
||||||
|
await toolset.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(async_main())
|
||||||
15
src/knowledge_search_mcp/__init__.py
Normal file
15
src/knowledge_search_mcp/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""MCP server for semantic search over Vertex AI Vector Search."""
|
||||||
|
|
||||||
|
from .clients.storage import GoogleCloudFileStorage
|
||||||
|
from .clients.vector_search import GoogleCloudVectorSearch
|
||||||
|
from .models import AppContext, SearchResult, SourceNamespace
|
||||||
|
from .utils.cache import LRUCache
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AppContext",
|
||||||
|
"GoogleCloudFileStorage",
|
||||||
|
"GoogleCloudVectorSearch",
|
||||||
|
"LRUCache",
|
||||||
|
"SearchResult",
|
||||||
|
"SourceNamespace",
|
||||||
|
]
|
||||||
128
src/knowledge_search_mcp/__main__.py
Normal file
128
src/knowledge_search_mcp/__main__.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""MCP server for semantic search over Vertex AI Vector Search."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from mcp.server.fastmcp import Context, FastMCP
|
||||||
|
|
||||||
|
from .config import _args
|
||||||
|
from .logging import log_structured_entry
|
||||||
|
from .models import AppContext, SourceNamespace
|
||||||
|
from .server import lifespan
|
||||||
|
from .services.search import (
|
||||||
|
filter_search_results,
|
||||||
|
format_search_results,
|
||||||
|
generate_query_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
mcp = FastMCP(
|
||||||
|
"knowledge-search",
|
||||||
|
host=_args.host,
|
||||||
|
port=_args.port,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
async def knowledge_search(
|
||||||
|
query: str,
|
||||||
|
ctx: Context,
|
||||||
|
source: SourceNamespace | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Search a knowledge base using a natural-language query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The text query to search for.
|
||||||
|
ctx: MCP request context (injected automatically).
|
||||||
|
source: Optional filter to restrict results by source.
|
||||||
|
Allowed values: 'Educacion Financiera',
|
||||||
|
'Productos y Servicios', 'Funcionalidades de la App Movil'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted string containing matched documents with id and content.
|
||||||
|
|
||||||
|
"""
|
||||||
|
app: AppContext = ctx.request_context.lifespan_context
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
|
log_structured_entry(
|
||||||
|
"knowledge_search request received",
|
||||||
|
"INFO",
|
||||||
|
{"query": query[:100]}, # Log first 100 chars of query
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate embedding for the query
|
||||||
|
embedding, error = await generate_query_embedding(
|
||||||
|
app.genai_client,
|
||||||
|
app.settings.embedding_model,
|
||||||
|
query,
|
||||||
|
)
|
||||||
|
if error:
|
||||||
|
return error
|
||||||
|
|
||||||
|
t_embed = time.perf_counter()
|
||||||
|
log_structured_entry(
|
||||||
|
"Query embedding generated successfully",
|
||||||
|
"INFO",
|
||||||
|
{"time_ms": round((t_embed - t0) * 1000, 1)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform vector search
|
||||||
|
log_structured_entry("Performing vector search", "INFO")
|
||||||
|
try:
|
||||||
|
search_results = await app.vector_search.async_run_query(
|
||||||
|
deployed_index_id=app.settings.deployed_index_id,
|
||||||
|
query=embedding,
|
||||||
|
limit=app.settings.search_limit,
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
t_search = time.perf_counter()
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log_structured_entry(
|
||||||
|
"Vector search failed",
|
||||||
|
"ERROR",
|
||||||
|
{"error": str(e), "error_type": type(e).__name__, "query": query[:100]},
|
||||||
|
)
|
||||||
|
return f"Error performing vector search: {e!s}"
|
||||||
|
|
||||||
|
# Apply similarity filtering
|
||||||
|
filtered_results = filter_search_results(search_results)
|
||||||
|
|
||||||
|
log_structured_entry(
|
||||||
|
"knowledge_search completed successfully",
|
||||||
|
"INFO",
|
||||||
|
{
|
||||||
|
"embedding_ms": f"{round((t_embed - t0) * 1000, 1)}ms",
|
||||||
|
"vector_search_ms": f"{round((t_search - t_embed) * 1000, 1)}ms",
|
||||||
|
"total_ms": f"{round((t_search - t0) * 1000, 1)}ms",
|
||||||
|
"source_filter": source.value if source is not None else None,
|
||||||
|
"results_count": len(filtered_results),
|
||||||
|
"chunks": [s["id"] for s in filtered_results],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format and return results
|
||||||
|
if not filtered_results:
|
||||||
|
log_structured_entry(
|
||||||
|
"No results found for query", "INFO", {"query": query[:100]}
|
||||||
|
)
|
||||||
|
|
||||||
|
return format_search_results(filtered_results)
|
||||||
|
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
# Catch-all for any unexpected errors
|
||||||
|
log_structured_entry(
|
||||||
|
"Unexpected error in knowledge_search",
|
||||||
|
"ERROR",
|
||||||
|
{"error": str(e), "error_type": type(e).__name__, "query": query[:100]},
|
||||||
|
)
|
||||||
|
return f"Unexpected error during search: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Entry point for the MCP server."""
|
||||||
|
mcp.run(transport=_args.transport)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
11
src/knowledge_search_mcp/clients/__init__.py
Normal file
11
src/knowledge_search_mcp/clients/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Client modules for Google Cloud services."""
|
||||||
|
|
||||||
|
from .base import BaseGoogleCloudClient
|
||||||
|
from .storage import GoogleCloudFileStorage
|
||||||
|
from .vector_search import GoogleCloudVectorSearch
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseGoogleCloudClient",
|
||||||
|
"GoogleCloudFileStorage",
|
||||||
|
"GoogleCloudVectorSearch",
|
||||||
|
]
|
||||||
30
src/knowledge_search_mcp/clients/base.py
Normal file
30
src/knowledge_search_mcp/clients/base.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Base client with shared aiohttp session management."""
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGoogleCloudClient:
|
||||||
|
"""Base class with shared aiohttp session management."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize session tracking."""
|
||||||
|
self._aio_session: aiohttp.ClientSession | None = None
|
||||||
|
|
||||||
|
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||||
|
"""Get or create aiohttp session with connection pooling."""
|
||||||
|
if self._aio_session is None or self._aio_session.closed:
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
limit=300,
|
||||||
|
limit_per_host=50,
|
||||||
|
)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=60)
|
||||||
|
self._aio_session = aiohttp.ClientSession(
|
||||||
|
timeout=timeout,
|
||||||
|
connector=connector,
|
||||||
|
)
|
||||||
|
return self._aio_session
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close aiohttp session if open."""
|
||||||
|
if self._aio_session and not self._aio_session.closed:
|
||||||
|
await self._aio_session.close()
|
||||||
150
src/knowledge_search_mcp/clients/storage.py
Normal file
150
src/knowledge_search_mcp/clients/storage.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""Google Cloud Storage client with caching."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
from typing import BinaryIO
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from gcloud.aio.storage import Storage
|
||||||
|
|
||||||
|
from knowledge_search_mcp.logging import log_structured_entry
|
||||||
|
from knowledge_search_mcp.utils.cache import LRUCache
|
||||||
|
|
||||||
|
from .base import BaseGoogleCloudClient
|
||||||
|
|
||||||
|
HTTP_TOO_MANY_REQUESTS = 429
|
||||||
|
HTTP_SERVER_ERROR = 500
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleCloudFileStorage(BaseGoogleCloudClient):
|
||||||
|
"""Cache-aware helper for downloading files from Google Cloud Storage."""
|
||||||
|
|
||||||
|
def __init__(self, bucket: str, cache_size: int = 100) -> None:
|
||||||
|
"""Initialize the storage helper with LRU cache."""
|
||||||
|
super().__init__()
|
||||||
|
self.bucket_name = bucket
|
||||||
|
self._aio_storage: Storage | None = None
|
||||||
|
self._cache = LRUCache(max_size=cache_size)
|
||||||
|
|
||||||
|
def _get_aio_storage(self) -> Storage:
|
||||||
|
if self._aio_storage is None:
|
||||||
|
self._aio_storage = Storage(
|
||||||
|
session=self._get_aio_session(),
|
||||||
|
)
|
||||||
|
return self._aio_storage
|
||||||
|
|
||||||
|
async def async_get_file_stream(
|
||||||
|
self,
|
||||||
|
file_name: str,
|
||||||
|
max_retries: int = 3,
|
||||||
|
) -> BinaryIO:
|
||||||
|
"""Get a file asynchronously with retry on transient errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: The blob name to retrieve.
|
||||||
|
max_retries: Maximum number of retry attempts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A BytesIO stream with the file contents.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If all retry attempts fail.
|
||||||
|
|
||||||
|
"""
|
||||||
|
cached_content = self._cache.get(file_name)
|
||||||
|
if cached_content is not None:
|
||||||
|
log_structured_entry(
|
||||||
|
"File retrieved from cache",
|
||||||
|
"INFO",
|
||||||
|
{"file": file_name, "bucket": self.bucket_name},
|
||||||
|
)
|
||||||
|
file_stream = io.BytesIO(cached_content)
|
||||||
|
file_stream.name = file_name
|
||||||
|
return file_stream
|
||||||
|
|
||||||
|
log_structured_entry(
|
||||||
|
"Starting file download from GCS",
|
||||||
|
"INFO",
|
||||||
|
{"file": file_name, "bucket": self.bucket_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_client = self._get_aio_storage()
|
||||||
|
last_exception: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
content = await storage_client.download(
|
||||||
|
self.bucket_name,
|
||||||
|
file_name,
|
||||||
|
)
|
||||||
|
self._cache.put(file_name, content)
|
||||||
|
file_stream = io.BytesIO(content)
|
||||||
|
file_stream.name = file_name
|
||||||
|
log_structured_entry(
|
||||||
|
"File downloaded successfully",
|
||||||
|
"INFO",
|
||||||
|
{
|
||||||
|
"file": file_name,
|
||||||
|
"bucket": self.bucket_name,
|
||||||
|
"size_bytes": len(content),
|
||||||
|
"attempt": attempt + 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
last_exception = exc
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
f"Timeout downloading gs://{self.bucket_name}/{file_name} "
|
||||||
|
f"(attempt {attempt + 1}/{max_retries})"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{"error": str(exc)},
|
||||||
|
)
|
||||||
|
except aiohttp.ClientResponseError as exc:
|
||||||
|
last_exception = exc
|
||||||
|
if (
|
||||||
|
exc.status == HTTP_TOO_MANY_REQUESTS
|
||||||
|
or exc.status >= HTTP_SERVER_ERROR
|
||||||
|
):
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
f"HTTP {exc.status} downloading gs://{self.bucket_name}/"
|
||||||
|
f"{file_name} (attempt {attempt + 1}/{max_retries})"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{"status": exc.status, "message": str(exc)},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_structured_entry(
|
||||||
|
f"Non-retryable HTTP error downloading gs://{self.bucket_name}/{file_name}",
|
||||||
|
"ERROR",
|
||||||
|
{"status": exc.status, "message": str(exc)},
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
return file_stream
|
||||||
|
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
delay = 0.5 * (2**attempt)
|
||||||
|
log_structured_entry(
|
||||||
|
"Retrying file download",
|
||||||
|
"INFO",
|
||||||
|
{"file": file_name, "delay_seconds": delay},
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
||||||
|
f"after {max_retries} attempts"
|
||||||
|
)
|
||||||
|
log_structured_entry(
|
||||||
|
"File download failed after all retries",
|
||||||
|
"ERROR",
|
||||||
|
{
|
||||||
|
"file": file_name,
|
||||||
|
"bucket": self.bucket_name,
|
||||||
|
"max_retries": max_retries,
|
||||||
|
"last_error": str(last_exception),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise TimeoutError(msg) from last_exception
|
||||||
223
src/knowledge_search_mcp/clients/vector_search.py
Normal file
223
src/knowledge_search_mcp/clients/vector_search.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""Google Cloud Vector Search client."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from gcloud.aio.auth import Token
|
||||||
|
|
||||||
|
from knowledge_search_mcp.logging import log_structured_entry
|
||||||
|
from knowledge_search_mcp.models import SearchResult, SourceNamespace
|
||||||
|
|
||||||
|
from .base import BaseGoogleCloudClient
|
||||||
|
from .storage import GoogleCloudFileStorage
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleCloudVectorSearch(BaseGoogleCloudClient):
|
||||||
|
"""Minimal async client for the Vertex AI Matching Engine REST API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
location: str,
|
||||||
|
bucket: str,
|
||||||
|
index_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store configuration used to issue Matching Engine queries."""
|
||||||
|
super().__init__()
|
||||||
|
self.project_id = project_id
|
||||||
|
self.location = location
|
||||||
|
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||||
|
self.index_name = index_name
|
||||||
|
self._async_token: Token | None = None
|
||||||
|
self._endpoint_domain: str | None = None
|
||||||
|
self._endpoint_name: str | None = None
|
||||||
|
|
||||||
|
async def _async_get_auth_headers(self) -> dict[str, str]:
|
||||||
|
if self._async_token is None:
|
||||||
|
self._async_token = Token(
|
||||||
|
session=self._get_aio_session(),
|
||||||
|
scopes=[
|
||||||
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
access_token = await self._async_token.get()
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close aiohttp sessions for both vector search and storage."""
|
||||||
|
await super().close()
|
||||||
|
await self.storage.close()
|
||||||
|
|
||||||
|
def configure_index_endpoint(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
public_domain: str,
|
||||||
|
) -> None:
|
||||||
|
"""Persist the metadata needed to access a deployed endpoint."""
|
||||||
|
if not name:
|
||||||
|
msg = "Index endpoint name must be a non-empty string."
|
||||||
|
raise ValueError(msg)
|
||||||
|
if not public_domain:
|
||||||
|
msg = "Index endpoint domain must be a non-empty public domain."
|
||||||
|
raise ValueError(msg)
|
||||||
|
self._endpoint_name = name
|
||||||
|
self._endpoint_domain = public_domain
|
||||||
|
|
||||||
|
async def async_run_query(
|
||||||
|
self,
|
||||||
|
deployed_index_id: str,
|
||||||
|
query: Sequence[float],
|
||||||
|
limit: int,
|
||||||
|
source: SourceNamespace | None = None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Run an async similarity search via the REST API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deployed_index_id: The ID of the deployed index.
|
||||||
|
query: The embedding vector for the search query.
|
||||||
|
limit: Maximum number of nearest neighbors to return.
|
||||||
|
source: Optional namespace filter to restrict results by source.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of matched items with id, distance, and content.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._endpoint_domain is None or self._endpoint_name is None:
|
||||||
|
msg = (
|
||||||
|
"Missing endpoint metadata. Call "
|
||||||
|
"`configure_index_endpoint` before querying."
|
||||||
|
)
|
||||||
|
log_structured_entry(
|
||||||
|
"Vector search query failed - endpoint not configured",
|
||||||
|
"ERROR",
|
||||||
|
{"error": msg},
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
domain = self._endpoint_domain
|
||||||
|
endpoint_id = self._endpoint_name.split("/")[-1]
|
||||||
|
url = (
|
||||||
|
f"https://{domain}/v1/projects/{self.project_id}"
|
||||||
|
f"/locations/{self.location}"
|
||||||
|
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||||
|
)
|
||||||
|
|
||||||
|
log_structured_entry(
|
||||||
|
"Starting vector search query",
|
||||||
|
"INFO",
|
||||||
|
{
|
||||||
|
"deployed_index_id": deployed_index_id,
|
||||||
|
"neighbor_count": limit,
|
||||||
|
"endpoint_id": endpoint_id,
|
||||||
|
"embedding_dimension": len(query),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
datapoint: dict = {"feature_vector": list(query)}
|
||||||
|
if source is not None:
|
||||||
|
datapoint["restricts"] = [
|
||||||
|
{"namespace": "source", "allow_list": [source.value]},
|
||||||
|
]
|
||||||
|
payload = {
|
||||||
|
"deployed_index_id": deployed_index_id,
|
||||||
|
"queries": [
|
||||||
|
{
|
||||||
|
"datapoint": datapoint,
|
||||||
|
"neighbor_count": limit,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = await self._async_get_auth_headers()
|
||||||
|
session = self._get_aio_session()
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
) as response:
|
||||||
|
if not response.ok:
|
||||||
|
body = await response.text()
|
||||||
|
msg = f"findNeighbors returned {response.status}: {body}"
|
||||||
|
log_structured_entry(
|
||||||
|
"Vector search API request failed",
|
||||||
|
"ERROR",
|
||||||
|
{
|
||||||
|
"status": response.status,
|
||||||
|
"response_body": body,
|
||||||
|
"deployed_index_id": deployed_index_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise RuntimeError(msg) # noqa: TRY301
|
||||||
|
data = await response.json()
|
||||||
|
|
||||||
|
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||||
|
log_structured_entry(
|
||||||
|
"Vector search API request successful",
|
||||||
|
"INFO",
|
||||||
|
{
|
||||||
|
"neighbors_found": len(neighbors),
|
||||||
|
"deployed_index_id": deployed_index_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not neighbors:
|
||||||
|
log_structured_entry(
|
||||||
|
"No neighbors found in vector search",
|
||||||
|
"WARNING",
|
||||||
|
{"deployed_index_id": deployed_index_id},
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Fetch content for all neighbors
|
||||||
|
content_tasks = []
|
||||||
|
for neighbor in neighbors:
|
||||||
|
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||||
|
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
|
||||||
|
content_tasks.append(
|
||||||
|
self.storage.async_get_file_stream(file_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
log_structured_entry(
|
||||||
|
"Fetching content for search results",
|
||||||
|
"INFO",
|
||||||
|
{"file_count": len(content_tasks)},
|
||||||
|
)
|
||||||
|
|
||||||
|
file_streams = await asyncio.gather(*content_tasks)
|
||||||
|
results: list[SearchResult] = []
|
||||||
|
for neighbor, stream in zip(
|
||||||
|
neighbors,
|
||||||
|
file_streams,
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
results.append(
|
||||||
|
SearchResult(
|
||||||
|
id=neighbor["datapoint"]["datapointId"],
|
||||||
|
distance=neighbor["distance"],
|
||||||
|
content=stream.read().decode("utf-8"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
log_structured_entry(
|
||||||
|
"Vector search completed successfully",
|
||||||
|
"INFO",
|
||||||
|
{"results_count": len(results), "deployed_index_id": deployed_index_id},
|
||||||
|
)
|
||||||
|
return results # noqa: TRY300
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log_structured_entry(
|
||||||
|
"Vector search query failed with exception",
|
||||||
|
"ERROR",
|
||||||
|
{
|
||||||
|
"error": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
"deployed_index_id": deployed_index_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise
|
||||||
104
src/knowledge_search_mcp/config.py
Normal file
104
src/knowledge_search_mcp/config.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Configuration management for the MCP server."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from pydantic_settings import (
|
||||||
|
BaseSettings,
|
||||||
|
PydanticBaseSettingsSource,
|
||||||
|
YamlConfigSettingsSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args() -> argparse.Namespace:
|
||||||
|
"""Parse command-line arguments.
|
||||||
|
|
||||||
|
Returns a namespace with default values if running under pytest.
|
||||||
|
"""
|
||||||
|
# Don't parse args if running under pytest
|
||||||
|
if "pytest" in sys.modules:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
return argparse.Namespace(
|
||||||
|
transport="stdio",
|
||||||
|
host="0.0.0.0", # noqa: S104
|
||||||
|
port=8080,
|
||||||
|
config=os.environ.get("CONFIG_FILE", "config.yaml"),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--transport",
|
||||||
|
choices=["stdio", "sse", "streamable-http"],
|
||||||
|
default="stdio",
|
||||||
|
)
|
||||||
|
parser.add_argument("--host", default="0.0.0.0") # noqa: S104
|
||||||
|
parser.add_argument("--port", type=int, default=8080)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
default=os.environ.get("CONFIG_FILE", "config.yaml"),
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
_args = _parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Server configuration populated from env vars and a YAML config file."""
|
||||||
|
|
||||||
|
model_config = {"env_file": ".env", "yaml_file": _args.config}
|
||||||
|
|
||||||
|
project_id: str
|
||||||
|
location: str
|
||||||
|
bucket: str
|
||||||
|
index_name: str
|
||||||
|
deployed_index_id: str
|
||||||
|
endpoint_name: str
|
||||||
|
endpoint_domain: str
|
||||||
|
embedding_model: str = "gemini-embedding-001"
|
||||||
|
search_limit: int = 10
|
||||||
|
log_name: str = "va_agent_evaluation_logs"
|
||||||
|
log_level: str = "INFO"
|
||||||
|
cloud_logging_enabled: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def settings_customise_sources(
|
||||||
|
cls,
|
||||||
|
settings_cls: type[BaseSettings],
|
||||||
|
init_settings: PydanticBaseSettingsSource,
|
||||||
|
env_settings: PydanticBaseSettingsSource,
|
||||||
|
dotenv_settings: PydanticBaseSettingsSource,
|
||||||
|
file_secret_settings: PydanticBaseSettingsSource,
|
||||||
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||||
|
"""Customize the order of settings sources to include YAML config."""
|
||||||
|
return (
|
||||||
|
init_settings,
|
||||||
|
env_settings,
|
||||||
|
dotenv_settings,
|
||||||
|
YamlConfigSettingsSource(settings_cls),
|
||||||
|
file_secret_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Lazy singleton instance of Settings
|
||||||
|
_cfg: Settings | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_config() -> Settings:
|
||||||
|
"""Get or create the singleton Settings instance."""
|
||||||
|
global _cfg # noqa: PLW0603
|
||||||
|
if _cfg is None:
|
||||||
|
_cfg = Settings.model_validate({})
|
||||||
|
return _cfg
|
||||||
|
|
||||||
|
|
||||||
|
# For backwards compatibility, provide cfg as a property-like accessor
|
||||||
|
class _ConfigProxy:
|
||||||
|
"""Proxy object that lazily loads config on attribute access."""
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> object:
|
||||||
|
return getattr(get_config(), name)
|
||||||
|
|
||||||
|
|
||||||
|
cfg = _ConfigProxy()
|
||||||
67
src/knowledge_search_mcp/logging.py
Normal file
67
src/knowledge_search_mcp/logging.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Centralized Cloud Logging setup.
|
||||||
|
|
||||||
|
Uses CloudLoggingHandler (background thread) so logging does not add latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import google.cloud.logging
|
||||||
|
from google.cloud.logging.handlers import CloudLoggingHandler
|
||||||
|
|
||||||
|
from .config import get_config
|
||||||
|
|
||||||
|
_eval_log: logging.Logger | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_logger() -> logging.Logger:
|
||||||
|
"""Get or create the singleton evaluation logger."""
|
||||||
|
global _eval_log # noqa: PLW0603
|
||||||
|
if _eval_log is not None:
|
||||||
|
return _eval_log
|
||||||
|
|
||||||
|
cfg = get_config()
|
||||||
|
logger = logging.getLogger(cfg.log_name)
|
||||||
|
if any(isinstance(h, CloudLoggingHandler) for h in logger.handlers):
|
||||||
|
_eval_log = logger
|
||||||
|
return logger
|
||||||
|
|
||||||
|
if cfg.cloud_logging_enabled:
|
||||||
|
try:
|
||||||
|
client = google.cloud.logging.Client(project=cfg.project_id)
|
||||||
|
handler = CloudLoggingHandler(client, name=cfg.log_name) # async transport
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.setLevel(getattr(logging, cfg.log_level.upper()))
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
# Fallback to console if Cloud Logging is unavailable (local dev)
|
||||||
|
logging.basicConfig(level=getattr(logging, cfg.log_level.upper()))
|
||||||
|
logger = logging.getLogger(cfg.log_name)
|
||||||
|
logger.warning("Cloud Logging setup failed; using console. Error: %s", e)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=getattr(logging, cfg.log_level.upper()))
|
||||||
|
logger = logging.getLogger(cfg.log_name)
|
||||||
|
|
||||||
|
_eval_log = logger
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def log_structured_entry(
|
||||||
|
message: str,
|
||||||
|
severity: Literal["INFO", "WARNING", "ERROR"],
|
||||||
|
custom_log: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Emit a JSON-structured log row.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Short label for the row (e.g., "Final agent turn").
|
||||||
|
severity: "INFO" | "WARNING" | "ERROR"
|
||||||
|
custom_log: A dict with your structured payload.
|
||||||
|
|
||||||
|
"""
|
||||||
|
level = getattr(logging, severity.upper(), logging.INFO)
|
||||||
|
logger = _get_logger()
|
||||||
|
logger.log(
|
||||||
|
level,
|
||||||
|
message,
|
||||||
|
extra={"json_fields": {"message": message, "custom": custom_log or {}}},
|
||||||
|
)
|
||||||
36
src/knowledge_search_mcp/models.py
Normal file
36
src/knowledge_search_mcp/models.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Domain models for knowledge search MCP server."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from google import genai
|
||||||
|
|
||||||
|
from .clients.vector_search import GoogleCloudVectorSearch
|
||||||
|
from .config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
class SourceNamespace(StrEnum):
|
||||||
|
"""Allowed values for the 'source' namespace filter."""
|
||||||
|
|
||||||
|
EDUCACION_FINANCIERA = "Educacion Financiera"
|
||||||
|
PRODUCTOS_Y_SERVICIOS = "Productos y Servicios"
|
||||||
|
FUNCIONALIDADES_APP_MOVIL = "Funcionalidades de la App Movil"
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(TypedDict):
|
||||||
|
"""Structured response item returned by the vector search API."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
distance: float
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppContext:
|
||||||
|
"""Shared resources initialised once at server startup."""
|
||||||
|
|
||||||
|
vector_search: "GoogleCloudVectorSearch"
|
||||||
|
genai_client: "genai.Client"
|
||||||
|
settings: "Settings"
|
||||||
143
src/knowledge_search_mcp/server.py
Normal file
143
src/knowledge_search_mcp/server.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""MCP server lifecycle management."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
|
||||||
|
from .clients.vector_search import GoogleCloudVectorSearch
|
||||||
|
from .config import get_config
|
||||||
|
from .logging import log_structured_entry
|
||||||
|
from .models import AppContext
|
||||||
|
from .services.validation import (
|
||||||
|
validate_gcs_access,
|
||||||
|
validate_genai_access,
|
||||||
|
validate_vector_search_access,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(_server: FastMCP) -> AsyncIterator[AppContext]:
|
||||||
|
"""Create and configure the vector-search client for the server lifetime."""
|
||||||
|
# Get config with proper types for initialization
|
||||||
|
config_for_init = get_config()
|
||||||
|
|
||||||
|
log_structured_entry(
|
||||||
|
"Initializing MCP server",
|
||||||
|
"INFO",
|
||||||
|
{
|
||||||
|
"project_id": config_for_init.project_id,
|
||||||
|
"location": config_for_init.location,
|
||||||
|
"bucket": config_for_init.bucket,
|
||||||
|
"index_name": config_for_init.index_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
vs: GoogleCloudVectorSearch | None = None
|
||||||
|
try:
|
||||||
|
# Initialize vector search client
|
||||||
|
log_structured_entry("Creating GoogleCloudVectorSearch client", "INFO")
|
||||||
|
vs = GoogleCloudVectorSearch(
|
||||||
|
project_id=config_for_init.project_id,
|
||||||
|
location=config_for_init.location,
|
||||||
|
bucket=config_for_init.bucket,
|
||||||
|
index_name=config_for_init.index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure endpoint
|
||||||
|
log_structured_entry(
|
||||||
|
"Configuring index endpoint",
|
||||||
|
"INFO",
|
||||||
|
{
|
||||||
|
"endpoint_name": config_for_init.endpoint_name,
|
||||||
|
"endpoint_domain": config_for_init.endpoint_domain,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
vs.configure_index_endpoint(
|
||||||
|
name=config_for_init.endpoint_name,
|
||||||
|
public_domain=config_for_init.endpoint_domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize GenAI client
|
||||||
|
log_structured_entry(
|
||||||
|
"Creating GenAI client",
|
||||||
|
"INFO",
|
||||||
|
{
|
||||||
|
"project_id": config_for_init.project_id,
|
||||||
|
"location": config_for_init.location,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
genai_client = genai.Client(
|
||||||
|
vertexai=True,
|
||||||
|
project=config_for_init.project_id,
|
||||||
|
location=config_for_init.location,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate credentials and configuration by testing actual resources
|
||||||
|
# These validations are non-blocking - errors are logged but won't stop startup
|
||||||
|
log_structured_entry("Starting validation of credentials and resources", "INFO")
|
||||||
|
|
||||||
|
validation_errors = []
|
||||||
|
|
||||||
|
# Run all validations
|
||||||
|
config = get_config()
|
||||||
|
genai_error = await validate_genai_access(genai_client, config)
|
||||||
|
if genai_error:
|
||||||
|
validation_errors.append(genai_error)
|
||||||
|
|
||||||
|
gcs_error = await validate_gcs_access(vs, config)
|
||||||
|
if gcs_error:
|
||||||
|
validation_errors.append(gcs_error)
|
||||||
|
|
||||||
|
vs_error = await validate_vector_search_access(vs, config)
|
||||||
|
if vs_error:
|
||||||
|
validation_errors.append(vs_error)
|
||||||
|
|
||||||
|
# Summary of validations
|
||||||
|
if validation_errors:
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
"MCP server started with validation errors - "
|
||||||
|
"service may not work correctly"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{
|
||||||
|
"validation_errors": validation_errors,
|
||||||
|
"error_count": len(validation_errors),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_structured_entry(
|
||||||
|
"All validations passed - MCP server initialization complete", "INFO"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield AppContext(
|
||||||
|
vector_search=vs,
|
||||||
|
genai_client=genai_client,
|
||||||
|
settings=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log_structured_entry(
|
||||||
|
"Failed to initialize MCP server",
|
||||||
|
"ERROR",
|
||||||
|
{
|
||||||
|
"error": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
log_structured_entry("MCP server lifespan ending", "INFO")
|
||||||
|
# Clean up resources
|
||||||
|
if vs is not None:
|
||||||
|
try:
|
||||||
|
await vs.close()
|
||||||
|
log_structured_entry("Closed aiohttp sessions", "INFO")
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log_structured_entry(
|
||||||
|
"Error closing aiohttp sessions",
|
||||||
|
"WARNING",
|
||||||
|
{"error": str(e), "error_type": type(e).__name__},
|
||||||
|
)
|
||||||
21
src/knowledge_search_mcp/services/__init__.py
Normal file
21
src/knowledge_search_mcp/services/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Service modules for business logic."""
|
||||||
|
|
||||||
|
from .search import (
|
||||||
|
filter_search_results,
|
||||||
|
format_search_results,
|
||||||
|
generate_query_embedding,
|
||||||
|
)
|
||||||
|
from .validation import (
|
||||||
|
validate_gcs_access,
|
||||||
|
validate_genai_access,
|
||||||
|
validate_vector_search_access,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"filter_search_results",
|
||||||
|
"format_search_results",
|
||||||
|
"generate_query_embedding",
|
||||||
|
"validate_gcs_access",
|
||||||
|
"validate_genai_access",
|
||||||
|
"validate_vector_search_access",
|
||||||
|
]
|
||||||
101
src/knowledge_search_mcp/services/search.py
Normal file
101
src/knowledge_search_mcp/services/search.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""Search helper functions."""
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types as genai_types
|
||||||
|
|
||||||
|
from knowledge_search_mcp.logging import log_structured_entry
|
||||||
|
from knowledge_search_mcp.models import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_query_embedding(
|
||||||
|
genai_client: genai.Client,
|
||||||
|
embedding_model: str,
|
||||||
|
query: str,
|
||||||
|
) -> tuple[list[float], str | None]:
|
||||||
|
"""Generate embedding for search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (embedding vector, error message). Error message is None on success.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not query or not query.strip():
|
||||||
|
return ([], "Error: Query cannot be empty")
|
||||||
|
|
||||||
|
log_structured_entry("Generating query embedding", "INFO")
|
||||||
|
try:
|
||||||
|
response = await genai_client.aio.models.embed_content(
|
||||||
|
model=embedding_model,
|
||||||
|
contents=query,
|
||||||
|
config=genai_types.EmbedContentConfig(
|
||||||
|
task_type="RETRIEVAL_QUERY",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not response.embeddings or not response.embeddings[0].values:
|
||||||
|
return ([], "Error: Failed to generate embedding - empty response")
|
||||||
|
embedding = response.embeddings[0].values
|
||||||
|
return (embedding, None) # noqa: TRY300
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
error_type = type(e).__name__
|
||||||
|
error_msg = str(e)
|
||||||
|
|
||||||
|
# Check if it's a rate limit error
|
||||||
|
if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
|
||||||
|
log_structured_entry(
|
||||||
|
"Rate limit exceeded while generating embedding",
|
||||||
|
"WARNING",
|
||||||
|
{"error": error_msg, "error_type": error_type, "query": query[:100]},
|
||||||
|
)
|
||||||
|
return ([], "Error: API rate limit exceeded. Please try again later.")
|
||||||
|
log_structured_entry(
|
||||||
|
"Failed to generate query embedding",
|
||||||
|
"ERROR",
|
||||||
|
{"error": error_msg, "error_type": error_type, "query": query[:100]},
|
||||||
|
)
|
||||||
|
return ([], f"Error generating embedding: {error_msg}")
|
||||||
|
|
||||||
|
|
||||||
|
def filter_search_results(
|
||||||
|
results: list[SearchResult],
|
||||||
|
min_similarity: float = 0.6,
|
||||||
|
top_percent: float = 0.9,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Filter search results by similarity thresholds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Raw search results from vector search.
|
||||||
|
min_similarity: Minimum similarity score (distance) to include.
|
||||||
|
top_percent: Keep results within this percentage of the top score.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of search results.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
max_sim = max(r["distance"] for r in results)
|
||||||
|
cutoff = max_sim * top_percent
|
||||||
|
|
||||||
|
return [
|
||||||
|
s for s in results if s["distance"] > cutoff and s["distance"] > min_similarity
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def format_search_results(results: list[SearchResult]) -> str:
|
||||||
|
"""Format search results as XML-like documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of search results to format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string with document tags.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return "No relevant documents found for your query."
|
||||||
|
|
||||||
|
formatted_results = [
|
||||||
|
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
|
||||||
|
for i, result in enumerate(results, start=1)
|
||||||
|
]
|
||||||
|
return "\n".join(formatted_results)
|
||||||
214
src/knowledge_search_mcp/services/validation.py
Normal file
214
src/knowledge_search_mcp/services/validation.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Validation functions for Google Cloud services."""
|
||||||
|
|
||||||
|
from gcloud.aio.auth import Token
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types as genai_types
|
||||||
|
|
||||||
|
from knowledge_search_mcp.clients.vector_search import GoogleCloudVectorSearch
|
||||||
|
from knowledge_search_mcp.config import Settings
|
||||||
|
from knowledge_search_mcp.logging import log_structured_entry
|
||||||
|
|
||||||
|
# HTTP status codes
|
||||||
|
HTTP_FORBIDDEN = 403
|
||||||
|
HTTP_NOT_FOUND = 404
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_genai_access(
|
||||||
|
genai_client: genai.Client, cfg: Settings
|
||||||
|
) -> str | None:
|
||||||
|
"""Validate GenAI embedding access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Error message if validation fails, None if successful.
|
||||||
|
|
||||||
|
"""
|
||||||
|
log_structured_entry("Validating GenAI embedding access", "INFO")
|
||||||
|
try:
|
||||||
|
test_response = await genai_client.aio.models.embed_content(
|
||||||
|
model=cfg.embedding_model,
|
||||||
|
contents="test",
|
||||||
|
config=genai_types.EmbedContentConfig(
|
||||||
|
task_type="RETRIEVAL_QUERY",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if test_response and test_response.embeddings:
|
||||||
|
embedding_values = test_response.embeddings[0].values
|
||||||
|
log_structured_entry(
|
||||||
|
"GenAI embedding validation successful",
|
||||||
|
"INFO",
|
||||||
|
{
|
||||||
|
"embedding_dimension": len(embedding_values)
|
||||||
|
if embedding_values
|
||||||
|
else 0
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
msg = "Embedding validation returned empty response"
|
||||||
|
log_structured_entry(msg, "WARNING")
|
||||||
|
return msg # noqa: TRY300
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
"Failed to validate GenAI embedding access - "
|
||||||
|
"service may not work correctly"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{"error": str(e), "error_type": type(e).__name__},
|
||||||
|
)
|
||||||
|
return f"GenAI: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_gcs_access(vs: GoogleCloudVectorSearch, cfg: Settings) -> str | None:
|
||||||
|
"""Validate GCS bucket access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Error message if validation fails, None if successful.
|
||||||
|
|
||||||
|
"""
|
||||||
|
log_structured_entry("Validating GCS bucket access", "INFO", {"bucket": cfg.bucket})
|
||||||
|
try:
|
||||||
|
session = vs.storage._get_aio_session() # noqa: SLF001
|
||||||
|
token_obj = Token(
|
||||||
|
session=session,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
)
|
||||||
|
access_token = await token_obj.get()
|
||||||
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
|
|
||||||
|
async with session.get(
|
||||||
|
f"https://storage.googleapis.com/storage/v1/b/{cfg.bucket}/o?maxResults=1",
|
||||||
|
headers=headers,
|
||||||
|
) as response:
|
||||||
|
if response.status == HTTP_FORBIDDEN:
|
||||||
|
msg = f"Access denied to bucket '{cfg.bucket}'. Check permissions."
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
"GCS bucket validation failed - access denied - "
|
||||||
|
"service may not work correctly"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{"bucket": cfg.bucket, "status": response.status},
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
if response.status == HTTP_NOT_FOUND:
|
||||||
|
msg = f"Bucket '{cfg.bucket}' not found. Check bucket name and project."
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
"GCS bucket validation failed - not found - "
|
||||||
|
"service may not work correctly"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{"bucket": cfg.bucket, "status": response.status},
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
if not response.ok:
|
||||||
|
body = await response.text()
|
||||||
|
msg = f"Failed to access bucket '{cfg.bucket}': {response.status}"
|
||||||
|
log_structured_entry(
|
||||||
|
"GCS bucket validation failed - service may not work correctly",
|
||||||
|
"WARNING",
|
||||||
|
{"bucket": cfg.bucket, "status": response.status, "response": body},
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
log_structured_entry(
|
||||||
|
"GCS bucket validation successful", "INFO", {"bucket": cfg.bucket}
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log_structured_entry(
|
||||||
|
"Failed to validate GCS bucket access - service may not work correctly",
|
||||||
|
"WARNING",
|
||||||
|
{"error": str(e), "error_type": type(e).__name__, "bucket": cfg.bucket},
|
||||||
|
)
|
||||||
|
return f"GCS: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_vector_search_access(
|
||||||
|
vs: GoogleCloudVectorSearch, cfg: Settings
|
||||||
|
) -> str | None:
|
||||||
|
"""Validate vector search endpoint access.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Error message if validation fails, None if successful.
|
||||||
|
|
||||||
|
"""
|
||||||
|
log_structured_entry(
|
||||||
|
"Validating vector search endpoint access",
|
||||||
|
"INFO",
|
||||||
|
{"endpoint_name": cfg.endpoint_name},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
headers = await vs._async_get_auth_headers() # noqa: SLF001
|
||||||
|
session = vs._get_aio_session() # noqa: SLF001
|
||||||
|
endpoint_url = (
|
||||||
|
f"https://{cfg.location}-aiplatform.googleapis.com/v1/{cfg.endpoint_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with session.get(endpoint_url, headers=headers) as response:
|
||||||
|
if response.status == HTTP_FORBIDDEN:
|
||||||
|
msg = (
|
||||||
|
f"Access denied to endpoint '{cfg.endpoint_name}'. "
|
||||||
|
"Check permissions."
|
||||||
|
)
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
"Vector search endpoint validation failed - "
|
||||||
|
"access denied - service may not work correctly"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{"endpoint": cfg.endpoint_name, "status": response.status},
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
if response.status == HTTP_NOT_FOUND:
|
||||||
|
msg = (
|
||||||
|
f"Endpoint '{cfg.endpoint_name}' not found. "
|
||||||
|
"Check endpoint name and project."
|
||||||
|
)
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
"Vector search endpoint validation failed - "
|
||||||
|
"not found - service may not work correctly"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{"endpoint": cfg.endpoint_name, "status": response.status},
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
if not response.ok:
|
||||||
|
body = await response.text()
|
||||||
|
msg = (
|
||||||
|
f"Failed to access endpoint '{cfg.endpoint_name}': "
|
||||||
|
f"{response.status}"
|
||||||
|
)
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
"Vector search endpoint validation failed - "
|
||||||
|
"service may not work correctly"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{
|
||||||
|
"endpoint": cfg.endpoint_name,
|
||||||
|
"status": response.status,
|
||||||
|
"response": body,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
log_structured_entry(
|
||||||
|
"Vector search endpoint validation successful",
|
||||||
|
"INFO",
|
||||||
|
{"endpoint": cfg.endpoint_name},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
log_structured_entry(
|
||||||
|
(
|
||||||
|
"Failed to validate vector search endpoint access - "
|
||||||
|
"service may not work correctly"
|
||||||
|
),
|
||||||
|
"WARNING",
|
||||||
|
{
|
||||||
|
"error": str(e),
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
"endpoint": cfg.endpoint_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return f"Vector Search: {e!s}"
|
||||||
5
src/knowledge_search_mcp/utils/__init__.py
Normal file
5
src/knowledge_search_mcp/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Utility modules for knowledge search MCP server."""
|
||||||
|
|
||||||
|
from .cache import LRUCache
|
||||||
|
|
||||||
|
__all__ = ["LRUCache"]
|
||||||
32
src/knowledge_search_mcp/utils/cache.py
Normal file
32
src/knowledge_search_mcp/utils/cache.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""LRU cache implementation."""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
class LRUCache:
|
||||||
|
"""Simple LRU cache with size limit."""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int = 100) -> None:
|
||||||
|
"""Initialize cache with maximum size."""
|
||||||
|
self.cache: OrderedDict[str, bytes] = OrderedDict()
|
||||||
|
self.max_size = max_size
|
||||||
|
|
||||||
|
def get(self, key: str) -> bytes | None:
|
||||||
|
"""Get item from cache, returning None if not found."""
|
||||||
|
if key not in self.cache:
|
||||||
|
return None
|
||||||
|
# Move to end to mark as recently used
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
return self.cache[key]
|
||||||
|
|
||||||
|
def put(self, key: str, value: bytes) -> None:
|
||||||
|
"""Put item in cache, evicting oldest if at capacity."""
|
||||||
|
if key in self.cache:
|
||||||
|
self.cache.move_to_end(key)
|
||||||
|
self.cache[key] = value
|
||||||
|
if len(self.cache) > self.max_size:
|
||||||
|
self.cache.popitem(last=False)
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""Check if key exists in cache."""
|
||||||
|
return key in self.cache
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for knowledge-search-mcp."""
|
||||||
36
tests/conftest.py
Normal file
36
tests/conftest.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Pytest configuration and shared fixtures."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_env_vars(monkeypatch):
|
||||||
|
"""Set required environment variables for testing."""
|
||||||
|
test_env = {
|
||||||
|
"PROJECT_ID": "test-project",
|
||||||
|
"LOCATION": "us-central1",
|
||||||
|
"BUCKET": "test-bucket",
|
||||||
|
"INDEX_NAME": "test-index",
|
||||||
|
"DEPLOYED_INDEX_ID": "test-deployed-index",
|
||||||
|
"ENDPOINT_NAME": "projects/test/locations/us-central1/indexEndpoints/test",
|
||||||
|
"ENDPOINT_DOMAIN": "test.us-central1-aiplatform.googleapis.com",
|
||||||
|
}
|
||||||
|
for key, value in test_env.items():
|
||||||
|
monkeypatch.setenv(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_gcs_storage():
|
||||||
|
"""Mock Google Cloud Storage client."""
|
||||||
|
mock = MagicMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_vector_search():
|
||||||
|
"""Mock vector search client."""
|
||||||
|
mock = MagicMock()
|
||||||
|
return mock
|
||||||
56
tests/test_config.py
Normal file
56
tests/test_config.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Tests for configuration management."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from knowledge_search_mcp.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
def test_settings_from_env():
|
||||||
|
"""Test that Settings can be loaded from environment variables."""
|
||||||
|
# Environment is set by conftest.py fixture
|
||||||
|
settings = Settings.model_validate({})
|
||||||
|
|
||||||
|
assert settings.project_id == "test-project"
|
||||||
|
assert settings.location == "us-central1"
|
||||||
|
assert settings.bucket == "test-bucket"
|
||||||
|
assert settings.index_name == "test-index"
|
||||||
|
assert settings.deployed_index_id == "test-deployed-index"
|
||||||
|
|
||||||
|
|
||||||
|
def test_settings_defaults():
|
||||||
|
"""Test that Settings has correct default values."""
|
||||||
|
settings = Settings.model_validate({})
|
||||||
|
|
||||||
|
assert settings.embedding_model == "gemini-embedding-001"
|
||||||
|
assert settings.search_limit == 10
|
||||||
|
assert settings.log_name == "va_agent_evaluation_logs"
|
||||||
|
assert settings.log_level == "INFO"
|
||||||
|
|
||||||
|
|
||||||
|
def test_settings_custom_values(monkeypatch):
|
||||||
|
"""Test that Settings can be customized via environment."""
|
||||||
|
monkeypatch.setenv("EMBEDDING_MODEL", "custom-embedding-model")
|
||||||
|
monkeypatch.setenv("SEARCH_LIMIT", "20")
|
||||||
|
monkeypatch.setenv("LOG_LEVEL", "DEBUG")
|
||||||
|
|
||||||
|
settings = Settings.model_validate({})
|
||||||
|
|
||||||
|
assert settings.embedding_model == "custom-embedding-model"
|
||||||
|
assert settings.search_limit == 20
|
||||||
|
assert settings.log_level == "DEBUG"
|
||||||
|
|
||||||
|
|
||||||
|
def test_settings_validation_error():
|
||||||
|
"""Test that Settings raises ValidationError when required fields are missing."""
|
||||||
|
# Clear all env vars temporarily
|
||||||
|
required_vars = [
|
||||||
|
"PROJECT_ID", "LOCATION", "BUCKET", "INDEX_NAME",
|
||||||
|
"DEPLOYED_INDEX_ID", "ENDPOINT_NAME", "ENDPOINT_DOMAIN"
|
||||||
|
]
|
||||||
|
|
||||||
|
# This should work with conftest fixture
|
||||||
|
settings = Settings.model_validate({})
|
||||||
|
assert settings.project_id == "test-project"
|
||||||
408
tests/test_main_tool.py
Normal file
408
tests/test_main_tool.py
Normal file
@@ -0,0 +1,408 @@
|
|||||||
|
"""Tests for the main knowledge_search tool."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from knowledge_search_mcp.__main__ import knowledge_search
|
||||||
|
from knowledge_search_mcp.models import AppContext, SourceNamespace, SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeSearch:
|
||||||
|
"""Tests for knowledge_search tool function."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_app_context(self):
|
||||||
|
"""Create a mock AppContext."""
|
||||||
|
app = MagicMock(spec=AppContext)
|
||||||
|
|
||||||
|
# Mock genai_client
|
||||||
|
app.genai_client = MagicMock()
|
||||||
|
|
||||||
|
# Mock vector_search
|
||||||
|
app.vector_search = MagicMock()
|
||||||
|
app.vector_search.async_run_query = AsyncMock()
|
||||||
|
|
||||||
|
# Mock settings
|
||||||
|
app.settings = MagicMock()
|
||||||
|
app.settings.embedding_model = "models/text-embedding-004"
|
||||||
|
app.settings.deployed_index_id = "test-deployed-index"
|
||||||
|
app.settings.search_limit = 10
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_context(self, mock_app_context):
|
||||||
|
"""Create a mock MCP Context."""
|
||||||
|
ctx = MagicMock()
|
||||||
|
ctx.request_context = MagicMock()
|
||||||
|
ctx.request_context.lifespan_context = mock_app_context
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_embedding(self):
|
||||||
|
"""Create a sample embedding vector."""
|
||||||
|
return [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_search_results(self):
|
||||||
|
"""Create sample search results."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1.txt", "distance": 0.95, "content": "First document content"},
|
||||||
|
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content"},
|
||||||
|
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content"},
|
||||||
|
]
|
||||||
|
return results
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||||
|
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||||
|
async def test_successful_search(
|
||||||
|
self,
|
||||||
|
mock_format,
|
||||||
|
mock_filter,
|
||||||
|
mock_generate,
|
||||||
|
mock_context,
|
||||||
|
sample_embedding,
|
||||||
|
sample_search_results
|
||||||
|
):
|
||||||
|
"""Test successful search workflow."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||||
|
mock_filter.return_value = sample_search_results
|
||||||
|
mock_format.return_value = "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await knowledge_search("What is financial education?", mock_context)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == "<document 1 name=doc1.txt>\nFirst document content\n</document 1>"
|
||||||
|
mock_generate.assert_called_once()
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
|
||||||
|
deployed_index_id="test-deployed-index",
|
||||||
|
query=sample_embedding,
|
||||||
|
limit=10,
|
||||||
|
source=None,
|
||||||
|
)
|
||||||
|
mock_filter.assert_called_once_with(sample_search_results)
|
||||||
|
mock_format.assert_called_once_with(sample_search_results)
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
async def test_embedding_generation_error(self, mock_generate, mock_context):
|
||||||
|
"""Test handling of embedding generation error."""
|
||||||
|
# Setup mock to return error
|
||||||
|
mock_generate.return_value = ([], "Error: API rate limit exceeded. Please try again later.")
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == "Error: API rate limit exceeded. Please try again later."
|
||||||
|
mock_generate.assert_called_once()
|
||||||
|
# Vector search should not be called
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_not_called()
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
async def test_empty_query_error(self, mock_generate, mock_context):
|
||||||
|
"""Test handling of empty query."""
|
||||||
|
# Setup mock to return error for empty query
|
||||||
|
mock_generate.return_value = ([], "Error: Query cannot be empty")
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await knowledge_search("", mock_context)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == "Error: Query cannot be empty"
|
||||||
|
mock_generate.assert_called_once()
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
async def test_vector_search_error(self, mock_generate, mock_context, sample_embedding):
|
||||||
|
"""Test handling of vector search error."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = Exception(
|
||||||
|
"Vector search service unavailable"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "Error performing vector search:" in result
|
||||||
|
assert "Vector search service unavailable" in result
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||||
|
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||||
|
async def test_empty_search_results(
|
||||||
|
self,
|
||||||
|
mock_format,
|
||||||
|
mock_filter,
|
||||||
|
mock_generate,
|
||||||
|
mock_context,
|
||||||
|
sample_embedding
|
||||||
|
):
|
||||||
|
"""Test handling of empty search results."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = []
|
||||||
|
mock_filter.return_value = []
|
||||||
|
mock_format.return_value = "No relevant documents found for your query."
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await knowledge_search("obscure query", mock_context)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == "No relevant documents found for your query."
|
||||||
|
mock_format.assert_called_once_with([])
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||||
|
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||||
|
async def test_filtered_results_empty(
|
||||||
|
self,
|
||||||
|
mock_format,
|
||||||
|
mock_filter,
|
||||||
|
mock_generate,
|
||||||
|
mock_context,
|
||||||
|
sample_embedding,
|
||||||
|
sample_search_results
|
||||||
|
):
|
||||||
|
"""Test when filtering removes all results."""
|
||||||
|
# Setup mocks - results exist but get filtered out
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||||
|
mock_filter.return_value = [] # All filtered out
|
||||||
|
mock_format.return_value = "No relevant documents found for your query."
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == "No relevant documents found for your query."
|
||||||
|
mock_filter.assert_called_once_with(sample_search_results)
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||||
|
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||||
|
async def test_source_filter_parameter(
|
||||||
|
self,
|
||||||
|
mock_format,
|
||||||
|
mock_filter,
|
||||||
|
mock_generate,
|
||||||
|
mock_context,
|
||||||
|
sample_embedding,
|
||||||
|
sample_search_results
|
||||||
|
):
|
||||||
|
"""Test that source filter is passed correctly to vector search."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||||
|
mock_filter.return_value = sample_search_results
|
||||||
|
mock_format.return_value = "formatted results"
|
||||||
|
|
||||||
|
# Execute with source filter
|
||||||
|
source_filter = SourceNamespace.EDUCACION_FINANCIERA
|
||||||
|
result = await knowledge_search("test query", mock_context, source=source_filter)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == "formatted results"
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.assert_called_once_with(
|
||||||
|
deployed_index_id="test-deployed-index",
|
||||||
|
query=sample_embedding,
|
||||||
|
limit=10,
|
||||||
|
source=source_filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||||
|
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||||
|
async def test_all_source_filters(
|
||||||
|
self,
|
||||||
|
mock_format,
|
||||||
|
mock_filter,
|
||||||
|
mock_generate,
|
||||||
|
mock_context,
|
||||||
|
sample_embedding,
|
||||||
|
sample_search_results
|
||||||
|
):
|
||||||
|
"""Test all available source filter values."""
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||||
|
mock_filter.return_value = sample_search_results
|
||||||
|
mock_format.return_value = "results"
|
||||||
|
|
||||||
|
# Test each source filter
|
||||||
|
for source in SourceNamespace:
|
||||||
|
result = await knowledge_search("test query", mock_context, source=source)
|
||||||
|
assert result == "results"
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
async def test_vector_search_timeout(self, mock_generate, mock_context, sample_embedding):
|
||||||
|
"""Test handling of vector search timeout."""
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = TimeoutError(
|
||||||
|
"Request timed out"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
assert "Error performing vector search:" in result
|
||||||
|
assert "Request timed out" in result
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
async def test_vector_search_connection_error(self, mock_generate, mock_context, sample_embedding):
|
||||||
|
"""Test handling of vector search connection error."""
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = ConnectionError(
|
||||||
|
"Connection refused"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
assert "Error performing vector search:" in result
|
||||||
|
assert "Connection refused" in result
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||||
|
async def test_format_results_unexpected_error(
|
||||||
|
self,
|
||||||
|
mock_filter,
|
||||||
|
mock_generate,
|
||||||
|
mock_context,
|
||||||
|
sample_embedding,
|
||||||
|
sample_search_results
|
||||||
|
):
|
||||||
|
"""Test handling of unexpected error in format_search_results."""
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||||
|
mock_filter.return_value = sample_search_results
|
||||||
|
|
||||||
|
# Mock format_search_results to raise an error
|
||||||
|
with patch('knowledge_search_mcp.__main__.format_search_results', side_effect=ValueError("Format error")):
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
assert "Unexpected error during search:" in result
|
||||||
|
assert "Format error" in result
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
async def test_filter_results_unexpected_error(self, mock_generate, mock_context, sample_embedding):
|
||||||
|
"""Test handling of unexpected error in filter_search_results."""
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = [
|
||||||
|
{"id": "doc1", "distance": 0.9, "content": "test"}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock filter_search_results to raise an error
|
||||||
|
with patch('knowledge_search_mcp.__main__.filter_search_results', side_effect=TypeError("Filter error")):
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
assert "Unexpected error during search:" in result
|
||||||
|
assert "Filter error" in result
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||||
|
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||||
|
async def test_long_query_truncation_in_logs(
|
||||||
|
self,
|
||||||
|
mock_format,
|
||||||
|
mock_filter,
|
||||||
|
mock_generate,
|
||||||
|
mock_context,
|
||||||
|
sample_embedding,
|
||||||
|
sample_search_results
|
||||||
|
):
|
||||||
|
"""Test that long queries are handled correctly."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||||
|
mock_filter.return_value = sample_search_results
|
||||||
|
mock_format.return_value = "results"
|
||||||
|
|
||||||
|
# Execute with very long query
|
||||||
|
long_query = "a" * 500
|
||||||
|
result = await knowledge_search(long_query, mock_context)
|
||||||
|
|
||||||
|
# Assert - should succeed
|
||||||
|
assert result == "results"
|
||||||
|
# Verify generate_query_embedding was called with full query
|
||||||
|
assert mock_generate.call_args[0][2] == long_query
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||||
|
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||||
|
async def test_multiple_results_returned(
|
||||||
|
self,
|
||||||
|
mock_format,
|
||||||
|
mock_filter,
|
||||||
|
mock_generate,
|
||||||
|
mock_context,
|
||||||
|
sample_embedding
|
||||||
|
):
|
||||||
|
"""Test handling of multiple search results."""
|
||||||
|
# Create larger result set
|
||||||
|
large_results: list[SearchResult] = [
|
||||||
|
{"id": f"doc{i}.txt", "distance": 0.9 - (i * 0.05), "content": f"Content {i}"}
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = large_results
|
||||||
|
mock_filter.return_value = large_results[:5] # Filter to top 5
|
||||||
|
mock_format.return_value = "formatted 5 results"
|
||||||
|
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
assert result == "formatted 5 results"
|
||||||
|
mock_filter.assert_called_once_with(large_results)
|
||||||
|
mock_format.assert_called_once_with(large_results[:5])
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
@patch('knowledge_search_mcp.__main__.filter_search_results')
|
||||||
|
@patch('knowledge_search_mcp.__main__.format_search_results')
|
||||||
|
async def test_settings_values_used_correctly(
|
||||||
|
self,
|
||||||
|
mock_format,
|
||||||
|
mock_filter,
|
||||||
|
mock_generate,
|
||||||
|
mock_context,
|
||||||
|
sample_embedding,
|
||||||
|
sample_search_results
|
||||||
|
):
|
||||||
|
"""Test that settings values are used correctly."""
|
||||||
|
# Customize settings
|
||||||
|
mock_context.request_context.lifespan_context.settings.embedding_model = "custom-model"
|
||||||
|
mock_context.request_context.lifespan_context.settings.deployed_index_id = "custom-index"
|
||||||
|
mock_context.request_context.lifespan_context.settings.search_limit = 20
|
||||||
|
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.return_value = sample_search_results
|
||||||
|
mock_filter.return_value = sample_search_results
|
||||||
|
mock_format.return_value = "results"
|
||||||
|
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
# Verify embedding model
|
||||||
|
assert mock_generate.call_args[0][1] == "custom-model"
|
||||||
|
|
||||||
|
# Verify vector search parameters
|
||||||
|
call_kwargs = mock_context.request_context.lifespan_context.vector_search.async_run_query.call_args.kwargs
|
||||||
|
assert call_kwargs["deployed_index_id"] == "custom-index"
|
||||||
|
assert call_kwargs["limit"] == 20
|
||||||
|
|
||||||
|
@patch('knowledge_search_mcp.__main__.generate_query_embedding')
|
||||||
|
async def test_graceful_degradation_on_partial_failure(
|
||||||
|
self, mock_generate, mock_context, sample_embedding
|
||||||
|
):
|
||||||
|
"""Test that errors are caught and returned as strings, not raised."""
|
||||||
|
mock_generate.return_value = (sample_embedding, None)
|
||||||
|
mock_context.request_context.lifespan_context.vector_search.async_run_query.side_effect = RuntimeError(
|
||||||
|
"Critical failure"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise, should return error message
|
||||||
|
result = await knowledge_search("test query", mock_context)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "Error performing vector search:" in result
|
||||||
|
assert "Critical failure" in result
|
||||||
110
tests/test_search.py
Normal file
110
tests/test_search.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Tests for vector search functionality."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from knowledge_search_mcp import (
|
||||||
|
GoogleCloudFileStorage,
|
||||||
|
GoogleCloudVectorSearch,
|
||||||
|
LRUCache,
|
||||||
|
SourceNamespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleCloudFileStorage:
|
||||||
|
"""Tests for GoogleCloudFileStorage."""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test storage initialization."""
|
||||||
|
storage = GoogleCloudFileStorage(bucket="test-bucket")
|
||||||
|
assert storage.bucket_name == "test-bucket"
|
||||||
|
assert isinstance(storage._cache, LRUCache)
|
||||||
|
assert storage._cache.max_size == 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit(self):
|
||||||
|
"""Test that cached files are returned without fetching."""
|
||||||
|
storage = GoogleCloudFileStorage(bucket="test-bucket")
|
||||||
|
test_content = b"cached content"
|
||||||
|
storage._cache.put("test.md", test_content)
|
||||||
|
|
||||||
|
result = await storage.async_get_file_stream("test.md")
|
||||||
|
|
||||||
|
assert result.read() == test_content
|
||||||
|
assert result.name == "test.md"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_miss(self):
|
||||||
|
"""Test that uncached files are fetched from GCS."""
|
||||||
|
storage = GoogleCloudFileStorage(bucket="test-bucket")
|
||||||
|
test_content = b"fetched content"
|
||||||
|
|
||||||
|
# Mock the storage download
|
||||||
|
with patch.object(storage, '_get_aio_storage') as mock_storage_getter:
|
||||||
|
mock_storage = AsyncMock()
|
||||||
|
mock_storage.download = AsyncMock(return_value=test_content)
|
||||||
|
mock_storage_getter.return_value = mock_storage
|
||||||
|
|
||||||
|
result = await storage.async_get_file_stream("test.md")
|
||||||
|
|
||||||
|
assert result.read() == test_content
|
||||||
|
assert storage._cache.get("test.md") == test_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleCloudVectorSearch:
|
||||||
|
"""Tests for GoogleCloudVectorSearch."""
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
"""Test vector search client initialization."""
|
||||||
|
vs = GoogleCloudVectorSearch(
|
||||||
|
project_id="test-project",
|
||||||
|
location="us-central1",
|
||||||
|
bucket="test-bucket",
|
||||||
|
index_name="test-index",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vs.project_id == "test-project"
|
||||||
|
assert vs.location == "us-central1"
|
||||||
|
assert vs.index_name == "test-index"
|
||||||
|
|
||||||
|
def test_configure_index_endpoint(self):
|
||||||
|
"""Test endpoint configuration."""
|
||||||
|
vs = GoogleCloudVectorSearch(
|
||||||
|
project_id="test-project",
|
||||||
|
location="us-central1",
|
||||||
|
bucket="test-bucket",
|
||||||
|
)
|
||||||
|
|
||||||
|
vs.configure_index_endpoint(
|
||||||
|
name="test-endpoint",
|
||||||
|
public_domain="test.domain.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vs._endpoint_name == "test-endpoint"
|
||||||
|
assert vs._endpoint_domain == "test.domain.com"
|
||||||
|
|
||||||
|
def test_configure_index_endpoint_validation(self):
|
||||||
|
"""Test that endpoint configuration validates inputs."""
|
||||||
|
vs = GoogleCloudVectorSearch(
|
||||||
|
project_id="test-project",
|
||||||
|
location="us-central1",
|
||||||
|
bucket="test-bucket",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="endpoint name"):
|
||||||
|
vs.configure_index_endpoint(name="", public_domain="test.com")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="endpoint domain"):
|
||||||
|
vs.configure_index_endpoint(name="test", public_domain="")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSourceNamespace:
|
||||||
|
"""Tests for SourceNamespace enum."""
|
||||||
|
|
||||||
|
def test_source_namespace_values(self):
|
||||||
|
"""Test that SourceNamespace has expected values."""
|
||||||
|
assert SourceNamespace.EDUCACION_FINANCIERA.value == "Educacion Financiera"
|
||||||
|
assert SourceNamespace.PRODUCTOS_Y_SERVICIOS.value == "Productos y Servicios"
|
||||||
|
assert SourceNamespace.FUNCIONALIDADES_APP_MOVIL.value == "Funcionalidades de la App Movil"
|
||||||
381
tests/test_search_services.py
Normal file
381
tests/test_search_services.py
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
"""Tests for search service functions."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from knowledge_search_mcp.services.search import (
|
||||||
|
generate_query_embedding,
|
||||||
|
filter_search_results,
|
||||||
|
format_search_results,
|
||||||
|
)
|
||||||
|
from knowledge_search_mcp.models import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateQueryEmbedding:
|
||||||
|
"""Tests for generate_query_embedding function."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_genai_client(self):
|
||||||
|
"""Create a mock genai client."""
|
||||||
|
client = MagicMock()
|
||||||
|
client.aio = MagicMock()
|
||||||
|
client.aio.models = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def test_successful_embedding_generation(self, mock_genai_client):
|
||||||
|
"""Test successful embedding generation."""
|
||||||
|
# Setup mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_embedding = MagicMock()
|
||||||
|
mock_embedding.values = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
mock_response.embeddings = [mock_embedding]
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
embedding, error = await generate_query_embedding(
|
||||||
|
mock_genai_client,
|
||||||
|
"models/text-embedding-004",
|
||||||
|
"What is financial education?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert error is None
|
||||||
|
assert embedding == [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||||
|
mock_genai_client.aio.models.embed_content.assert_called_once()
|
||||||
|
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
|
||||||
|
assert call_kwargs["model"] == "models/text-embedding-004"
|
||||||
|
assert call_kwargs["contents"] == "What is financial education?"
|
||||||
|
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
|
||||||
|
|
||||||
|
async def test_empty_query_string(self, mock_genai_client):
|
||||||
|
"""Test handling of empty query string."""
|
||||||
|
embedding, error = await generate_query_embedding(
|
||||||
|
mock_genai_client,
|
||||||
|
"models/text-embedding-004",
|
||||||
|
""
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding == []
|
||||||
|
assert error == "Error: Query cannot be empty"
|
||||||
|
mock_genai_client.aio.models.embed_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_whitespace_only_query(self, mock_genai_client):
|
||||||
|
"""Test handling of whitespace-only query."""
|
||||||
|
embedding, error = await generate_query_embedding(
|
||||||
|
mock_genai_client,
|
||||||
|
"models/text-embedding-004",
|
||||||
|
" \t\n "
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding == []
|
||||||
|
assert error == "Error: Query cannot be empty"
|
||||||
|
mock_genai_client.aio.models.embed_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_rate_limit_error_429(self, mock_genai_client):
|
||||||
|
"""Test handling of 429 rate limit error."""
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||||
|
side_effect=Exception("429 Too Many Requests")
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding, error = await generate_query_embedding(
|
||||||
|
mock_genai_client,
|
||||||
|
"models/text-embedding-004",
|
||||||
|
"test query"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding == []
|
||||||
|
assert error == "Error: API rate limit exceeded. Please try again later."
|
||||||
|
|
||||||
|
async def test_rate_limit_error_resource_exhausted(self, mock_genai_client):
|
||||||
|
"""Test handling of RESOURCE_EXHAUSTED error."""
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||||
|
side_effect=Exception("RESOURCE_EXHAUSTED: Quota exceeded")
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding, error = await generate_query_embedding(
|
||||||
|
mock_genai_client,
|
||||||
|
"models/text-embedding-004",
|
||||||
|
"test query"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding == []
|
||||||
|
assert error == "Error: API rate limit exceeded. Please try again later."
|
||||||
|
|
||||||
|
async def test_generic_api_error(self, mock_genai_client):
|
||||||
|
"""Test handling of generic API error."""
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||||
|
side_effect=ValueError("Invalid model name")
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding, error = await generate_query_embedding(
|
||||||
|
mock_genai_client,
|
||||||
|
"invalid-model",
|
||||||
|
"test query"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding == []
|
||||||
|
assert "Error generating embedding: Invalid model name" in error
|
||||||
|
|
||||||
|
async def test_network_error(self, mock_genai_client):
|
||||||
|
"""Test handling of network error."""
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||||
|
side_effect=ConnectionError("Network unreachable")
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding, error = await generate_query_embedding(
|
||||||
|
mock_genai_client,
|
||||||
|
"models/text-embedding-004",
|
||||||
|
"test query"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding == []
|
||||||
|
assert "Error generating embedding: Network unreachable" in error
|
||||||
|
|
||||||
|
async def test_long_query_truncation_in_logging(self, mock_genai_client):
|
||||||
|
"""Test that long queries are truncated in error logging."""
|
||||||
|
long_query = "a" * 200
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||||
|
side_effect=Exception("API error")
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding, error = await generate_query_embedding(
|
||||||
|
mock_genai_client,
|
||||||
|
"models/text-embedding-004",
|
||||||
|
long_query
|
||||||
|
)
|
||||||
|
|
||||||
|
assert embedding == []
|
||||||
|
assert error is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestFilterSearchResults:
|
||||||
|
"""Tests for filter_search_results function."""
|
||||||
|
|
||||||
|
def test_empty_results(self):
|
||||||
|
"""Test filtering empty results list."""
|
||||||
|
filtered = filter_search_results([])
|
||||||
|
assert filtered == []
|
||||||
|
|
||||||
|
def test_single_result_above_thresholds(self):
|
||||||
|
"""Test single result above both thresholds."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1", "distance": 0.85, "content": "test content"}
|
||||||
|
]
|
||||||
|
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert filtered[0]["id"] == "doc1"
|
||||||
|
|
||||||
|
def test_single_result_below_min_similarity(self):
|
||||||
|
"""Test single result below minimum similarity threshold."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1", "distance": 0.5, "content": "test content"}
|
||||||
|
]
|
||||||
|
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||||
|
assert filtered == []
|
||||||
|
|
||||||
|
def test_multiple_results_all_above_thresholds(self):
|
||||||
|
"""Test multiple results all above thresholds."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||||
|
{"id": "doc2", "distance": 0.90, "content": "content 2"},
|
||||||
|
{"id": "doc3", "distance": 0.85, "content": "content 3"},
|
||||||
|
]
|
||||||
|
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.8)
|
||||||
|
# max_sim = 0.95, cutoff = 0.95 * 0.8 = 0.76
|
||||||
|
# Results with distance > 0.76 and > 0.6: all three
|
||||||
|
assert len(filtered) == 3
|
||||||
|
|
||||||
|
def test_top_percent_filtering(self):
|
||||||
|
"""Test filtering by top_percent threshold."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1", "distance": 1.0, "content": "content 1"},
|
||||||
|
{"id": "doc2", "distance": 0.95, "content": "content 2"},
|
||||||
|
{"id": "doc3", "distance": 0.85, "content": "content 3"},
|
||||||
|
{"id": "doc4", "distance": 0.70, "content": "content 4"},
|
||||||
|
]
|
||||||
|
# max_sim = 1.0, cutoff = 1.0 * 0.9 = 0.9
|
||||||
|
# Results with distance > 0.9: doc1 (1.0), doc2 (0.95)
|
||||||
|
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||||
|
assert len(filtered) == 2
|
||||||
|
assert filtered[0]["id"] == "doc1"
|
||||||
|
assert filtered[1]["id"] == "doc2"
|
||||||
|
|
||||||
|
def test_min_similarity_filtering(self):
|
||||||
|
"""Test filtering by minimum similarity threshold."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||||
|
{"id": "doc2", "distance": 0.75, "content": "content 2"},
|
||||||
|
{"id": "doc3", "distance": 0.55, "content": "content 3"},
|
||||||
|
]
|
||||||
|
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
|
||||||
|
# doc1 > 0.855 and > 0.7: included
|
||||||
|
# doc2 < 0.855: excluded by top_percent
|
||||||
|
# doc3 < 0.7: excluded by min_similarity
|
||||||
|
filtered = filter_search_results(results, min_similarity=0.7, top_percent=0.9)
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert filtered[0]["id"] == "doc1"
|
||||||
|
|
||||||
|
def test_default_parameters(self):
|
||||||
|
"""Test filtering with default parameters."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1", "distance": 0.95, "content": "content 1"},
|
||||||
|
{"id": "doc2", "distance": 0.85, "content": "content 2"},
|
||||||
|
{"id": "doc3", "distance": 0.50, "content": "content 3"},
|
||||||
|
]
|
||||||
|
# Default: min_similarity=0.6, top_percent=0.9
|
||||||
|
# max_sim = 0.95, cutoff = 0.95 * 0.9 = 0.855
|
||||||
|
# doc1 > 0.855 and > 0.6: included
|
||||||
|
# doc2 < 0.855: excluded
|
||||||
|
# doc3 < 0.6: excluded
|
||||||
|
filtered = filter_search_results(results)
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert filtered[0]["id"] == "doc1"
|
||||||
|
|
||||||
|
def test_all_results_filtered_out(self):
|
||||||
|
"""Test when all results are filtered out."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1", "distance": 0.55, "content": "content 1"},
|
||||||
|
{"id": "doc2", "distance": 0.45, "content": "content 2"},
|
||||||
|
{"id": "doc3", "distance": 0.35, "content": "content 3"},
|
||||||
|
]
|
||||||
|
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||||
|
assert filtered == []
|
||||||
|
|
||||||
|
def test_exact_threshold_boundaries(self):
|
||||||
|
"""Test behavior at exact threshold boundaries."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1", "distance": 0.9, "content": "content 1"},
|
||||||
|
{"id": "doc2", "distance": 0.6, "content": "content 2"},
|
||||||
|
]
|
||||||
|
# max_sim = 0.9, cutoff = 0.9 * 0.9 = 0.81
|
||||||
|
# doc1: 0.9 > 0.81 and 0.9 > 0.6: included
|
||||||
|
# doc2: 0.6 < 0.81: excluded
|
||||||
|
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert filtered[0]["id"] == "doc1"
|
||||||
|
|
||||||
|
def test_identical_distances(self):
|
||||||
|
"""Test filtering with identical distance values."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1", "distance": 0.8, "content": "content 1"},
|
||||||
|
{"id": "doc2", "distance": 0.8, "content": "content 2"},
|
||||||
|
{"id": "doc3", "distance": 0.8, "content": "content 3"},
|
||||||
|
]
|
||||||
|
# max_sim = 0.8, cutoff = 0.8 * 0.9 = 0.72
|
||||||
|
# All have distance 0.8 > 0.72 and > 0.6: all included
|
||||||
|
filtered = filter_search_results(results, min_similarity=0.6, top_percent=0.9)
|
||||||
|
assert len(filtered) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatSearchResults:
|
||||||
|
"""Tests for format_search_results function."""
|
||||||
|
|
||||||
|
def test_empty_results(self):
|
||||||
|
"""Test formatting empty results list."""
|
||||||
|
formatted = format_search_results([])
|
||||||
|
assert formatted == "No relevant documents found for your query."
|
||||||
|
|
||||||
|
def test_single_result(self):
|
||||||
|
"""Test formatting single result."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1.txt", "distance": 0.95, "content": "This is the content."}
|
||||||
|
]
|
||||||
|
formatted = format_search_results(results)
|
||||||
|
expected = "<document 1 name=doc1.txt>\nThis is the content.\n</document 1>"
|
||||||
|
assert formatted == expected
|
||||||
|
|
||||||
|
def test_multiple_results(self):
|
||||||
|
"""Test formatting multiple results."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1.txt", "distance": 0.95, "content": "First document content."},
|
||||||
|
{"id": "doc2.txt", "distance": 0.85, "content": "Second document content."},
|
||||||
|
{"id": "doc3.txt", "distance": 0.75, "content": "Third document content."},
|
||||||
|
]
|
||||||
|
formatted = format_search_results(results)
|
||||||
|
expected = (
|
||||||
|
"<document 1 name=doc1.txt>\nFirst document content.\n</document 1>\n"
|
||||||
|
"<document 2 name=doc2.txt>\nSecond document content.\n</document 2>\n"
|
||||||
|
"<document 3 name=doc3.txt>\nThird document content.\n</document 3>"
|
||||||
|
)
|
||||||
|
assert formatted == expected
|
||||||
|
|
||||||
|
def test_multiline_content(self):
|
||||||
|
"""Test formatting results with multiline content."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{
|
||||||
|
"id": "doc1.txt",
|
||||||
|
"distance": 0.95,
|
||||||
|
"content": "Line 1\nLine 2\nLine 3"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
formatted = format_search_results(results)
|
||||||
|
expected = "<document 1 name=doc1.txt>\nLine 1\nLine 2\nLine 3\n</document 1>"
|
||||||
|
assert formatted == expected
|
||||||
|
|
||||||
|
def test_special_characters_in_content(self):
|
||||||
|
"""Test formatting with special characters in content."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{
|
||||||
|
"id": "doc1.txt",
|
||||||
|
"distance": 0.95,
|
||||||
|
"content": "Content with <special> & \"characters\""
|
||||||
|
}
|
||||||
|
]
|
||||||
|
formatted = format_search_results(results)
|
||||||
|
expected = '<document 1 name=doc1.txt>\nContent with <special> & "characters"\n</document 1>'
|
||||||
|
assert formatted == expected
|
||||||
|
|
||||||
|
def test_special_characters_in_document_id(self):
|
||||||
|
"""Test formatting with special characters in document ID."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{
|
||||||
|
"id": "path/to/doc-name_v2.txt",
|
||||||
|
"distance": 0.95,
|
||||||
|
"content": "Some content"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
formatted = format_search_results(results)
|
||||||
|
expected = "<document 1 name=path/to/doc-name_v2.txt>\nSome content\n</document 1>"
|
||||||
|
assert formatted == expected
|
||||||
|
|
||||||
|
def test_empty_content(self):
|
||||||
|
"""Test formatting result with empty content."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1.txt", "distance": 0.95, "content": ""}
|
||||||
|
]
|
||||||
|
formatted = format_search_results(results)
|
||||||
|
expected = "<document 1 name=doc1.txt>\n\n</document 1>"
|
||||||
|
assert formatted == expected
|
||||||
|
|
||||||
|
def test_document_numbering(self):
|
||||||
|
"""Test that document numbering starts at 1 and increments correctly."""
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "a.txt", "distance": 0.9, "content": "A"},
|
||||||
|
{"id": "b.txt", "distance": 0.8, "content": "B"},
|
||||||
|
{"id": "c.txt", "distance": 0.7, "content": "C"},
|
||||||
|
{"id": "d.txt", "distance": 0.6, "content": "D"},
|
||||||
|
{"id": "e.txt", "distance": 0.5, "content": "E"},
|
||||||
|
]
|
||||||
|
formatted = format_search_results(results)
|
||||||
|
|
||||||
|
assert "<document 1 name=a.txt>" in formatted
|
||||||
|
assert "</document 1>" in formatted
|
||||||
|
assert "<document 2 name=b.txt>" in formatted
|
||||||
|
assert "</document 2>" in formatted
|
||||||
|
assert "<document 3 name=c.txt>" in formatted
|
||||||
|
assert "</document 3>" in formatted
|
||||||
|
assert "<document 4 name=d.txt>" in formatted
|
||||||
|
assert "</document 4>" in formatted
|
||||||
|
assert "<document 5 name=e.txt>" in formatted
|
||||||
|
assert "</document 5>" in formatted
|
||||||
|
|
||||||
|
def test_very_long_content(self):
|
||||||
|
"""Test formatting with very long content."""
|
||||||
|
long_content = "A" * 10000
|
||||||
|
results: list[SearchResult] = [
|
||||||
|
{"id": "doc1.txt", "distance": 0.95, "content": long_content}
|
||||||
|
]
|
||||||
|
formatted = format_search_results(results)
|
||||||
|
assert f"<document 1 name=doc1.txt>\n{long_content}\n</document 1>" == formatted
|
||||||
|
assert len(formatted) > 10000
|
||||||
436
tests/test_validation_services.py
Normal file
436
tests/test_validation_services.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
"""Tests for validation service functions."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from aiohttp import ClientResponse
|
||||||
|
|
||||||
|
from knowledge_search_mcp.services.validation import (
|
||||||
|
validate_genai_access,
|
||||||
|
validate_gcs_access,
|
||||||
|
validate_vector_search_access,
|
||||||
|
)
|
||||||
|
from knowledge_search_mcp.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateGenAIAccess:
|
||||||
|
"""Tests for validate_genai_access function."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings(self):
|
||||||
|
"""Create mock settings."""
|
||||||
|
settings = MagicMock(spec=Settings)
|
||||||
|
settings.embedding_model = "models/text-embedding-004"
|
||||||
|
settings.project_id = "test-project"
|
||||||
|
settings.location = "us-central1"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_genai_client(self):
|
||||||
|
"""Create a mock genai client."""
|
||||||
|
client = MagicMock()
|
||||||
|
client.aio = MagicMock()
|
||||||
|
client.aio.models = MagicMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
async def test_successful_validation(self, mock_genai_client, mock_settings):
|
||||||
|
"""Test successful GenAI access validation."""
|
||||||
|
# Setup mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_embedding = MagicMock()
|
||||||
|
mock_embedding.values = [0.1] * 768 # Typical embedding dimension
|
||||||
|
mock_response.embeddings = [mock_embedding]
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert error is None
|
||||||
|
mock_genai_client.aio.models.embed_content.assert_called_once()
|
||||||
|
call_kwargs = mock_genai_client.aio.models.embed_content.call_args.kwargs
|
||||||
|
assert call_kwargs["model"] == "models/text-embedding-004"
|
||||||
|
assert call_kwargs["contents"] == "test"
|
||||||
|
assert call_kwargs["config"].task_type == "RETRIEVAL_QUERY"
|
||||||
|
|
||||||
|
async def test_empty_response(self, mock_genai_client, mock_settings):
|
||||||
|
"""Test handling of empty response."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.embeddings = []
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||||
|
|
||||||
|
assert error == "Embedding validation returned empty response"
|
||||||
|
|
||||||
|
async def test_none_response(self, mock_genai_client, mock_settings):
|
||||||
|
"""Test handling of None response."""
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||||
|
|
||||||
|
assert error == "Embedding validation returned empty response"
|
||||||
|
|
||||||
|
async def test_api_permission_error(self, mock_genai_client, mock_settings):
|
||||||
|
"""Test handling of permission denied error."""
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||||
|
side_effect=PermissionError("Permission denied for GenAI API")
|
||||||
|
)
|
||||||
|
|
||||||
|
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "GenAI:" in error
|
||||||
|
assert "Permission denied for GenAI API" in error
|
||||||
|
|
||||||
|
async def test_api_quota_error(self, mock_genai_client, mock_settings):
|
||||||
|
"""Test handling of quota exceeded error."""
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||||
|
side_effect=Exception("Quota exceeded")
|
||||||
|
)
|
||||||
|
|
||||||
|
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "GenAI:" in error
|
||||||
|
assert "Quota exceeded" in error
|
||||||
|
|
||||||
|
async def test_network_error(self, mock_genai_client, mock_settings):
|
||||||
|
"""Test handling of network error."""
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||||
|
side_effect=ConnectionError("Network unreachable")
|
||||||
|
)
|
||||||
|
|
||||||
|
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "GenAI:" in error
|
||||||
|
assert "Network unreachable" in error
|
||||||
|
|
||||||
|
async def test_invalid_model_error(self, mock_genai_client, mock_settings):
|
||||||
|
"""Test handling of invalid model error."""
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(
|
||||||
|
side_effect=ValueError("Invalid model name")
|
||||||
|
)
|
||||||
|
|
||||||
|
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "GenAI:" in error
|
||||||
|
assert "Invalid model name" in error
|
||||||
|
|
||||||
|
async def test_embeddings_with_zero_values(self, mock_genai_client, mock_settings):
|
||||||
|
"""Test validation with empty embedding values."""
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_embedding = MagicMock()
|
||||||
|
mock_embedding.values = []
|
||||||
|
mock_response.embeddings = [mock_embedding]
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.embed_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
error = await validate_genai_access(mock_genai_client, mock_settings)
|
||||||
|
|
||||||
|
# Should succeed even with empty values, as long as embeddings exist
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateGCSAccess:
|
||||||
|
"""Tests for validate_gcs_access function."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings(self):
|
||||||
|
"""Create mock settings."""
|
||||||
|
settings = MagicMock(spec=Settings)
|
||||||
|
settings.bucket = "test-bucket"
|
||||||
|
settings.project_id = "test-project"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_vector_search(self):
|
||||||
|
"""Create a mock vector search client."""
|
||||||
|
vs = MagicMock()
|
||||||
|
vs.storage = MagicMock()
|
||||||
|
return vs
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session(self):
|
||||||
|
"""Create a mock aiohttp session."""
|
||||||
|
session = MagicMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response(self):
|
||||||
|
"""Create a mock HTTP response."""
|
||||||
|
response = MagicMock()
|
||||||
|
response.text = AsyncMock(return_value='{"items": []}')
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||||
|
"""Test successful GCS bucket access validation."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.ok = True
|
||||||
|
mock_session.get = MagicMock()
|
||||||
|
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||||
|
mock_token = MockToken.return_value
|
||||||
|
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||||
|
|
||||||
|
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is None
|
||||||
|
mock_session.get.assert_called_once()
|
||||||
|
call_args = mock_session.get.call_args
|
||||||
|
assert "test-bucket" in call_args[0][0]
|
||||||
|
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-access-token"
|
||||||
|
|
||||||
|
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||||
|
"""Test handling of 403 access denied."""
|
||||||
|
mock_response.status = 403
|
||||||
|
mock_response.ok = False
|
||||||
|
mock_session.get = MagicMock()
|
||||||
|
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||||
|
mock_token = MockToken.return_value
|
||||||
|
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||||
|
|
||||||
|
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "Access denied to bucket 'test-bucket'" in error
|
||||||
|
assert "permissions" in error.lower()
|
||||||
|
|
||||||
|
async def test_bucket_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||||
|
"""Test handling of 404 bucket not found."""
|
||||||
|
mock_response.status = 404
|
||||||
|
mock_response.ok = False
|
||||||
|
mock_session.get = MagicMock()
|
||||||
|
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||||
|
mock_token = MockToken.return_value
|
||||||
|
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||||
|
|
||||||
|
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "Bucket 'test-bucket' not found" in error
|
||||||
|
assert "bucket name" in error.lower()
|
||||||
|
|
||||||
|
async def test_server_error_500(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||||
|
"""Test handling of 500 server error."""
|
||||||
|
mock_response.status = 500
|
||||||
|
mock_response.ok = False
|
||||||
|
mock_response.text = AsyncMock(return_value='{"error": "Internal server error"}')
|
||||||
|
mock_session.get = MagicMock()
|
||||||
|
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||||
|
mock_token = MockToken.return_value
|
||||||
|
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||||
|
|
||||||
|
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "Failed to access bucket 'test-bucket': 500" in error
|
||||||
|
|
||||||
|
async def test_token_acquisition_error(self, mock_vector_search, mock_settings, mock_session):
|
||||||
|
"""Test handling of token acquisition error."""
|
||||||
|
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||||
|
mock_token = MockToken.return_value
|
||||||
|
mock_token.get = AsyncMock(side_effect=Exception("Failed to get access token"))
|
||||||
|
|
||||||
|
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "GCS:" in error
|
||||||
|
assert "Failed to get access token" in error
|
||||||
|
|
||||||
|
async def test_network_error(self, mock_vector_search, mock_settings, mock_session):
|
||||||
|
"""Test handling of network error."""
|
||||||
|
mock_session.get = MagicMock(side_effect=ConnectionError("Network unreachable"))
|
||||||
|
mock_vector_search.storage._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
with patch('knowledge_search_mcp.services.validation.Token') as MockToken:
|
||||||
|
mock_token = MockToken.return_value
|
||||||
|
mock_token.get = AsyncMock(return_value="fake-access-token")
|
||||||
|
|
||||||
|
error = await validate_gcs_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "GCS:" in error
|
||||||
|
assert "Network unreachable" in error
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateVectorSearchAccess:
|
||||||
|
"""Tests for validate_vector_search_access function."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings(self):
|
||||||
|
"""Create mock settings."""
|
||||||
|
settings = MagicMock(spec=Settings)
|
||||||
|
settings.endpoint_name = "projects/test/locations/us-central1/indexEndpoints/test-endpoint"
|
||||||
|
settings.location = "us-central1"
|
||||||
|
return settings
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_vector_search(self):
|
||||||
|
"""Create a mock vector search client."""
|
||||||
|
vs = MagicMock()
|
||||||
|
vs._async_get_auth_headers = AsyncMock(return_value={"Authorization": "Bearer fake-token"})
|
||||||
|
return vs
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session(self):
|
||||||
|
"""Create a mock aiohttp session."""
|
||||||
|
session = MagicMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response(self):
|
||||||
|
"""Create a mock HTTP response."""
|
||||||
|
response = MagicMock()
|
||||||
|
response.text = AsyncMock(return_value='{"name": "test-endpoint"}')
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def test_successful_validation(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||||
|
"""Test successful vector search endpoint validation."""
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.ok = True
|
||||||
|
mock_session.get = MagicMock()
|
||||||
|
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_vector_search._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is None
|
||||||
|
mock_vector_search._async_get_auth_headers.assert_called_once()
|
||||||
|
mock_session.get.assert_called_once()
|
||||||
|
call_args = mock_session.get.call_args
|
||||||
|
assert "us-central1-aiplatform.googleapis.com" in call_args[0][0]
|
||||||
|
assert "test-endpoint" in call_args[0][0]
|
||||||
|
assert call_args[1]["headers"]["Authorization"] == "Bearer fake-token"
|
||||||
|
|
||||||
|
async def test_access_denied_403(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||||
|
"""Test handling of 403 access denied."""
|
||||||
|
mock_response.status = 403
|
||||||
|
mock_response.ok = False
|
||||||
|
mock_session.get = MagicMock()
|
||||||
|
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_vector_search._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "Access denied to endpoint" in error
|
||||||
|
assert "test-endpoint" in error
|
||||||
|
assert "permissions" in error.lower()
|
||||||
|
|
||||||
|
async def test_endpoint_not_found_404(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||||
|
"""Test handling of 404 endpoint not found."""
|
||||||
|
mock_response.status = 404
|
||||||
|
mock_response.ok = False
|
||||||
|
mock_session.get = MagicMock()
|
||||||
|
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_vector_search._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "not found" in error.lower()
|
||||||
|
assert "test-endpoint" in error
|
||||||
|
|
||||||
|
async def test_server_error_503(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||||
|
"""Test handling of 503 service unavailable."""
|
||||||
|
mock_response.status = 503
|
||||||
|
mock_response.ok = False
|
||||||
|
mock_response.text = AsyncMock(return_value='{"error": "Service unavailable"}')
|
||||||
|
mock_session.get = MagicMock()
|
||||||
|
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_vector_search._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "Failed to access endpoint" in error
|
||||||
|
assert "503" in error
|
||||||
|
|
||||||
|
async def test_auth_header_error(self, mock_vector_search, mock_settings):
|
||||||
|
"""Test handling of authentication header error."""
|
||||||
|
mock_vector_search._async_get_auth_headers = AsyncMock(
|
||||||
|
side_effect=Exception("Failed to get auth headers")
|
||||||
|
)
|
||||||
|
|
||||||
|
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "Vector Search:" in error
|
||||||
|
assert "Failed to get auth headers" in error
|
||||||
|
|
||||||
|
async def test_network_timeout(self, mock_vector_search, mock_settings, mock_session):
|
||||||
|
"""Test handling of network timeout."""
|
||||||
|
mock_session.get = MagicMock(side_effect=TimeoutError("Request timed out"))
|
||||||
|
mock_vector_search._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "Vector Search:" in error
|
||||||
|
assert "Request timed out" in error
|
||||||
|
|
||||||
|
async def test_connection_error(self, mock_vector_search, mock_settings, mock_session):
|
||||||
|
"""Test handling of connection error."""
|
||||||
|
mock_session.get = MagicMock(side_effect=ConnectionError("Connection refused"))
|
||||||
|
mock_vector_search._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is not None
|
||||||
|
assert "Vector Search:" in error
|
||||||
|
assert "Connection refused" in error
|
||||||
|
|
||||||
|
async def test_endpoint_url_construction(self, mock_vector_search, mock_settings, mock_session, mock_response):
|
||||||
|
"""Test that endpoint URL is constructed correctly."""
|
||||||
|
mock_response.status = 200
|
||||||
|
mock_response.ok = True
|
||||||
|
mock_session.get = MagicMock()
|
||||||
|
mock_session.get.return_value.__aenter__ = AsyncMock(return_value=mock_response)
|
||||||
|
mock_session.get.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_vector_search._get_aio_session.return_value = mock_session
|
||||||
|
|
||||||
|
# Custom location
|
||||||
|
mock_settings.location = "europe-west1"
|
||||||
|
mock_settings.endpoint_name = "projects/my-project/locations/europe-west1/indexEndpoints/my-endpoint"
|
||||||
|
|
||||||
|
error = await validate_vector_search_access(mock_vector_search, mock_settings)
|
||||||
|
|
||||||
|
assert error is None
|
||||||
|
call_args = mock_session.get.call_args
|
||||||
|
url = call_args[0][0]
|
||||||
|
assert "europe-west1-aiplatform.googleapis.com" in url
|
||||||
|
assert "my-endpoint" in url
|
||||||
Reference in New Issue
Block a user