Lean MCP implementation

This commit is contained in:
2026-02-23 03:29:21 +00:00
parent a9bc36b5fc
commit 159e8ee433
37 changed files with 2380 additions and 3541 deletions

6
src/va_agent/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""Package export for the ADK root agent."""
import os
# Ensure the Google GenAI SDK talks to Vertex AI instead of the public Gemini API.
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "true")

29
src/va_agent/agent.py Normal file
View File

@@ -0,0 +1,29 @@
"""ADK agent with vector search RAG tool."""
from google import genai
from google.adk.agents.llm_agent import Agent
from google.adk.runners import Runner
from google.adk.tools.mcp_tool import McpToolset
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.cloud.firestore_v1.async_client import AsyncClient
from va_agent.config import settings
from va_agent.session import FirestoreSessionService
connection_params = SseConnectionParams(url=settings.mcp_remote_url)
toolset = McpToolset(connection_params=connection_params)
agent = Agent(
model=settings.agent_model,
name=settings.agent_name,
instruction=settings.agent_instructions,
tools=[toolset],
)
session_service = FirestoreSessionService(
db=AsyncClient(database=settings.firestore_db),
compaction_token_threshold=10_000,
genai_client=genai.Client(),
)
runner = Runner(app_name="va_agent", agent=agent, session_service=session_service)

53
src/va_agent/config.py Normal file
View File

@@ -0,0 +1,53 @@
"""Configuration helper for ADK agent."""
import os
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
class AgentSettings(BaseSettings):
"""Settings for ADK agent with vector search."""
google_cloud_project: str
google_cloud_location: str
# Agent configuration
agent_name: str
agent_instructions: str
agent_model: str
# Firestore configuration
firestore_db: str
# MCP configuration
mcp_remote_url: str
model_config = SettingsConfigDict(
yaml_file=CONFIG_FILE_PATH,
extra="ignore", # Ignore extra fields from config.yaml
)
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
) -> tuple[PydanticBaseSettingsSource, ...]:
"""Use env vars and YAML as settings sources."""
return (
env_settings,
YamlConfigSettingsSource(settings_cls),
)
settings = AgentSettings.model_validate({})

10
src/va_agent/server.py Normal file
View File

@@ -0,0 +1,10 @@
"""FastAPI server exposing the RAG agent endpoint.
NOTE: This file is a stub. The rag_eval module was removed in the
lean MCP implementation. This file is kept for reference but is not
functional.
"""
from fastapi import FastAPI
app = FastAPI(title="RAG Agent")

582
src/va_agent/session.py Normal file
View File

@@ -0,0 +1,582 @@
"""Firestore-backed session service for Google ADK."""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from typing import TYPE_CHECKING, Any, override
from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event
from google.adk.sessions import _session_util
from google.adk.sessions.base_session_service import (
BaseSessionService,
GetSessionConfig,
ListSessionsResponse,
)
from google.adk.sessions.session import Session
from google.adk.sessions.state import State
from google.cloud.firestore_v1.async_transaction import async_transactional
from google.cloud.firestore_v1.base_query import FieldFilter
from google.cloud.firestore_v1.field_path import FieldPath
from google.genai.types import Content, Part
if TYPE_CHECKING:
from google import genai
from google.cloud.firestore_v1.async_client import AsyncClient
logger = logging.getLogger("google_adk." + __name__)
_COMPACTION_LOCK_TTL = 300 # seconds
@async_transactional
async def _try_claim_compaction_txn(transaction: Any, session_ref: Any) -> bool:
"""Atomically claim the compaction lock if it is free or stale."""
snapshot = await session_ref.get(transaction=transaction)
if not snapshot.exists:
return False
data = snapshot.to_dict() or {}
lock_time = data.get("compaction_lock")
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
return False
transaction.update(session_ref, {"compaction_lock": time.time()})
return True
class FirestoreSessionService(BaseSessionService):
"""A Firestore-backed implementation of BaseSessionService.
Firestore document layout (given ``collection_prefix="adk"``)::
adk_app_states/{app_name}
→ app-scoped state key/values
adk_user_states/{app_name}__{user_id}
→ user-scoped state key/values
adk_sessions/{app_name}__{user_id}__{session_id}
{app_name, user_id, session_id, state: {…}, last_update_time}
└─ events/{event_id} → serialised Event
"""
def __init__( # noqa: PLR0913
self,
*,
db: AsyncClient,
collection_prefix: str = "adk",
compaction_token_threshold: int | None = None,
compaction_model: str = "gemini-2.5-flash",
compaction_keep_recent: int = 10,
genai_client: genai.Client | None = None,
) -> None:
"""Initialize FirestoreSessionService.
Args:
db: Firestore async client
collection_prefix: Prefix for Firestore collections
compaction_token_threshold: Token count threshold for compaction
compaction_model: Model to use for summarization
compaction_keep_recent: Number of recent events to keep
genai_client: GenAI client for compaction summaries
"""
if compaction_token_threshold is not None and genai_client is None:
msg = "genai_client is required when compaction_token_threshold is set."
raise ValueError(msg)
self._db = db
self._prefix = collection_prefix
self._compaction_threshold = compaction_token_threshold
self._compaction_model = compaction_model
self._compaction_keep_recent = compaction_keep_recent
self._genai_client = genai_client
self._compaction_locks: dict[str, asyncio.Lock] = {}
self._active_tasks: set[asyncio.Task] = set()
# ------------------------------------------------------------------
# Document-reference helpers
# ------------------------------------------------------------------
def _app_state_ref(self, app_name: str) -> Any:
return self._db.collection(f"{self._prefix}_app_states").document(app_name)
def _user_state_ref(self, app_name: str, user_id: str) -> Any:
return self._db.collection(f"{self._prefix}_user_states").document(
f"{app_name}__{user_id}"
)
def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any:
return self._db.collection(f"{self._prefix}_sessions").document(
f"{app_name}__{user_id}__{session_id}"
)
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
return self._session_ref(app_name, user_id, session_id).collection("events")
# ------------------------------------------------------------------
# State helpers
# ------------------------------------------------------------------
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
snap = await self._app_state_ref(app_name).get()
return snap.to_dict() or {} if snap.exists else {}
async def _get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]:
snap = await self._user_state_ref(app_name, user_id).get()
return snap.to_dict() or {} if snap.exists else {}
@staticmethod
def _merge_state(
app_state: dict[str, Any],
user_state: dict[str, Any],
session_state: dict[str, Any],
) -> dict[str, Any]:
merged = dict(session_state)
for key, value in app_state.items():
merged[State.APP_PREFIX + key] = value
for key, value in user_state.items():
merged[State.USER_PREFIX + key] = value
return merged
# ------------------------------------------------------------------
# Compaction helpers
# ------------------------------------------------------------------
@staticmethod
def _events_to_text(events: list[Event]) -> str:
lines: list[str] = []
for event in events:
if event.content and event.content.parts:
text = "".join(p.text or "" for p in event.content.parts)
if text:
role = "User" if event.author == "user" else "Assistant"
lines.append(f"{role}: {text}")
return "\n\n".join(lines)
async def _generate_summary(
self, existing_summary: str, events: list[Event]
) -> str:
conversation_text = self._events_to_text(events)
previous = (
f"Previous summary of earlier conversation:\n{existing_summary}\n\n"
if existing_summary
else ""
)
prompt = (
"Summarize the following conversation between a user and an "
"assistant. Preserve:\n"
"- Key decisions and conclusions\n"
"- User preferences and requirements\n"
"- Important facts, names, and numbers\n"
"- The overall topic and direction of the conversation\n"
"- Any pending tasks or open questions\n\n"
f"{previous}"
f"Conversation:\n{conversation_text}\n\n"
"Provide a clear, comprehensive summary."
)
if self._genai_client is None:
msg = "genai_client is required for compaction"
raise RuntimeError(msg)
response = await self._genai_client.aio.models.generate_content(
model=self._compaction_model,
contents=prompt,
)
return response.text or ""
async def _compact_session(self, session: Session) -> None:
app_name = session.app_name
user_id = session.user_id
session_id = session.id
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref.order_by("timestamp")
event_docs = await query.get()
if len(event_docs) <= self._compaction_keep_recent:
return
all_events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
events_to_summarize = all_events[: -self._compaction_keep_recent]
session_snap = await self._session_ref(app_name, user_id, session_id).get()
existing_summary = (session_snap.to_dict() or {}).get(
"conversation_summary", ""
)
try:
summary = await self._generate_summary(
existing_summary, events_to_summarize
)
except Exception:
logger.exception("Compaction summary generation failed; skipping.")
return
# Write summary BEFORE deleting events so a crash between the two
# steps leaves safe duplication rather than data loss.
await self._session_ref(app_name, user_id, session_id).update(
{"conversation_summary": summary}
)
docs_to_delete = event_docs[: -self._compaction_keep_recent]
for i in range(0, len(docs_to_delete), 500):
batch = self._db.batch()
for doc in docs_to_delete[i : i + 500]:
batch.delete(doc.reference)
await batch.commit()
logger.info(
"Compacted session %s: summarised %d events, kept %d.",
session_id,
len(docs_to_delete),
self._compaction_keep_recent,
)
async def _guarded_compact(self, session: Session) -> None:
"""Run compaction in the background with per-session locking."""
key = f"{session.app_name}__{session.user_id}__{session.id}"
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
if lock.locked():
logger.debug("Compaction already running locally for %s; skipping.", key)
return
async with lock:
session_ref = self._session_ref(
session.app_name, session.user_id, session.id
)
try:
transaction = self._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, session_ref)
except Exception:
logger.exception("Failed to claim compaction lock for %s", key)
return
if not claimed:
logger.debug(
"Compaction lock held by another instance for %s; skipping.",
key,
)
return
try:
await self._compact_session(session)
except Exception:
logger.exception("Background compaction failed for %s", key)
finally:
try:
await session_ref.update({"compaction_lock": None})
except Exception:
logger.exception("Failed to release compaction lock for %s", key)
async def close(self) -> None:
"""Await all in-flight compaction tasks. Call before shutdown."""
if self._active_tasks:
await asyncio.gather(*self._active_tasks, return_exceptions=True)
# ------------------------------------------------------------------
# BaseSessionService implementation
# ------------------------------------------------------------------
@override
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: dict[str, Any] | None = None,
session_id: str | None = None,
) -> Session:
if session_id and session_id.strip():
session_id = session_id.strip()
existing = await self._session_ref(app_name, user_id, session_id).get()
if existing.exists:
msg = f"Session with id {session_id} already exists."
raise AlreadyExistsError(msg)
else:
session_id = str(uuid.uuid4())
state_deltas = _session_util.extract_state_delta(state) # type: ignore[attr-defined]
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state = state_deltas["session"]
write_coros: list = []
if app_state_delta:
write_coros.append(
self._app_state_ref(app_name).set(app_state_delta, merge=True)
)
if user_state_delta:
write_coros.append(
self._user_state_ref(app_name, user_id).set(
user_state_delta, merge=True
)
)
now = time.time()
write_coros.append(
self._session_ref(app_name, user_id, session_id).set(
{
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
"state": session_state or {},
"last_update_time": now,
}
)
)
await asyncio.gather(*write_coros)
app_state, user_state = await asyncio.gather(
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
merged = self._merge_state(app_state, user_state, session_state or {})
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
last_update_time=now,
)
@override
async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: GetSessionConfig | None = None,
) -> Session | None:
snap = await self._session_ref(app_name, user_id, session_id).get()
if not snap.exists:
return None
session_data = snap.to_dict()
# Build events query
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref
if config and config.after_timestamp:
query = query.where(
filter=FieldFilter("timestamp", ">=", config.after_timestamp)
)
query = query.order_by("timestamp")
event_docs, app_state, user_state = await asyncio.gather(
query.get(),
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
if config and config.num_recent_events:
events = events[-config.num_recent_events :]
# Prepend conversation summary as synthetic context events
conversation_summary = session_data.get("conversation_summary")
if conversation_summary:
summary_event = Event(
id="summary-context",
author="user",
content=Content(
role="user",
parts=[
Part(
text=(
"[Conversation context from previous"
" messages]\n"
f"{conversation_summary}"
)
)
],
),
timestamp=0.0,
invocation_id="compaction-summary",
)
ack_event = Event(
id="summary-ack",
author=app_name,
content=Content(
role="model",
parts=[
Part(
text=(
"Understood, I have the context from our"
" previous conversation and will continue"
" accordingly."
)
)
],
),
timestamp=0.001,
invocation_id="compaction-summary",
)
events = [summary_event, ack_event, *events]
# Merge scoped state
merged = self._merge_state(app_state, user_state, session_data.get("state", {}))
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
events=events,
last_update_time=session_data.get("last_update_time", 0.0),
)
@override
async def list_sessions(
self, *, app_name: str, user_id: str | None = None
) -> ListSessionsResponse:
query = self._db.collection(f"{self._prefix}_sessions").where(
filter=FieldFilter("app_name", "==", app_name)
)
if user_id is not None:
query = query.where(filter=FieldFilter("user_id", "==", user_id))
docs = await query.get()
if not docs:
return ListSessionsResponse()
doc_dicts: list[dict[str, Any]] = [doc.to_dict() or {} for doc in docs]
# Pre-fetch app state and all distinct user states in parallel
unique_user_ids = list({d["user_id"] for d in doc_dicts})
app_state, *user_states = await asyncio.gather(
self._get_app_state(app_name),
*(self._get_user_state(app_name, uid) for uid in unique_user_ids),
)
user_state_cache = dict(zip(unique_user_ids, user_states, strict=False))
sessions: list[Session] = []
for data in doc_dicts:
s_user_id = data["user_id"]
merged = self._merge_state(
app_state,
user_state_cache[s_user_id],
data.get("state", {}),
)
sessions.append(
Session(
app_name=app_name,
user_id=s_user_id,
id=data["session_id"],
state=merged,
events=[],
last_update_time=data.get("last_update_time", 0.0),
)
)
return ListSessionsResponse(sessions=sessions)
@override
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
ref = self._session_ref(app_name, user_id, session_id)
await self._db.recursive_delete(ref)
@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
return event
t0 = time.monotonic()
app_name = session.app_name
user_id = session.user_id
session_id = session.id
# Base class: strips temp state, applies delta to in-memory session,
# appends event to session.events
event = await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Persist event document
event_data = event.model_dump(mode="json", exclude_none=True)
await (
self._events_col(app_name, user_id, session_id)
.document(event.id)
.set(event_data)
)
# Persist state deltas
session_ref = self._session_ref(app_name, user_id, session_id)
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
write_coros: list = []
if state_deltas["app"]:
write_coros.append(
self._app_state_ref(app_name).set(state_deltas["app"], merge=True)
)
if state_deltas["user"]:
write_coros.append(
self._user_state_ref(app_name, user_id).set(
state_deltas["user"], merge=True
)
)
if state_deltas["session"]:
field_updates: dict[str, Any] = {
FieldPath("state", k).to_api_repr(): v
for k, v in state_deltas["session"].items()
}
field_updates["last_update_time"] = event.timestamp
write_coros.append(session_ref.update(field_updates))
else:
write_coros.append(
session_ref.update({"last_update_time": event.timestamp})
)
await asyncio.gather(*write_coros)
else:
await session_ref.update({"last_update_time": event.timestamp})
# Log token usage
if event.usage_metadata:
meta = event.usage_metadata
logger.info(
"Token usage for session %s event %s: "
"prompt=%s, candidates=%s, total=%s",
session_id,
event.id,
meta.prompt_token_count,
meta.candidates_token_count,
meta.total_token_count,
)
# Trigger compaction if total token count exceeds threshold
if (
self._compaction_threshold is not None
and event.usage_metadata
and event.usage_metadata.total_token_count
and event.usage_metadata.total_token_count >= self._compaction_threshold
):
logger.info(
"Compaction triggered for session %s: "
"total_token_count=%d >= threshold=%d",
session_id,
event.usage_metadata.total_token_count,
self._compaction_threshold,
)
task = asyncio.create_task(self._guarded_compact(session))
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
elapsed = time.monotonic() - t0
logger.info(
"append_event completed for session %s event %s in %.3fs",
session_id,
event.id,
elapsed,
)
return event