Lean MCP implementation
This commit is contained in:
6
src/va_agent/__init__.py
Normal file
6
src/va_agent/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Package export for the ADK root agent."""
|
||||
|
||||
import os
|
||||
|
||||
# Ensure the Google GenAI SDK talks to Vertex AI instead of the public Gemini API.
|
||||
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
||||
29
src/va_agent/agent.py
Normal file
29
src/va_agent/agent.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""ADK agent with vector search RAG tool."""
|
||||
|
||||
from google import genai
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.tools.mcp_tool import McpToolset
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
|
||||
from va_agent.config import settings
|
||||
from va_agent.session import FirestoreSessionService
|
||||
|
||||
connection_params = SseConnectionParams(url=settings.mcp_remote_url)
|
||||
toolset = McpToolset(connection_params=connection_params)
|
||||
|
||||
agent = Agent(
|
||||
model=settings.agent_model,
|
||||
name=settings.agent_name,
|
||||
instruction=settings.agent_instructions,
|
||||
tools=[toolset],
|
||||
)
|
||||
|
||||
session_service = FirestoreSessionService(
|
||||
db=AsyncClient(database=settings.firestore_db),
|
||||
compaction_token_threshold=10_000,
|
||||
genai_client=genai.Client(),
|
||||
)
|
||||
|
||||
runner = Runner(app_name="va_agent", agent=agent, session_service=session_service)
|
||||
53
src/va_agent/config.py
Normal file
53
src/va_agent/config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Configuration helper for ADK agent."""
|
||||
|
||||
import os
|
||||
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||
|
||||
|
||||
class AgentSettings(BaseSettings):
|
||||
"""Settings for ADK agent with vector search."""
|
||||
|
||||
google_cloud_project: str
|
||||
google_cloud_location: str
|
||||
|
||||
# Agent configuration
|
||||
agent_name: str
|
||||
agent_instructions: str
|
||||
agent_model: str
|
||||
|
||||
# Firestore configuration
|
||||
firestore_db: str
|
||||
|
||||
# MCP configuration
|
||||
mcp_remote_url: str
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
yaml_file=CONFIG_FILE_PATH,
|
||||
extra="ignore", # Ignore extra fields from config.yaml
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""Use env vars and YAML as settings sources."""
|
||||
return (
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
|
||||
|
||||
settings = AgentSettings.model_validate({})
|
||||
10
src/va_agent/server.py
Normal file
10
src/va_agent/server.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""FastAPI server exposing the RAG agent endpoint.
|
||||
|
||||
NOTE: This file is a stub. The rag_eval module was removed in the
|
||||
lean MCP implementation. This file is kept for reference but is not
|
||||
functional.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI(title="RAG Agent")
|
||||
582
src/va_agent/session.py
Normal file
582
src/va_agent/session.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""Firestore-backed session service for Google ADK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.sessions import _session_util
|
||||
from google.adk.sessions.base_session_service import (
|
||||
BaseSessionService,
|
||||
GetSessionConfig,
|
||||
ListSessionsResponse,
|
||||
)
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.sessions.state import State
|
||||
from google.cloud.firestore_v1.async_transaction import async_transactional
|
||||
from google.cloud.firestore_v1.base_query import FieldFilter
|
||||
from google.cloud.firestore_v1.field_path import FieldPath
|
||||
from google.genai.types import Content, Part
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google import genai
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
_COMPACTION_LOCK_TTL = 300 # seconds
|
||||
|
||||
|
||||
@async_transactional
|
||||
async def _try_claim_compaction_txn(transaction: Any, session_ref: Any) -> bool:
|
||||
"""Atomically claim the compaction lock if it is free or stale."""
|
||||
snapshot = await session_ref.get(transaction=transaction)
|
||||
if not snapshot.exists:
|
||||
return False
|
||||
data = snapshot.to_dict() or {}
|
||||
lock_time = data.get("compaction_lock")
|
||||
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
|
||||
return False
|
||||
transaction.update(session_ref, {"compaction_lock": time.time()})
|
||||
return True
|
||||
|
||||
|
||||
class FirestoreSessionService(BaseSessionService):
|
||||
"""A Firestore-backed implementation of BaseSessionService.
|
||||
|
||||
Firestore document layout (given ``collection_prefix="adk"``)::
|
||||
|
||||
adk_app_states/{app_name}
|
||||
→ app-scoped state key/values
|
||||
|
||||
adk_user_states/{app_name}__{user_id}
|
||||
→ user-scoped state key/values
|
||||
|
||||
adk_sessions/{app_name}__{user_id}__{session_id}
|
||||
→ {app_name, user_id, session_id, state: {…}, last_update_time}
|
||||
└─ events/{event_id} → serialised Event
|
||||
"""
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
db: AsyncClient,
|
||||
collection_prefix: str = "adk",
|
||||
compaction_token_threshold: int | None = None,
|
||||
compaction_model: str = "gemini-2.5-flash",
|
||||
compaction_keep_recent: int = 10,
|
||||
genai_client: genai.Client | None = None,
|
||||
) -> None:
|
||||
"""Initialize FirestoreSessionService.
|
||||
|
||||
Args:
|
||||
db: Firestore async client
|
||||
collection_prefix: Prefix for Firestore collections
|
||||
compaction_token_threshold: Token count threshold for compaction
|
||||
compaction_model: Model to use for summarization
|
||||
compaction_keep_recent: Number of recent events to keep
|
||||
genai_client: GenAI client for compaction summaries
|
||||
|
||||
"""
|
||||
if compaction_token_threshold is not None and genai_client is None:
|
||||
msg = "genai_client is required when compaction_token_threshold is set."
|
||||
raise ValueError(msg)
|
||||
self._db = db
|
||||
self._prefix = collection_prefix
|
||||
self._compaction_threshold = compaction_token_threshold
|
||||
self._compaction_model = compaction_model
|
||||
self._compaction_keep_recent = compaction_keep_recent
|
||||
self._genai_client = genai_client
|
||||
self._compaction_locks: dict[str, asyncio.Lock] = {}
|
||||
self._active_tasks: set[asyncio.Task] = set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Document-reference helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _app_state_ref(self, app_name: str) -> Any:
|
||||
return self._db.collection(f"{self._prefix}_app_states").document(app_name)
|
||||
|
||||
def _user_state_ref(self, app_name: str, user_id: str) -> Any:
|
||||
return self._db.collection(f"{self._prefix}_user_states").document(
|
||||
f"{app_name}__{user_id}"
|
||||
)
|
||||
|
||||
def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any:
|
||||
return self._db.collection(f"{self._prefix}_sessions").document(
|
||||
f"{app_name}__{user_id}__{session_id}"
|
||||
)
|
||||
|
||||
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
|
||||
return self._session_ref(app_name, user_id, session_id).collection("events")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
|
||||
snap = await self._app_state_ref(app_name).get()
|
||||
return snap.to_dict() or {} if snap.exists else {}
|
||||
|
||||
async def _get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]:
|
||||
snap = await self._user_state_ref(app_name, user_id).get()
|
||||
return snap.to_dict() or {} if snap.exists else {}
|
||||
|
||||
@staticmethod
|
||||
def _merge_state(
|
||||
app_state: dict[str, Any],
|
||||
user_state: dict[str, Any],
|
||||
session_state: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
merged = dict(session_state)
|
||||
for key, value in app_state.items():
|
||||
merged[State.APP_PREFIX + key] = value
|
||||
for key, value in user_state.items():
|
||||
merged[State.USER_PREFIX + key] = value
|
||||
return merged
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Compaction helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _events_to_text(events: list[Event]) -> str:
|
||||
lines: list[str] = []
|
||||
for event in events:
|
||||
if event.content and event.content.parts:
|
||||
text = "".join(p.text or "" for p in event.content.parts)
|
||||
if text:
|
||||
role = "User" if event.author == "user" else "Assistant"
|
||||
lines.append(f"{role}: {text}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
async def _generate_summary(
|
||||
self, existing_summary: str, events: list[Event]
|
||||
) -> str:
|
||||
conversation_text = self._events_to_text(events)
|
||||
previous = (
|
||||
f"Previous summary of earlier conversation:\n{existing_summary}\n\n"
|
||||
if existing_summary
|
||||
else ""
|
||||
)
|
||||
prompt = (
|
||||
"Summarize the following conversation between a user and an "
|
||||
"assistant. Preserve:\n"
|
||||
"- Key decisions and conclusions\n"
|
||||
"- User preferences and requirements\n"
|
||||
"- Important facts, names, and numbers\n"
|
||||
"- The overall topic and direction of the conversation\n"
|
||||
"- Any pending tasks or open questions\n\n"
|
||||
f"{previous}"
|
||||
f"Conversation:\n{conversation_text}\n\n"
|
||||
"Provide a clear, comprehensive summary."
|
||||
)
|
||||
if self._genai_client is None:
|
||||
msg = "genai_client is required for compaction"
|
||||
raise RuntimeError(msg)
|
||||
response = await self._genai_client.aio.models.generate_content(
|
||||
model=self._compaction_model,
|
||||
contents=prompt,
|
||||
)
|
||||
return response.text or ""
|
||||
|
||||
async def _compact_session(self, session: Session) -> None:
|
||||
app_name = session.app_name
|
||||
user_id = session.user_id
|
||||
session_id = session.id
|
||||
|
||||
events_ref = self._events_col(app_name, user_id, session_id)
|
||||
query = events_ref.order_by("timestamp")
|
||||
event_docs = await query.get()
|
||||
|
||||
if len(event_docs) <= self._compaction_keep_recent:
|
||||
return
|
||||
|
||||
all_events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||
events_to_summarize = all_events[: -self._compaction_keep_recent]
|
||||
|
||||
session_snap = await self._session_ref(app_name, user_id, session_id).get()
|
||||
existing_summary = (session_snap.to_dict() or {}).get(
|
||||
"conversation_summary", ""
|
||||
)
|
||||
|
||||
try:
|
||||
summary = await self._generate_summary(
|
||||
existing_summary, events_to_summarize
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Compaction summary generation failed; skipping.")
|
||||
return
|
||||
|
||||
# Write summary BEFORE deleting events so a crash between the two
|
||||
# steps leaves safe duplication rather than data loss.
|
||||
await self._session_ref(app_name, user_id, session_id).update(
|
||||
{"conversation_summary": summary}
|
||||
)
|
||||
|
||||
docs_to_delete = event_docs[: -self._compaction_keep_recent]
|
||||
for i in range(0, len(docs_to_delete), 500):
|
||||
batch = self._db.batch()
|
||||
for doc in docs_to_delete[i : i + 500]:
|
||||
batch.delete(doc.reference)
|
||||
await batch.commit()
|
||||
|
||||
logger.info(
|
||||
"Compacted session %s: summarised %d events, kept %d.",
|
||||
session_id,
|
||||
len(docs_to_delete),
|
||||
self._compaction_keep_recent,
|
||||
)
|
||||
|
||||
async def _guarded_compact(self, session: Session) -> None:
|
||||
"""Run compaction in the background with per-session locking."""
|
||||
key = f"{session.app_name}__{session.user_id}__{session.id}"
|
||||
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
|
||||
|
||||
if lock.locked():
|
||||
logger.debug("Compaction already running locally for %s; skipping.", key)
|
||||
return
|
||||
|
||||
async with lock:
|
||||
session_ref = self._session_ref(
|
||||
session.app_name, session.user_id, session.id
|
||||
)
|
||||
try:
|
||||
transaction = self._db.transaction()
|
||||
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||
except Exception:
|
||||
logger.exception("Failed to claim compaction lock for %s", key)
|
||||
return
|
||||
|
||||
if not claimed:
|
||||
logger.debug(
|
||||
"Compaction lock held by another instance for %s; skipping.",
|
||||
key,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self._compact_session(session)
|
||||
except Exception:
|
||||
logger.exception("Background compaction failed for %s", key)
|
||||
finally:
|
||||
try:
|
||||
await session_ref.update({"compaction_lock": None})
|
||||
except Exception:
|
||||
logger.exception("Failed to release compaction lock for %s", key)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Await all in-flight compaction tasks. Call before shutdown."""
|
||||
if self._active_tasks:
|
||||
await asyncio.gather(*self._active_tasks, return_exceptions=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# BaseSessionService implementation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: dict[str, Any] | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> Session:
|
||||
if session_id and session_id.strip():
|
||||
session_id = session_id.strip()
|
||||
existing = await self._session_ref(app_name, user_id, session_id).get()
|
||||
if existing.exists:
|
||||
msg = f"Session with id {session_id} already exists."
|
||||
raise AlreadyExistsError(msg)
|
||||
else:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
state_deltas = _session_util.extract_state_delta(state) # type: ignore[attr-defined]
|
||||
app_state_delta = state_deltas["app"]
|
||||
user_state_delta = state_deltas["user"]
|
||||
session_state = state_deltas["session"]
|
||||
|
||||
write_coros: list = []
|
||||
if app_state_delta:
|
||||
write_coros.append(
|
||||
self._app_state_ref(app_name).set(app_state_delta, merge=True)
|
||||
)
|
||||
if user_state_delta:
|
||||
write_coros.append(
|
||||
self._user_state_ref(app_name, user_id).set(
|
||||
user_state_delta, merge=True
|
||||
)
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
write_coros.append(
|
||||
self._session_ref(app_name, user_id, session_id).set(
|
||||
{
|
||||
"app_name": app_name,
|
||||
"user_id": user_id,
|
||||
"session_id": session_id,
|
||||
"state": session_state or {},
|
||||
"last_update_time": now,
|
||||
}
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*write_coros)
|
||||
|
||||
app_state, user_state = await asyncio.gather(
|
||||
self._get_app_state(app_name),
|
||||
self._get_user_state(app_name, user_id),
|
||||
)
|
||||
merged = self._merge_state(app_state, user_state, session_state or {})
|
||||
|
||||
return Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=merged,
|
||||
last_update_time=now,
|
||||
)
|
||||
|
||||
@override
|
||||
async def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
config: GetSessionConfig | None = None,
|
||||
) -> Session | None:
|
||||
snap = await self._session_ref(app_name, user_id, session_id).get()
|
||||
if not snap.exists:
|
||||
return None
|
||||
|
||||
session_data = snap.to_dict()
|
||||
|
||||
# Build events query
|
||||
events_ref = self._events_col(app_name, user_id, session_id)
|
||||
query = events_ref
|
||||
if config and config.after_timestamp:
|
||||
query = query.where(
|
||||
filter=FieldFilter("timestamp", ">=", config.after_timestamp)
|
||||
)
|
||||
query = query.order_by("timestamp")
|
||||
|
||||
event_docs, app_state, user_state = await asyncio.gather(
|
||||
query.get(),
|
||||
self._get_app_state(app_name),
|
||||
self._get_user_state(app_name, user_id),
|
||||
)
|
||||
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||
|
||||
if config and config.num_recent_events:
|
||||
events = events[-config.num_recent_events :]
|
||||
|
||||
# Prepend conversation summary as synthetic context events
|
||||
conversation_summary = session_data.get("conversation_summary")
|
||||
if conversation_summary:
|
||||
summary_event = Event(
|
||||
id="summary-context",
|
||||
author="user",
|
||||
content=Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part(
|
||||
text=(
|
||||
"[Conversation context from previous"
|
||||
" messages]\n"
|
||||
f"{conversation_summary}"
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
timestamp=0.0,
|
||||
invocation_id="compaction-summary",
|
||||
)
|
||||
ack_event = Event(
|
||||
id="summary-ack",
|
||||
author=app_name,
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
text=(
|
||||
"Understood, I have the context from our"
|
||||
" previous conversation and will continue"
|
||||
" accordingly."
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
timestamp=0.001,
|
||||
invocation_id="compaction-summary",
|
||||
)
|
||||
events = [summary_event, ack_event, *events]
|
||||
|
||||
# Merge scoped state
|
||||
merged = self._merge_state(app_state, user_state, session_data.get("state", {}))
|
||||
|
||||
return Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=merged,
|
||||
events=events,
|
||||
last_update_time=session_data.get("last_update_time", 0.0),
|
||||
)
|
||||
|
||||
@override
|
||||
async def list_sessions(
|
||||
self, *, app_name: str, user_id: str | None = None
|
||||
) -> ListSessionsResponse:
|
||||
query = self._db.collection(f"{self._prefix}_sessions").where(
|
||||
filter=FieldFilter("app_name", "==", app_name)
|
||||
)
|
||||
if user_id is not None:
|
||||
query = query.where(filter=FieldFilter("user_id", "==", user_id))
|
||||
|
||||
docs = await query.get()
|
||||
if not docs:
|
||||
return ListSessionsResponse()
|
||||
|
||||
doc_dicts: list[dict[str, Any]] = [doc.to_dict() or {} for doc in docs]
|
||||
|
||||
# Pre-fetch app state and all distinct user states in parallel
|
||||
unique_user_ids = list({d["user_id"] for d in doc_dicts})
|
||||
app_state, *user_states = await asyncio.gather(
|
||||
self._get_app_state(app_name),
|
||||
*(self._get_user_state(app_name, uid) for uid in unique_user_ids),
|
||||
)
|
||||
user_state_cache = dict(zip(unique_user_ids, user_states, strict=False))
|
||||
|
||||
sessions: list[Session] = []
|
||||
for data in doc_dicts:
|
||||
s_user_id = data["user_id"]
|
||||
merged = self._merge_state(
|
||||
app_state,
|
||||
user_state_cache[s_user_id],
|
||||
data.get("state", {}),
|
||||
)
|
||||
|
||||
sessions.append(
|
||||
Session(
|
||||
app_name=app_name,
|
||||
user_id=s_user_id,
|
||||
id=data["session_id"],
|
||||
state=merged,
|
||||
events=[],
|
||||
last_update_time=data.get("last_update_time", 0.0),
|
||||
)
|
||||
)
|
||||
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
@override
|
||||
async def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
ref = self._session_ref(app_name, user_id, session_id)
|
||||
await self._db.recursive_delete(ref)
|
||||
|
||||
@override
|
||||
async def append_event(self, session: Session, event: Event) -> Event:
|
||||
if event.partial:
|
||||
return event
|
||||
|
||||
t0 = time.monotonic()
|
||||
|
||||
app_name = session.app_name
|
||||
user_id = session.user_id
|
||||
session_id = session.id
|
||||
|
||||
# Base class: strips temp state, applies delta to in-memory session,
|
||||
# appends event to session.events
|
||||
event = await super().append_event(session=session, event=event)
|
||||
session.last_update_time = event.timestamp
|
||||
|
||||
# Persist event document
|
||||
event_data = event.model_dump(mode="json", exclude_none=True)
|
||||
await (
|
||||
self._events_col(app_name, user_id, session_id)
|
||||
.document(event.id)
|
||||
.set(event_data)
|
||||
)
|
||||
|
||||
# Persist state deltas
|
||||
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||
|
||||
if event.actions and event.actions.state_delta:
|
||||
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
|
||||
|
||||
write_coros: list = []
|
||||
if state_deltas["app"]:
|
||||
write_coros.append(
|
||||
self._app_state_ref(app_name).set(state_deltas["app"], merge=True)
|
||||
)
|
||||
if state_deltas["user"]:
|
||||
write_coros.append(
|
||||
self._user_state_ref(app_name, user_id).set(
|
||||
state_deltas["user"], merge=True
|
||||
)
|
||||
)
|
||||
|
||||
if state_deltas["session"]:
|
||||
field_updates: dict[str, Any] = {
|
||||
FieldPath("state", k).to_api_repr(): v
|
||||
for k, v in state_deltas["session"].items()
|
||||
}
|
||||
field_updates["last_update_time"] = event.timestamp
|
||||
write_coros.append(session_ref.update(field_updates))
|
||||
else:
|
||||
write_coros.append(
|
||||
session_ref.update({"last_update_time": event.timestamp})
|
||||
)
|
||||
|
||||
await asyncio.gather(*write_coros)
|
||||
else:
|
||||
await session_ref.update({"last_update_time": event.timestamp})
|
||||
|
||||
# Log token usage
|
||||
if event.usage_metadata:
|
||||
meta = event.usage_metadata
|
||||
logger.info(
|
||||
"Token usage for session %s event %s: "
|
||||
"prompt=%s, candidates=%s, total=%s",
|
||||
session_id,
|
||||
event.id,
|
||||
meta.prompt_token_count,
|
||||
meta.candidates_token_count,
|
||||
meta.total_token_count,
|
||||
)
|
||||
|
||||
# Trigger compaction if total token count exceeds threshold
|
||||
if (
|
||||
self._compaction_threshold is not None
|
||||
and event.usage_metadata
|
||||
and event.usage_metadata.total_token_count
|
||||
and event.usage_metadata.total_token_count >= self._compaction_threshold
|
||||
):
|
||||
logger.info(
|
||||
"Compaction triggered for session %s: "
|
||||
"total_token_count=%d >= threshold=%d",
|
||||
session_id,
|
||||
event.usage_metadata.total_token_count,
|
||||
self._compaction_threshold,
|
||||
)
|
||||
task = asyncio.create_task(self._guarded_compact(session))
|
||||
self._active_tasks.add(task)
|
||||
task.add_done_callback(self._active_tasks.discard)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
logger.info(
|
||||
"append_event completed for session %s event %s in %.3fs",
|
||||
session_id,
|
||||
event.id,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
return event
|
||||
Reference in New Issue
Block a user