Lean MCP implementation

This commit is contained in:
2026-02-23 03:29:21 +00:00
parent a9bc36b5fc
commit 159e8ee433
37 changed files with 2380 additions and 3541 deletions

View File

@@ -1 +0,0 @@
"""RAG evaluation agent package."""

View File

@@ -1,92 +0,0 @@
"""Pydantic AI agent with RAG tool for vector search."""
import time
import structlog
from pydantic import BaseModel
from pydantic_ai import Agent, Embedder, RunContext
from pydantic_ai.models.google import GoogleModel
from rag_eval.config import settings
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
logger = structlog.get_logger(__name__)
class Deps(BaseModel):
"""Dependencies injected into the agent at runtime."""
vector_search: GoogleCloudVectorSearch
embedder: Embedder
model_config = {"arbitrary_types_allowed": True}
model = GoogleModel(
settings.agent_language_model,
provider=settings.provider,
)
agent = Agent(
model,
deps_type=Deps,
system_prompt=settings.agent_instructions,
)
@agent.tool
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
"""Search the vector index for the given query.
Args:
ctx: The run context containing dependencies.
query: The query to search for.
Returns:
A formatted string containing the search results.
"""
t0 = time.perf_counter()
min_sim = 0.6
query_embedding = await ctx.deps.embedder.embed_query(query)
t_embed = time.perf_counter()
search_results = await ctx.deps.vector_search.async_run_query(
deployed_index_id=settings.index_deployed_id,
query=list(query_embedding.embeddings[0]),
limit=5,
)
t_search = time.perf_counter()
if search_results:
max_sim = max(r["distance"] for r in search_results)
cutoff = max_sim * 0.9
search_results = [
s
for s in search_results
if s["distance"] > cutoff and s["distance"] > min_sim
]
logger.info(
"conocimiento.timing",
embedding_ms=round((t_embed - t0) * 1000, 1),
vector_search_ms=round((t_search - t_embed) * 1000, 1),
total_ms=round((t_search - t0) * 1000, 1),
chunks=[s["id"] for s in search_results],
)
formatted_results = [
f"<document {i} name={result['id']}>\n"
f"{result['content']}\n"
f"</document {i}>"
for i, result in enumerate(search_results, start=1)
]
return "\n".join(formatted_results)
if __name__ == "__main__":
deps = Deps(
vector_search=settings.vector_search,
embedder=settings.embedder,
)
agent.to_cli_sync(deps=deps)

View File

@@ -1,92 +0,0 @@
"""Application settings loaded from YAML and environment variables."""
import os
from functools import cached_property
from pydantic_ai import Embedder
from pydantic_ai.providers.google import GoogleProvider
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
class Settings(BaseSettings):
"""Application settings loaded from config.yaml and env vars."""
project_id: str
location: str
bucket: str
agent_name: str
agent_instructions: str
agent_language_model: str
agent_embedding_model: str
agent_thinking: int
index_name: str
index_deployed_id: str
index_endpoint: str
index_dimensions: int
index_machine_type: str = "e2-standard-16"
index_origin: str
index_destination: str
index_chunk_limit: int
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
) -> tuple[PydanticBaseSettingsSource, ...]:
"""Use env vars and YAML as settings sources."""
return (
env_settings,
YamlConfigSettingsSource(settings_cls),
)
@cached_property
def provider(self) -> GoogleProvider:
"""Return a Google provider configured for Vertex AI."""
return GoogleProvider(
project=self.project_id,
location=self.location,
)
@cached_property
def vector_search(self) -> GoogleCloudVectorSearch:
"""Return a configured vector search client."""
vs = GoogleCloudVectorSearch(
project_id=self.project_id,
location=self.location,
bucket=self.bucket,
index_name=self.index_name,
)
vs.load_index_endpoint(self.index_endpoint)
return vs
@cached_property
def embedder(self) -> Embedder:
"""Return an embedder configured for the agent's embedding model."""
from pydantic_ai.embeddings.google import GoogleEmbeddingModel # noqa: PLC0415
model = GoogleEmbeddingModel(
self.agent_embedding_model,
provider=self.provider,
)
return Embedder(model)
settings = Settings.model_validate({})

View File

@@ -1 +0,0 @@
"""File storage provider implementations."""

View File

