"""Firestore-backed session service for Google ADK.""" from __future__ import annotations import asyncio import logging import time import uuid from typing import TYPE_CHECKING, Any, override from google.adk.errors.already_exists_error import AlreadyExistsError from google.adk.events.event import Event from google.adk.sessions import _session_util from google.adk.sessions.base_session_service import ( BaseSessionService, GetSessionConfig, ListSessionsResponse, ) from google.adk.sessions.session import Session from google.adk.sessions.state import State from google.cloud.firestore_v1.base_query import FieldFilter from google.cloud.firestore_v1.field_path import FieldPath from google.genai.types import Content, Part from .compaction import SessionCompactor if TYPE_CHECKING: from google import genai from google.cloud.firestore_v1.async_client import AsyncClient logger = logging.getLogger("google_adk." + __name__) class FirestoreSessionService(BaseSessionService): """A Firestore-backed implementation of BaseSessionService. Firestore document layout (given ``collection_prefix="adk"``):: adk_app_states/{app_name} → app-scoped state key/values adk_user_states/{app_name}__{user_id} → user-scoped state key/values adk_sessions/{app_name}__{user_id}__{session_id} → {app_name, user_id, session_id, state: {…}, last_update_time} └─ events/{event_id} → serialised Event """ def __init__( # noqa: PLR0913 self, *, db: AsyncClient, collection_prefix: str = "adk", compaction_token_threshold: int | None = None, compaction_model: str = "gemini-2.5-flash", compaction_keep_recent: int = 10, genai_client: genai.Client | None = None, ) -> None: """Initialize FirestoreSessionService. Args: db: Firestore async client collection_prefix: Prefix for Firestore collections compaction_token_threshold: Token count threshold for compaction compaction_model: Model to use for summarization compaction_keep_recent: Number of recent events to keep genai_client: GenAI client for compaction summaries """ if compaction_token_threshold is not None and genai_client is None: msg = "genai_client is required when compaction_token_threshold is set." raise ValueError(msg) self._db = db self._prefix = collection_prefix self._compaction_threshold = compaction_token_threshold self._compactor = SessionCompactor( db=db, genai_client=genai_client, compaction_model=compaction_model, compaction_keep_recent=compaction_keep_recent, ) self._active_tasks: set[asyncio.Task] = set() # ------------------------------------------------------------------ # Document-reference helpers # ------------------------------------------------------------------ def _app_state_ref(self, app_name: str) -> Any: return self._db.collection(f"{self._prefix}_app_states").document(app_name) def _user_state_ref(self, app_name: str, user_id: str) -> Any: return self._db.collection(f"{self._prefix}_user_states").document( f"{app_name}__{user_id}" ) def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any: return self._db.collection(f"{self._prefix}_sessions").document( f"{app_name}__{user_id}__{session_id}" ) def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any: return self._session_ref(app_name, user_id, session_id).collection("events") # ------------------------------------------------------------------ # State helpers # ------------------------------------------------------------------ async def _get_app_state(self, app_name: str) -> dict[str, Any]: snap = await self._app_state_ref(app_name).get() return snap.to_dict() or {} if snap.exists else {} async def _get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]: snap = await self._user_state_ref(app_name, user_id).get() return snap.to_dict() or {} if snap.exists else {} @staticmethod def _merge_state( app_state: dict[str, Any], user_state: dict[str, Any], session_state: dict[str, Any], ) -> dict[str, Any]: merged = dict(session_state) for key, value in app_state.items(): merged[State.APP_PREFIX + key] = value for key, value in user_state.items(): merged[State.USER_PREFIX + key] = value return merged async def close(self) -> None: """Await all in-flight compaction tasks. Call before shutdown.""" if self._active_tasks: await asyncio.gather(*self._active_tasks, return_exceptions=True) # ------------------------------------------------------------------ # BaseSessionService implementation # ------------------------------------------------------------------ @override async def create_session( self, *, app_name: str, user_id: str, state: dict[str, Any] | None = None, session_id: str | None = None, ) -> Session: if session_id and session_id.strip(): session_id = session_id.strip() existing = await self._session_ref(app_name, user_id, session_id).get() if existing.exists: msg = f"Session with id {session_id} already exists." raise AlreadyExistsError(msg) else: session_id = str(uuid.uuid4()) state_deltas = _session_util.extract_state_delta(state) # type: ignore[attr-defined] app_state_delta = state_deltas["app"] user_state_delta = state_deltas["user"] session_state = state_deltas["session"] write_coros: list = [] if app_state_delta: write_coros.append( self._app_state_ref(app_name).set(app_state_delta, merge=True) ) if user_state_delta: write_coros.append( self._user_state_ref(app_name, user_id).set( user_state_delta, merge=True ) ) now = time.time() write_coros.append( self._session_ref(app_name, user_id, session_id).set( { "app_name": app_name, "user_id": user_id, "session_id": session_id, "state": session_state or {}, "last_update_time": now, } ) ) await asyncio.gather(*write_coros) app_state, user_state = await asyncio.gather( self._get_app_state(app_name), self._get_user_state(app_name, user_id), ) merged = self._merge_state(app_state, user_state, session_state or {}) return Session( app_name=app_name, user_id=user_id, id=session_id, state=merged, last_update_time=now, ) @override async def get_session( self, *, app_name: str, user_id: str, session_id: str, config: GetSessionConfig | None = None, ) -> Session | None: snap = await self._session_ref(app_name, user_id, session_id).get() if not snap.exists: return None session_data = snap.to_dict() # Build events query events_ref = self._events_col(app_name, user_id, session_id) query = events_ref if config and config.after_timestamp: query = query.where( filter=FieldFilter("timestamp", ">=", config.after_timestamp) ) query = query.order_by("timestamp") event_docs, app_state, user_state = await asyncio.gather( query.get(), self._get_app_state(app_name), self._get_user_state(app_name, user_id), ) events = [Event.model_validate(doc.to_dict()) for doc in event_docs] if config and config.num_recent_events: events = events[-config.num_recent_events :] # Prepend conversation summary as synthetic context events conversation_summary = session_data.get("conversation_summary") if conversation_summary: summary_event = Event( id="summary-context", author="user", content=Content( role="user", parts=[ Part( text=( "[Conversation context from previous" " messages]\n" f"{conversation_summary}" ) ) ], ), timestamp=0.0, invocation_id="compaction-summary", ) ack_event = Event( id="summary-ack", author=app_name, content=Content( role="model", parts=[ Part( text=( "Understood, I have the context from our" " previous conversation and will continue" " accordingly." ) ) ], ), timestamp=0.001, invocation_id="compaction-summary", ) events = [summary_event, ack_event, *events] # Merge scoped state merged = self._merge_state(app_state, user_state, session_data.get("state", {})) return Session( app_name=app_name, user_id=user_id, id=session_id, state=merged, events=events, last_update_time=session_data.get("last_update_time", 0.0), ) @override async def list_sessions( self, *, app_name: str, user_id: str | None = None ) -> ListSessionsResponse: query = self._db.collection(f"{self._prefix}_sessions").where( filter=FieldFilter("app_name", "==", app_name) ) if user_id is not None: query = query.where(filter=FieldFilter("user_id", "==", user_id)) docs = await query.get() if not docs: return ListSessionsResponse() doc_dicts: list[dict[str, Any]] = [doc.to_dict() or {} for doc in docs] # Pre-fetch app state and all distinct user states in parallel unique_user_ids = list({d["user_id"] for d in doc_dicts}) app_state, *user_states = await asyncio.gather( self._get_app_state(app_name), *(self._get_user_state(app_name, uid) for uid in unique_user_ids), ) user_state_cache = dict(zip(unique_user_ids, user_states, strict=False)) sessions: list[Session] = [] for data in doc_dicts: s_user_id = data["user_id"] merged = self._merge_state( app_state, user_state_cache[s_user_id], data.get("state", {}), ) sessions.append( Session( app_name=app_name, user_id=s_user_id, id=data["session_id"], state=merged, events=[], last_update_time=data.get("last_update_time", 0.0), ) ) return ListSessionsResponse(sessions=sessions) @override async def delete_session( self, *, app_name: str, user_id: str, session_id: str ) -> None: ref = self._session_ref(app_name, user_id, session_id) await self._db.recursive_delete(ref) @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: return event t0 = time.monotonic() app_name = session.app_name user_id = session.user_id session_id = session.id # Base class: strips temp state, applies delta to in-memory session, # appends event to session.events event = await super().append_event(session=session, event=event) session.last_update_time = event.timestamp # Persist event document event_data = event.model_dump(mode="json", exclude_none=True) await ( self._events_col(app_name, user_id, session_id) .document(event.id) .set(event_data) ) # Persist state deltas session_ref = self._session_ref(app_name, user_id, session_id) if event.actions and event.actions.state_delta: state_deltas = _session_util.extract_state_delta(event.actions.state_delta) write_coros: list = [] if state_deltas["app"]: write_coros.append( self._app_state_ref(app_name).set(state_deltas["app"], merge=True) ) if state_deltas["user"]: write_coros.append( self._user_state_ref(app_name, user_id).set( state_deltas["user"], merge=True ) ) if state_deltas["session"]: field_updates: dict[str, Any] = { FieldPath("state", k).to_api_repr(): v for k, v in state_deltas["session"].items() } field_updates["last_update_time"] = event.timestamp write_coros.append(session_ref.update(field_updates)) else: write_coros.append( session_ref.update({"last_update_time": event.timestamp}) ) await asyncio.gather(*write_coros) else: await session_ref.update({"last_update_time": event.timestamp}) # Log token usage if event.usage_metadata: meta = event.usage_metadata logger.info( "Token usage for session %s event %s: " "prompt=%s, candidates=%s, total=%s", session_id, event.id, meta.prompt_token_count, meta.candidates_token_count, meta.total_token_count, ) # Trigger compaction if total token count exceeds threshold if ( self._compaction_threshold is not None and event.usage_metadata and event.usage_metadata.total_token_count and event.usage_metadata.total_token_count >= self._compaction_threshold ): logger.info( "Compaction triggered for session %s: " "total_token_count=%d >= threshold=%d", session_id, event.usage_metadata.total_token_count, self._compaction_threshold, ) events_ref = self._events_col(app_name, user_id, session_id) session_ref = self._session_ref(app_name, user_id, session_id) task = asyncio.create_task( self._compactor.guarded_compact(session, events_ref, session_ref) ) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) elapsed = time.monotonic() - t0 logger.info( "append_event completed for session %s event %s in %.3fs", session_id, event.id, elapsed, ) return event