Files
agent/src/va_agent/session.py
Jorge Juarez f3515ee71c
All checks were successful
CI / ci (pull_request) Successful in 19s
fix(session): use datetime UTC and tighten timestamp logging
2026-03-10 21:24:11 +00:00

469 lines
16 KiB
Python

"""Firestore-backed session service for Google ADK."""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, override
from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event
from google.adk.sessions import _session_util
from google.adk.sessions.base_session_service import (
BaseSessionService,
GetSessionConfig,
ListSessionsResponse,
)
from google.adk.sessions.session import Session
from google.adk.sessions.state import State
from google.cloud.firestore_v1.base_query import FieldFilter
from google.cloud.firestore_v1.field_path import FieldPath
from google.genai.types import Content, Part
from .compaction import SessionCompactor
if TYPE_CHECKING:
from google import genai
from google.cloud.firestore_v1.async_client import AsyncClient
logger = logging.getLogger("google_adk." + __name__)
class FirestoreSessionService(BaseSessionService):
"""A Firestore-backed implementation of BaseSessionService.
Firestore document layout (given ``collection_prefix="adk"``)::
adk_app_states/{app_name}
→ app-scoped state key/values
adk_user_states/{app_name}__{user_id}
→ user-scoped state key/values
adk_sessions/{app_name}__{user_id}__{session_id}
{app_name, user_id, session_id, state: {…}, last_update_time}
└─ events/{event_id} → serialised Event
"""
def __init__( # noqa: PLR0913
self,
*,
db: AsyncClient,
collection_prefix: str = "adk",
compaction_token_threshold: int | None = None,
compaction_model: str = "gemini-2.5-flash",
compaction_keep_recent: int = 10,
genai_client: genai.Client | None = None,
) -> None:
"""Initialize FirestoreSessionService.
Args:
db: Firestore async client
collection_prefix: Prefix for Firestore collections
compaction_token_threshold: Token count threshold for compaction
compaction_model: Model to use for summarization
compaction_keep_recent: Number of recent events to keep
genai_client: GenAI client for compaction summaries
"""
if compaction_token_threshold is not None and genai_client is None:
msg = "genai_client is required when compaction_token_threshold is set."
raise ValueError(msg)
self._db = db
self._prefix = collection_prefix
self._compaction_threshold = compaction_token_threshold
self._compactor = SessionCompactor(
db=db,
genai_client=genai_client,
compaction_model=compaction_model,
compaction_keep_recent=compaction_keep_recent,
)
self._active_tasks: set[asyncio.Task] = set()
# ------------------------------------------------------------------
# Document-reference helpers
# ------------------------------------------------------------------
def _app_state_ref(self, app_name: str) -> Any:
return self._db.collection(f"{self._prefix}_app_states").document(app_name)
def _user_state_ref(self, app_name: str, user_id: str) -> Any:
return self._db.collection(f"{self._prefix}_user_states").document(
f"{app_name}__{user_id}"
)
def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any:
return self._db.collection(f"{self._prefix}_sessions").document(
f"{app_name}__{user_id}__{session_id}"
)
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
return self._session_ref(app_name, user_id, session_id).collection("events")
@staticmethod
def _timestamp_to_float(value: Any, default: float = 0.0) -> float:
if value is None:
return default
if isinstance(value, (int, float)):
return float(value)
if hasattr(value, "timestamp"):
try:
return float(value.timestamp())
except (
TypeError,
ValueError,
OSError,
OverflowError,
) as exc: # pragma: no cover
logger.debug("Failed to convert timestamp %r: %s", value, exc)
return default
# ------------------------------------------------------------------
# State helpers
# ------------------------------------------------------------------
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
snap = await self._app_state_ref(app_name).get()
return snap.to_dict() or {} if snap.exists else {}
async def _get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]:
snap = await self._user_state_ref(app_name, user_id).get()
return snap.to_dict() or {} if snap.exists else {}
@staticmethod
def _merge_state(
app_state: dict[str, Any],
user_state: dict[str, Any],
session_state: dict[str, Any],
) -> dict[str, Any]:
merged = dict(session_state)
for key, value in app_state.items():
merged[State.APP_PREFIX + key] = value
for key, value in user_state.items():
merged[State.USER_PREFIX + key] = value
return merged
async def close(self) -> None:
"""Await all in-flight compaction tasks. Call before shutdown."""
if self._active_tasks:
await asyncio.gather(*self._active_tasks, return_exceptions=True)
# ------------------------------------------------------------------
# BaseSessionService implementation
# ------------------------------------------------------------------
@override
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: dict[str, Any] | None = None,
session_id: str | None = None,
) -> Session:
if session_id and session_id.strip():
session_id = session_id.strip()
existing = await self._session_ref(app_name, user_id, session_id).get()
if existing.exists:
msg = f"Session with id {session_id} already exists."
raise AlreadyExistsError(msg)
else:
session_id = str(uuid.uuid4())
state_deltas = _session_util.extract_state_delta(state) # type: ignore[attr-defined]
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state = state_deltas["session"]
write_coros: list = []
if app_state_delta:
write_coros.append(
self._app_state_ref(app_name).set(app_state_delta, merge=True)
)
if user_state_delta:
write_coros.append(
self._user_state_ref(app_name, user_id).set(
user_state_delta, merge=True
)
)
now = datetime.now(UTC)
write_coros.append(
self._session_ref(app_name, user_id, session_id).set(
{
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
"state": session_state or {},
"last_update_time": now,
}
)
)
await asyncio.gather(*write_coros)
app_state, user_state = await asyncio.gather(
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
merged = self._merge_state(app_state, user_state, session_state or {})
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
last_update_time=now.timestamp(),
)
@override
async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: GetSessionConfig | None = None,
) -> Session | None:
snap = await self._session_ref(app_name, user_id, session_id).get()
if not snap.exists:
return None
session_data = snap.to_dict()
# Build events query
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref
if config and config.after_timestamp:
query = query.where(
filter=FieldFilter("timestamp", ">=", config.after_timestamp)
)
query = query.order_by("timestamp")
event_docs, app_state, user_state = await asyncio.gather(
query.get(),
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
if config and config.num_recent_events:
events = events[-config.num_recent_events :]
# Prepend conversation summary as synthetic context events
conversation_summary = session_data.get("conversation_summary")
if conversation_summary:
summary_event = Event(
id="summary-context",
author="user",
content=Content(
role="user",
parts=[
Part(
text=(
"[Conversation context from previous"
" messages]\n"
f"{conversation_summary}"
)
)
],
),
timestamp=0.0,
invocation_id="compaction-summary",
)
ack_event = Event(
id="summary-ack",
author=app_name,
content=Content(
role="model",
parts=[
Part(
text=(
"Understood, I have the context from our"
" previous conversation and will continue"
" accordingly."
)
)
],
),
timestamp=0.001,
invocation_id="compaction-summary",
)
events = [summary_event, ack_event, *events]
# Merge scoped state
merged = self._merge_state(app_state, user_state, session_data.get("state", {}))
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
events=events,
last_update_time=self._timestamp_to_float(
session_data.get("last_update_time"), 0.0
),
)
@override
async def list_sessions(
self, *, app_name: str, user_id: str | None = None
) -> ListSessionsResponse:
query = self._db.collection(f"{self._prefix}_sessions").where(
filter=FieldFilter("app_name", "==", app_name)
)
if user_id is not None:
query = query.where(filter=FieldFilter("user_id", "==", user_id))
docs = await query.get()
if not docs:
return ListSessionsResponse()
doc_dicts: list[dict[str, Any]] = [doc.to_dict() or {} for doc in docs]
# Pre-fetch app state and all distinct user states in parallel
unique_user_ids = list({d["user_id"] for d in doc_dicts})
app_state, *user_states = await asyncio.gather(
self._get_app_state(app_name),
*(self._get_user_state(app_name, uid) for uid in unique_user_ids),
)
user_state_cache = dict(zip(unique_user_ids, user_states, strict=False))
sessions: list[Session] = []
for data in doc_dicts:
s_user_id = data["user_id"]
merged = self._merge_state(
app_state,
user_state_cache[s_user_id],
data.get("state", {}),
)
sessions.append(
Session(
app_name=app_name,
user_id=s_user_id,
id=data["session_id"],
state=merged,
events=[],
last_update_time=self._timestamp_to_float(
data.get("last_update_time"), 0.0
),
)
)
return ListSessionsResponse(sessions=sessions)
@override
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
ref = self._session_ref(app_name, user_id, session_id)
await self._db.recursive_delete(ref)
@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
return event
t0 = time.monotonic()
app_name = session.app_name
user_id = session.user_id
session_id = session.id
# Base class: strips temp state, applies delta to in-memory session,
# appends event to session.events
event = await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Persist event document
event_data = event.model_dump(mode="json", exclude_none=True)
await (
self._events_col(app_name, user_id, session_id)
.document(event.id)
.set(event_data)
)
# Persist state deltas
session_ref = self._session_ref(app_name, user_id, session_id)
last_update_dt = datetime.fromtimestamp(event.timestamp, UTC)
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
write_coros: list = []
if state_deltas["app"]:
write_coros.append(
self._app_state_ref(app_name).set(state_deltas["app"], merge=True)
)
if state_deltas["user"]:
write_coros.append(
self._user_state_ref(app_name, user_id).set(
state_deltas["user"], merge=True
)
)
if state_deltas["session"]:
field_updates: dict[str, Any] = {
FieldPath("state", k).to_api_repr(): v
for k, v in state_deltas["session"].items()
}
field_updates["last_update_time"] = last_update_dt
write_coros.append(session_ref.update(field_updates))
else:
write_coros.append(
session_ref.update({"last_update_time": last_update_dt})
)
await asyncio.gather(*write_coros)
else:
await session_ref.update({"last_update_time": last_update_dt})
# Log token usage
if event.usage_metadata:
meta = event.usage_metadata
logger.info(
"Token usage for session %s event %s: "
"prompt=%s, candidates=%s, total=%s",
session_id,
event.id,
meta.prompt_token_count,
meta.candidates_token_count,
meta.total_token_count,
)
# Trigger compaction if total token count exceeds threshold
if (
self._compaction_threshold is not None
and event.usage_metadata
and event.usage_metadata.total_token_count
and event.usage_metadata.total_token_count >= self._compaction_threshold
):
logger.info(
"Compaction triggered for session %s: "
"total_token_count=%d >= threshold=%d",
session_id,
event.usage_metadata.total_token_count,
self._compaction_threshold,
)
events_ref = self._events_col(app_name, user_id, session_id)
session_ref = self._session_ref(app_name, user_id, session_id)
task = asyncio.create_task(
self._compactor.guarded_compact(session, events_ref, session_ref)
)
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
elapsed = time.monotonic() - t0
logger.info(
"append_event completed for session %s event %s in %.3fs",
session_id,
event.id,
elapsed,
)
return event