@@ -1,56 +0,0 @@
"""Abstract base class for file storage providers."""
from abc import ABC, abstractmethod
from typing import BinaryIO
class BaseFileStorage(ABC):
"""Abstract base class for a remote file processor.
Defines the interface for listing and processing files from
a remote source.
"""
@abstractmethod
def upload_file(
self,
file_path: str,
destination_blob_name: str,
content_type: str | None = None,
) -> None:
"""Upload a file to the remote source.
Args:
file_path: The local path to the file to upload.
destination_blob_name: Name of the file in remote storage.
content_type: The content type of the file.
"""
...
@abstractmethod
def list_files(self, path: str | None = None) -> list[str]:
"""List files from a remote location.
Args:
path: Path to a specific file or directory. If None,
recursively lists all files in the bucket.
Returns:
A list of file paths.
"""
...
@abstractmethod
def get_file_stream(self, file_name: str) -> BinaryIO:
"""Get a file from the remote source as a file-like object.
Args:
file_name: The name of the file to retrieve.
Returns:
A file-like object containing the file data.
"""
...

View File

@@ -1,188 +0,0 @@
"""Google Cloud Storage file storage implementation."""
import asyncio
import io
import logging
from typing import BinaryIO
import aiohttp
from gcloud.aio.storage import Storage
from google.cloud import storage
from rag_eval.file_storage.base import BaseFileStorage
logger = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500
class GoogleCloudFileStorage(BaseFileStorage):
"""File storage backed by Google Cloud Storage."""
def __init__(self, bucket: str) -> None: # noqa: D107
self.bucket_name = bucket
self.storage_client = storage.Client()
self.bucket_client = self.storage_client.bucket(self.bucket_name)
self._aio_session: aiohttp.ClientSession | None = None
self._aio_storage: Storage | None = None
self._cache: dict[str, bytes] = {}
def upload_file(
self,
file_path: str,
destination_blob_name: str,
content_type: str | None = None,
) -> None:
"""Upload a file to Cloud Storage.
Args:
file_path: The local path to the file to upload.
destination_blob_name: Name of the blob in the bucket.
content_type: The content type of the file.
"""
blob = self.bucket_client.blob(destination_blob_name)
blob.upload_from_filename(
file_path,
content_type=content_type,
if_generation_match=0,
)
self._cache.pop(destination_blob_name, None)
def list_files(self, path: str | None = None) -> list[str]:
"""List all files at the given path in the bucket.
If path is None, recursively lists all files.
Args:
path: Prefix to filter files by.
Returns:
A list of blob names.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
return [blob.name for blob in blobs]
def get_file_stream(self, file_name: str) -> BinaryIO:
"""Get a file as a file-like object, using cache.
Args:
file_name: The blob name to retrieve.
Returns:
A BytesIO stream with the file contents.
"""
if file_name not in self._cache:
blob = self.bucket_client.blob(file_name)
self._cache[file_name] = blob.download_as_bytes()
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def _get_aio_storage(self) -> Storage:
if self._aio_storage is None:
self._aio_storage = Storage(
session=self._get_aio_session(),
)
return self._aio_storage
async def async_get_file_stream(
self, file_name: str, max_retries: int = 3,
) -> BinaryIO:
"""Get a file asynchronously with retry on transient errors.
Args:
file_name: The blob name to retrieve.
max_retries: Maximum number of retry attempts.
Returns:
A BytesIO stream with the file contents.
Raises:
TimeoutError: If all retry attempts fail.
"""
if file_name in self._cache:
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
storage_client = self._get_aio_storage()
last_exception: Exception | None = None
for attempt in range(max_retries):
try:
self._cache[file_name] = await storage_client.download(
self.bucket_name, file_name,
)
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
except TimeoutError as exc:
last_exception = exc
logger.warning(
"Timeout downloading gs://%s/%s (attempt %d/%d)",
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
except aiohttp.ClientResponseError as exc:
last_exception = exc
if (
exc.status == HTTP_TOO_MANY_REQUESTS
or exc.status >= HTTP_SERVER_ERROR
):
logger.warning(
"HTTP %d downloading gs://%s/%s "
"(attempt %d/%d)",
exc.status,
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
else:
raise
else:
return file_stream
if attempt < max_retries - 1:
delay = 0.5 * (2**attempt)
await asyncio.sleep(delay)
msg = (
f"Failed to download gs://{self.bucket_name}/{file_name} "
f"after {max_retries} attempts"
)
raise TimeoutError(msg) from last_exception
def delete_files(self, path: str) -> None:
"""Delete all files at the given path in the bucket.
Args:
path: Prefix of blobs to delete.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
for blob in blobs:
blob.delete()
self._cache.pop(blob.name, None)

View File

@@ -1,47 +0,0 @@
"""Structured logging configuration using structlog."""
import logging
import sys
import structlog
def setup_logging(*, json: bool = True, level: int = logging.INFO) -> None:
"""Configure structlog with JSON or console output."""
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(fmt="iso"),
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
]
if json:
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
structlog.processors.JSONRenderer(),
],
)
else:
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
structlog.dev.ConsoleRenderer(),
],
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root = logging.getLogger()
root.handlers.clear()
root.addHandler(handler)
root.setLevel(level)
structlog.configure(
processors=shared_processors,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)

View File

@@ -1,61 +0,0 @@
"""FastAPI server exposing the RAG agent endpoint."""
import time
from typing import Literal
from uuid import uuid4
import structlog
from fastapi import FastAPI
from pydantic import BaseModel
from rag_eval.agent import Deps, agent
from rag_eval.config import settings
from rag_eval.logging import setup_logging
logger = structlog.get_logger(__name__)
setup_logging()
app = FastAPI(title="RAG Agent")
class Message(BaseModel):
"""A single chat message."""
role: Literal["system", "user", "assistant"]
content: str
class AgentRequest(BaseModel):
"""Request body for the agent endpoint."""
messages: list[Message]
class AgentResponse(BaseModel):
"""Response body from the agent endpoint."""
response: str
@app.post("/agent")
async def run_agent(request: AgentRequest) -> AgentResponse:
"""Run the RAG agent with the provided messages."""
request_id = uuid4().hex[:8]
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(request_id=request_id)
prompt = request.messages[-1].content
logger.info("request.start", prompt_length=len(prompt))
t0 = time.perf_counter()
deps = Deps(
vector_search=settings.vector_search,
embedder=settings.embedder,
)
result = await agent.run(prompt, deps=deps)
elapsed = round((time.perf_counter() - t0) * 1000, 1)
logger.info("request.end", elapsed_ms=elapsed)
return AgentResponse(response=result.output)

View File

@@ -1 +0,0 @@
"""Vector search provider implementations."""

View File

@@ -1,68 +0,0 @@
"""Abstract base class for vector search providers."""
from abc import ABC, abstractmethod
from typing import Any, TypedDict
class SearchResult(TypedDict):
"""A single vector search result."""
id: str
distance: float
content: str
class BaseVectorSearch(ABC):
"""Abstract base class for a vector search provider.
This class defines the standard interface for creating a vector search
index and running queries against it.
"""
@abstractmethod
def create_index(
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
) -> None:
"""Create a new vector search index with the provided content.
Args:
name: The desired name for the new index.
content_path: Path to the data used to populate the index.
**kwargs: Additional provider-specific arguments.
"""
...
@abstractmethod
def update_index(
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
) -> None:
"""Update an existing vector search index with new content.
Args:
index_name: The name of the index to update.
content_path: Path to the data used to populate the index.
**kwargs: Additional provider-specific arguments.
"""
...
@abstractmethod
def run_query(
self,
deployed_index_id: str,
query: list[float],
limit: int,
) -> list[SearchResult]:
"""Run a similarity search query against the index.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
...

View File

@@ -1,310 +0,0 @@
"""Google Cloud Vertex AI Vector Search implementation."""
import asyncio
from collections.abc import Sequence
from typing import Any
from uuid import uuid4
import aiohttp
import google.auth
import google.auth.credentials
import google.auth.transport.requests
from gcloud.aio.auth import Token
from google.cloud import aiplatform
from rag_eval.file_storage.google_cloud import GoogleCloudFileStorage
from rag_eval.vector_search.base import BaseVectorSearch, SearchResult
class GoogleCloudVectorSearch(BaseVectorSearch):
"""A vector search provider using Vertex AI Vector Search."""
def __init__(
self,
project_id: str,
location: str,
bucket: str,
index_name: str | None = None,
) -> None:
"""Initialize the GoogleCloudVectorSearch client.
Args:
project_id: The Google Cloud project ID.
location: The Google Cloud location (e.g., 'us-central1').
bucket: The GCS bucket to use for file storage.
index_name: The name of the index.
"""
aiplatform.init(project=project_id, location=location)
self.project_id = project_id
self.location = location
self.storage = GoogleCloudFileStorage(bucket=bucket)
self.index_name = index_name
self._credentials: google.auth.credentials.Credentials | None = None
self._aio_session: aiohttp.ClientSession | None = None
self._async_token: Token | None = None
def _get_auth_headers(self) -> dict[str, str]:
if self._credentials is None:
self._credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
if not self._credentials.token or self._credentials.expired:
self._credentials.refresh(
google.auth.transport.requests.Request(),
)
return {
"Authorization": f"Bearer {self._credentials.token}",
"Content-Type": "application/json",
}
async def _async_get_auth_headers(self) -> dict[str, str]:
if self._async_token is None:
self._async_token = Token(
session=self._get_aio_session(),
scopes=[
"https://www.googleapis.com/auth/cloud-platform",
],
)
access_token = await self._async_token.get()
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def create_index(
self,
name: str,
content_path: str,
*,
dimensions: int = 3072,
approximate_neighbors_count: int = 150,
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
**kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Create a new Vertex AI Vector Search index.
Args:
name: The display name for the new index.
content_path: GCS URI to the embeddings JSON file.
dimensions: Number of dimensions in embedding vectors.
approximate_neighbors_count: Neighbors to find per vector.
distance_measure_type: The distance measure to use.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
display_name=name,
contents_delta_uri=content_path,
dimensions=dimensions,
approximate_neighbors_count=approximate_neighbors_count,
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
leaf_node_embedding_count=1000,
leaf_nodes_to_search_percent=10,
)
self.index = index
def update_index(
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Update an existing Vertex AI Vector Search index.
Args:
index_name: The resource name of the index to update.
content_path: GCS URI to the new embeddings JSON file.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex(index_name=index_name)
index.update_embeddings(
contents_delta_uri=content_path,
)
self.index = index
def deploy_index(
self,
index_name: str,
machine_type: str = "e2-standard-2",
) -> None:
"""Deploy a Vertex AI Vector Search index to an endpoint.
Args:
index_name: The name of the index to deploy.
machine_type: The machine type for the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=f"{index_name}-endpoint",
public_endpoint_enabled=True,
)
index_endpoint.deploy_index(
index=self.index,
deployed_index_id=(
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
),
machine_type=machine_type,
)
self.index_endpoint = index_endpoint
def load_index_endpoint(self, endpoint_name: str) -> None:
"""Load an existing Vertex AI Vector Search index endpoint.
Args:
endpoint_name: The resource name of the index endpoint.
"""
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
endpoint_name,
)
if not self.index_endpoint.public_endpoint_domain_name:
msg = (
"The index endpoint does not have a public endpoint. "
"Ensure the endpoint is configured for public access."
)
raise ValueError(msg)
def run_query(
self,
deployed_index_id: str,
query: list[float],
limit: int,
) -> list[SearchResult]:
"""Run a similarity search query against the deployed index.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
response = self.index_endpoint.find_neighbors(
deployed_index_id=deployed_index_id,
queries=[query],
num_neighbors=limit,
)
results = []
for neighbor in response[0]:
file_path = (
f"{self.index_name}/contents/{neighbor.id}.md"
)
content = (
self.storage.get_file_stream(file_path)
.read()
.decode("utf-8")
)
results.append(
SearchResult(
id=neighbor.id,
distance=float(neighbor.distance or 0),
content=content,
),
)
return results
async def async_run_query(
self,
deployed_index_id: str,
query: Sequence[float],
limit: int,
) -> list[SearchResult]:
"""Run an async similarity search via the REST API.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
domain = self.index_endpoint.public_endpoint_domain_name
endpoint_id = self.index_endpoint.name.split("/")[-1]
url = (
f"https://{domain}/v1/projects/{self.project_id}"
f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors"
)
payload = {
"deployed_index_id": deployed_index_id,
"queries": [
{
"datapoint": {"feature_vector": list(query)},
"neighbor_count": limit,
},
],
}
headers = await self._async_get_auth_headers()
session = self._get_aio_session()
async with session.post(
url, json=payload, headers=headers,
) as response:
response.raise_for_status()
data = await response.json()
neighbors = (
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
)
content_tasks = []
for neighbor in neighbors:
datapoint_id = neighbor["datapoint"]["datapointId"]
file_path = (
f"{self.index_name}/contents/{datapoint_id}.md"
)
content_tasks.append(
self.storage.async_get_file_stream(file_path),
)
file_streams = await asyncio.gather(*content_tasks)
results: list[SearchResult] = []
for neighbor, stream in zip(
neighbors, file_streams, strict=True,
):
results.append(
SearchResult(
id=neighbor["datapoint"]["datapointId"],
distance=neighbor["distance"],
content=stream.read().decode("utf-8"),
),
)
return results
def delete_index(self, index_name: str) -> None:
"""Delete a Vertex AI Vector Search index.
Args:
index_name: The resource name of the index.
"""
index = aiplatform.MatchingEngineIndex(index_name)
index.delete()
def delete_index_endpoint(
self, index_endpoint_name: str,
) -> None:
"""Delete a Vertex AI Vector Search index endpoint.
Args:
index_endpoint_name: The resource name of the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name,
)
index_endpoint.undeploy_all()
index_endpoint.delete(force=True)

6
src/va_agent/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""Package export for the ADK root agent."""
import os
# Ensure the Google GenAI SDK talks to Vertex AI instead of the public Gemini API.
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "true")

29
src/va_agent/agent.py Normal file
View File

@@ -0,0 +1,29 @@
"""ADK agent with vector search RAG tool."""
from google import genai
from google.adk.agents.llm_agent import Agent
from google.adk.runners import Runner
from google.adk.tools.mcp_tool import McpToolset
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.cloud.firestore_v1.async_client import AsyncClient
from va_agent.config import settings
from va_agent.session import FirestoreSessionService
connection_params = SseConnectionParams(url=settings.mcp_remote_url)
toolset = McpToolset(connection_params=connection_params)
agent = Agent(
model=settings.agent_model,
name=settings.agent_name,
instruction=settings.agent_instructions,
tools=[toolset],
)
session_service = FirestoreSessionService(
db=AsyncClient(database=settings.firestore_db),
compaction_token_threshold=10_000,
genai_client=genai.Client(),
)
runner = Runner(app_name="va_agent", agent=agent, session_service=session_service)

53
src/va_agent/config.py Normal file
View File

@@ -0,0 +1,53 @@
"""Configuration helper for ADK agent."""
import os
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
class AgentSettings(BaseSettings):
"""Settings for ADK agent with vector search."""
google_cloud_project: str
google_cloud_location: str
# Agent configuration
agent_name: str
agent_instructions: str
agent_model: str
# Firestore configuration
firestore_db: str
# MCP configuration
mcp_remote_url: str
model_config = SettingsConfigDict(
yaml_file=CONFIG_FILE_PATH,
extra="ignore", # Ignore extra fields from config.yaml
)
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
) -> tuple[PydanticBaseSettingsSource, ...]:
"""Use env vars and YAML as settings sources."""
return (
env_settings,
YamlConfigSettingsSource(settings_cls),
)
settings = AgentSettings.model_validate({})

10
src/va_agent/server.py Normal file
View File

@@ -0,0 +1,10 @@
"""FastAPI server exposing the RAG agent endpoint.
NOTE: This file is a stub. The rag_eval module was removed in the
lean MCP implementation. This file is kept for reference but is not
functional.
"""
from fastapi import FastAPI
app = FastAPI(title="RAG Agent")

582
src/va_agent/session.py Normal file
View File

@@ -0,0 +1,582 @@
"""Firestore-backed session service for Google ADK."""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from typing import TYPE_CHECKING, Any, override
from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event
from google.adk.sessions import _session_util
from google.adk.sessions.base_session_service import (
BaseSessionService,
GetSessionConfig,
ListSessionsResponse,
)
from google.adk.sessions.session import Session
from google.adk.sessions.state import State
from google.cloud.firestore_v1.async_transaction import async_transactional
from google.cloud.firestore_v1.base_query import FieldFilter
from google.cloud.firestore_v1.field_path import FieldPath
from google.genai.types import Content, Part
if TYPE_CHECKING:
from google import genai
from google.cloud.firestore_v1.async_client import AsyncClient
logger = logging.getLogger("google_adk." + __name__)
_COMPACTION_LOCK_TTL = 300 # seconds
@async_transactional
async def _try_claim_compaction_txn(transaction: Any, session_ref: Any) -> bool:
"""Atomically claim the compaction lock if it is free or stale."""
snapshot = await session_ref.get(transaction=transaction)
if not snapshot.exists:
return False
data = snapshot.to_dict() or {}
lock_time = data.get("compaction_lock")
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
return False
transaction.update(session_ref, {"compaction_lock": time.time()})
return True
class FirestoreSessionService(BaseSessionService):
"""A Firestore-backed implementation of BaseSessionService.
Firestore document layout (given ``collection_prefix="adk"``)::
adk_app_states/{app_name}
→ app-scoped state key/values
adk_user_states/{app_name}__{user_id}
→ user-scoped state key/values
adk_sessions/{app_name}__{user_id}__{session_id}
{app_name, user_id, session_id, state: {…}, last_update_time}
└─ events/{event_id} → serialised Event
"""
def __init__( # noqa: PLR0913
self,
*,
db: AsyncClient,
collection_prefix: str = "adk",
compaction_token_threshold: int | None = None,
compaction_model: str = "gemini-2.5-flash",
compaction_keep_recent: int = 10,
genai_client: genai.Client | None = None,
) -> None:
"""Initialize FirestoreSessionService.
Args:
db: Firestore async client
collection_prefix: Prefix for Firestore collections
compaction_token_threshold: Token count threshold for compaction
compaction_model: Model to use for summarization
compaction_keep_recent: Number of recent events to keep
genai_client: GenAI client for compaction summaries
"""
if compaction_token_threshold is not None and genai_client is None:
msg = "genai_client is required when compaction_token_threshold is set."
raise ValueError(msg)
self._db = db
self._prefix = collection_prefix
self._compaction_threshold = compaction_token_threshold
self._compaction_model = compaction_model
self._compaction_keep_recent = compaction_keep_recent
self._genai_client = genai_client
self._compaction_locks: dict[str, asyncio.Lock] = {}
self._active_tasks: set[asyncio.Task] = set()
# ------------------------------------------------------------------
# Document-reference helpers
# ------------------------------------------------------------------
def _app_state_ref(self, app_name: str) -> Any:
return self._db.collection(f"{self._prefix}_app_states").document(app_name)
def _user_state_ref(self, app_name: str, user_id: str) -> Any:
return self._db.collection(f"{self._prefix}_user_states").document(
f"{app_name}__{user_id}"
)
def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any:
return self._db.collection(f"{self._prefix}_sessions").document(
f"{app_name}__{user_id}__{session_id}"
)
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
return self._session_ref(app_name, user_id, session_id).collection("events")
# ------------------------------------------------------------------
# State helpers
# ------------------------------------------------------------------
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
snap = await self._app_state_ref(app_name).get()
return snap.to_dict() or {} if snap.exists else {}
async def _get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]:
snap = await self._user_state_ref(app_name, user_id).get()
return snap.to_dict() or {} if snap.exists else {}
@staticmethod
def _merge_state(
app_state: dict[str, Any],
user_state: dict[str, Any],
session_state: dict[str, Any],
) -> dict[str, Any]:
merged = dict(session_state)
for key, value in app_state.items():
merged[State.APP_PREFIX + key] = value
for key, value in user_state.items():
merged[State.USER_PREFIX + key] = value
return merged
# ------------------------------------------------------------------
# Compaction helpers
# ------------------------------------------------------------------
@staticmethod
def _events_to_text(events: list[Event]) -> str:
lines: list[str] = []
for event in events:
if event.content and event.content.parts:
text = "".join(p.text or "" for p in event.content.parts)
if text:
role = "User" if event.author == "user" else "Assistant"
lines.append(f"{role}: {text}")
return "\n\n".join(lines)
async def _generate_summary(
self, existing_summary: str, events: list[Event]
) -> str:
conversation_text = self._events_to_text(events)
previous = (
f"Previous summary of earlier conversation:\n{existing_summary}\n\n"
if existing_summary
else ""
)
prompt = (
"Summarize the following conversation between a user and an "
"assistant. Preserve:\n"
"- Key decisions and conclusions\n"
"- User preferences and requirements\n"
"- Important facts, names, and numbers\n"
"- The overall topic and direction of the conversation\n"
"- Any pending tasks or open questions\n\n"
f"{previous}"
f"Conversation:\n{conversation_text}\n\n"
"Provide a clear, comprehensive summary."
)
if self._genai_client is None:
msg = "genai_client is required for compaction"
raise RuntimeError(msg)
response = await self._genai_client.aio.models.generate_content(
model=self._compaction_model,
contents=prompt,
)
return response.text or ""
async def _compact_session(self, session: Session) -> None:
app_name = session.app_name
user_id = session.user_id
session_id = session.id
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref.order_by("timestamp")
event_docs = await query.get()
if len(event_docs) <= self._compaction_keep_recent:
return
all_events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
events_to_summarize = all_events[: -self._compaction_keep_recent]
session_snap = await self._session_ref(app_name, user_id, session_id).get()
existing_summary = (session_snap.to_dict() or {}).get(
"conversation_summary", ""
)
try:
summary = await self._generate_summary(
existing_summary, events_to_summarize
)
except Exception:
logger.exception("Compaction summary generation failed; skipping.")
return
# Write summary BEFORE deleting events so a crash between the two
# steps leaves safe duplication rather than data loss.
await self._session_ref(app_name, user_id, session_id).update(
{"conversation_summary": summary}
)
docs_to_delete = event_docs[: -self._compaction_keep_recent]
for i in range(0, len(docs_to_delete), 500):
batch = self._db.batch()
for doc in docs_to_delete[i : i + 500]:
batch.delete(doc.reference)
await batch.commit()
logger.info(
"Compacted session %s: summarised %d events, kept %d.",
session_id,
len(docs_to_delete),
self._compaction_keep_recent,
)
async def _guarded_compact(self, session: Session) -> None:
"""Run compaction in the background with per-session locking."""
key = f"{session.app_name}__{session.user_id}__{session.id}"
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
if lock.locked():
logger.debug("Compaction already running locally for %s; skipping.", key)
return
async with lock:
session_ref = self._session_ref(
session.app_name, session.user_id, session.id
)
try:
transaction = self._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, session_ref)
except Exception:
logger.exception("Failed to claim compaction lock for %s", key)
return
if not claimed:
logger.debug(
"Compaction lock held by another instance for %s; skipping.",
key,
)
return
try:
await self._compact_session(session)
except Exception:
logger.exception("Background compaction failed for %s", key)
finally:
try:
await session_ref.update({"compaction_lock": None})
except Exception:
logger.exception("Failed to release compaction lock for %s", key)
async def close(self) -> None:
"""Await all in-flight compaction tasks. Call before shutdown."""
if self._active_tasks:
await asyncio.gather(*self._active_tasks, return_exceptions=True)
# ------------------------------------------------------------------
# BaseSessionService implementation
# ------------------------------------------------------------------
@override
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: dict[str, Any] | None = None,
session_id: str | None = None,
) -> Session:
if session_id and session_id.strip():
session_id = session_id.strip()
existing = await self._session_ref(app_name, user_id, session_id).get()
if existing.exists:
msg = f"Session with id {session_id} already exists."
raise AlreadyExistsError(msg)
else:
session_id = str(uuid.uuid4())
state_deltas = _session_util.extract_state_delta(state) # type: ignore[attr-defined]
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state = state_deltas["session"]
write_coros: list = []
if app_state_delta:
write_coros.append(
self._app_state_ref(app_name).set(app_state_delta, merge=True)
)
if user_state_delta:
write_coros.append(
self._user_state_ref(app_name, user_id).set(
user_state_delta, merge=True
)
)
now = time.time()
write_coros.append(
self._session_ref(app_name, user_id, session_id).set(
{
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
"state": session_state or {},
"last_update_time": now,
}
)
)
await asyncio.gather(*write_coros)
app_state, user_state = await asyncio.gather(
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
merged = self._merge_state(app_state, user_state, session_state or {})
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
last_update_time=now,
)
@override
async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: GetSessionConfig | None = None,
) -> Session | None:
snap = await self._session_ref(app_name, user_id, session_id).get()
if not snap.exists:
return None
session_data = snap.to_dict()
# Build events query
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref
if config and config.after_timestamp:
query = query.where(
filter=FieldFilter("timestamp", ">=", config.after_timestamp)
)
query = query.order_by("timestamp")
event_docs, app_state, user_state = await asyncio.gather(
query.get(),
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
if config and config.num_recent_events:
events = events[-config.num_recent_events :]
# Prepend conversation summary as synthetic context events
conversation_summary = session_data.get("conversation_summary")
if conversation_summary:
summary_event = Event(
id="summary-context",
author="user",
content=Content(
role="user",
parts=[
Part(
text=(
"[Conversation context from previous"
" messages]\n"
f"{conversation_summary}"
)
)
],
),
timestamp=0.0,
invocation_id="compaction-summary",
)
ack_event = Event(
id="summary-ack",
author=app_name,
content=Content(
role="model",
parts=[
Part(
text=(
"Understood, I have the context from our"
" previous conversation and will continue"
" accordingly."
)
)
],
),
timestamp=0.001,
invocation_id="compaction-summary",
)
events = [summary_event, ack_event, *events]
# Merge scoped state
merged = self._merge_state(app_state, user_state, session_data.get("state", {}))
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
events=events,
last_update_time=session_data.get("last_update_time", 0.0),
)
@override
async def list_sessions(
self, *, app_name: str, user_id: str | None = None
) -> ListSessionsResponse:
query = self._db.collection(f"{self._prefix}_sessions").where(
filter=FieldFilter("app_name", "==", app_name)
)
if user_id is not None:
query = query.where(filter=FieldFilter("user_id", "==", user_id))
docs = await query.get()
if not docs:
return ListSessionsResponse()
doc_dicts: list[dict[str, Any]] = [doc.to_dict() or {} for doc in docs]
# Pre-fetch app state and all distinct user states in parallel
unique_user_ids = list({d["user_id"] for d in doc_dicts})
app_state, *user_states = await asyncio.gather(
self._get_app_state(app_name),
*(self._get_user_state(app_name, uid) for uid in unique_user_ids),
)
user_state_cache = dict(zip(unique_user_ids, user_states, strict=False))
sessions: list[Session] = []
for data in doc_dicts:
s_user_id = data["user_id"]
merged = self._merge_state(
app_state,
user_state_cache[s_user_id],
data.get("state", {}),
)
sessions.append(
Session(
app_name=app_name,
user_id=s_user_id,
id=data["session_id"],
state=merged,
events=[],
last_update_time=data.get("last_update_time", 0.0),
)
)
return ListSessionsResponse(sessions=sessions)
@override
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
ref = self._session_ref(app_name, user_id, session_id)
await self._db.recursive_delete(ref)
@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
return event
t0 = time.monotonic()
app_name = session.app_name
user_id = session.user_id
session_id = session.id
# Base class: strips temp state, applies delta to in-memory session,
# appends event to session.events
event = await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Persist event document
event_data = event.model_dump(mode="json", exclude_none=True)
await (
self._events_col(app_name, user_id, session_id)
.document(event.id)
.set(event_data)
)
# Persist state deltas
session_ref = self._session_ref(app_name, user_id, session_id)
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
write_coros: list = []
if state_deltas["app"]:
write_coros.append(
self._app_state_ref(app_name).set(state_deltas["app"], merge=True)
)
if state_deltas["user"]:
write_coros.append(
self._user_state_ref(app_name, user_id).set(
state_deltas["user"], merge=True
)
)
if state_deltas["session"]:
field_updates: dict[str, Any] = {
FieldPath("state", k).to_api_repr(): v
for k, v in state_deltas["session"].items()
}
field_updates["last_update_time"] = event.timestamp
write_coros.append(session_ref.update(field_updates))
else:
write_coros.append(
session_ref.update({"last_update_time": event.timestamp})
)
await asyncio.gather(*write_coros)
else:
await session_ref.update({"last_update_time": event.timestamp})
# Log token usage
if event.usage_metadata:
meta = event.usage_metadata
logger.info(
"Token usage for session %s event %s: "
"prompt=%s, candidates=%s, total=%s",
session_id,
event.id,
meta.prompt_token_count,
meta.candidates_token_count,
meta.total_token_count,
)
# Trigger compaction if total token count exceeds threshold
if (
self._compaction_threshold is not None
and event.usage_metadata
and event.usage_metadata.total_token_count
and event.usage_metadata.total_token_count >= self._compaction_threshold
):
logger.info(
"Compaction triggered for session %s: "
"total_token_count=%d >= threshold=%d",
session_id,
event.usage_metadata.total_token_count,
self._compaction_threshold,
)
task = asyncio.create_task(self._guarded_compact(session))
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
elapsed = time.monotonic() - t0
logger.info(
"append_event completed for session %s event %s in %.3fs",
session_id,
event.id,
elapsed,
)
return event