444 lines
16 KiB
Python
444 lines
16 KiB
Python
"""Firestore-backed session service for Google ADK."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import TYPE_CHECKING, Any, override
|
|
|
|
from google.adk.errors.already_exists_error import AlreadyExistsError
|
|
from google.adk.events.event import Event
|
|
from google.adk.sessions import _session_util
|
|
from google.adk.sessions.base_session_service import (
|
|
BaseSessionService,
|
|
GetSessionConfig,
|
|
ListSessionsResponse,
|
|
)
|
|
from google.adk.sessions.session import Session
|
|
from google.adk.sessions.state import State
|
|
from google.cloud.firestore_v1.base_query import FieldFilter
|
|
from google.cloud.firestore_v1.field_path import FieldPath
|
|
from google.genai.types import Content, Part
|
|
|
|
from .compaction import SessionCompactor
|
|
|
|
if TYPE_CHECKING:
|
|
from google import genai
|
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
|
|
|
logger = logging.getLogger("google_adk." + __name__)
|
|
|
|
|
|
class FirestoreSessionService(BaseSessionService):
|
|
"""A Firestore-backed implementation of BaseSessionService.
|
|
|
|
Firestore document layout (given ``collection_prefix="adk"``)::
|
|
|
|
adk_app_states/{app_name}
|
|
→ app-scoped state key/values
|
|
|
|
adk_user_states/{app_name}__{user_id}
|
|
→ user-scoped state key/values
|
|
|
|
adk_sessions/{app_name}__{user_id}__{session_id}
|
|
→ {app_name, user_id, session_id, state: {…}, last_update_time}
|
|
└─ events/{event_id} → serialised Event
|
|
"""
|
|
|
|
def __init__( # noqa: PLR0913
|
|
self,
|
|
*,
|
|
db: AsyncClient,
|
|
collection_prefix: str = "adk",
|
|
compaction_token_threshold: int | None = None,
|
|
compaction_model: str = "gemini-2.5-flash",
|
|
compaction_keep_recent: int = 10,
|
|
genai_client: genai.Client | None = None,
|
|
) -> None:
|
|
"""Initialize FirestoreSessionService.
|
|
|
|
Args:
|
|
db: Firestore async client
|
|
collection_prefix: Prefix for Firestore collections
|
|
compaction_token_threshold: Token count threshold for compaction
|
|
compaction_model: Model to use for summarization
|
|
compaction_keep_recent: Number of recent events to keep
|
|
genai_client: GenAI client for compaction summaries
|
|
|
|
"""
|
|
if compaction_token_threshold is not None and genai_client is None:
|
|
msg = "genai_client is required when compaction_token_threshold is set."
|
|
raise ValueError(msg)
|
|
self._db = db
|
|
self._prefix = collection_prefix
|
|
self._compaction_threshold = compaction_token_threshold
|
|
self._compactor = SessionCompactor(
|
|
db=db,
|
|
genai_client=genai_client,
|
|
compaction_model=compaction_model,
|
|
compaction_keep_recent=compaction_keep_recent,
|
|
)
|
|
self._active_tasks: set[asyncio.Task] = set()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Document-reference helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _app_state_ref(self, app_name: str) -> Any:
|
|
return self._db.collection(f"{self._prefix}_app_states").document(app_name)
|
|
|
|
def _user_state_ref(self, app_name: str, user_id: str) -> Any:
|
|
return self._db.collection(f"{self._prefix}_user_states").document(
|
|
f"{app_name}__{user_id}"
|
|
)
|
|
|
|
def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any:
|
|
return self._db.collection(f"{self._prefix}_sessions").document(
|
|
f"{app_name}__{user_id}__{session_id}"
|
|
)
|
|
|
|
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
|
|
return self._session_ref(app_name, user_id, session_id).collection("events")
|
|
|
|
# ------------------------------------------------------------------
|
|
# State helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
|
|
snap = await self._app_state_ref(app_name).get()
|
|
return snap.to_dict() or {} if snap.exists else {}
|
|
|
|
async def _get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]:
|
|
snap = await self._user_state_ref(app_name, user_id).get()
|
|
return snap.to_dict() or {} if snap.exists else {}
|
|
|
|
@staticmethod
|
|
def _merge_state(
|
|
app_state: dict[str, Any],
|
|
user_state: dict[str, Any],
|
|
session_state: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
merged = dict(session_state)
|
|
for key, value in app_state.items():
|
|
merged[State.APP_PREFIX + key] = value
|
|
for key, value in user_state.items():
|
|
merged[State.USER_PREFIX + key] = value
|
|
return merged
|
|
|
|
async def close(self) -> None:
|
|
"""Await all in-flight compaction tasks. Call before shutdown."""
|
|
if self._active_tasks:
|
|
await asyncio.gather(*self._active_tasks, return_exceptions=True)
|
|
|
|
# ------------------------------------------------------------------
|
|
# BaseSessionService implementation
|
|
# ------------------------------------------------------------------
|
|
|
|
@override
|
|
async def create_session(
|
|
self,
|
|
*,
|
|
app_name: str,
|
|
user_id: str,
|
|
state: dict[str, Any] | None = None,
|
|
session_id: str | None = None,
|
|
) -> Session:
|
|
if session_id and session_id.strip():
|
|
session_id = session_id.strip()
|
|
existing = await self._session_ref(app_name, user_id, session_id).get()
|
|
if existing.exists:
|
|
msg = f"Session with id {session_id} already exists."
|
|
raise AlreadyExistsError(msg)
|
|
else:
|
|
session_id = str(uuid.uuid4())
|
|
|
|
state_deltas = _session_util.extract_state_delta(state) # type: ignore[attr-defined]
|
|
app_state_delta = state_deltas["app"]
|
|
user_state_delta = state_deltas["user"]
|
|
session_state = state_deltas["session"]
|
|
|
|
write_coros: list = []
|
|
if app_state_delta:
|
|
write_coros.append(
|
|
self._app_state_ref(app_name).set(app_state_delta, merge=True)
|
|
)
|
|
if user_state_delta:
|
|
write_coros.append(
|
|
self._user_state_ref(app_name, user_id).set(
|
|
user_state_delta, merge=True
|
|
)
|
|
)
|
|
|
|
now = time.time()
|
|
write_coros.append(
|
|
self._session_ref(app_name, user_id, session_id).set(
|
|
{
|
|
"app_name": app_name,
|
|
"user_id": user_id,
|
|
"session_id": session_id,
|
|
"state": session_state or {},
|
|
"last_update_time": now,
|
|
}
|
|
)
|
|
)
|
|
await asyncio.gather(*write_coros)
|
|
|
|
app_state, user_state = await asyncio.gather(
|
|
self._get_app_state(app_name),
|
|
self._get_user_state(app_name, user_id),
|
|
)
|
|
merged = self._merge_state(app_state, user_state, session_state or {})
|
|
|
|
return Session(
|
|
app_name=app_name,
|
|
user_id=user_id,
|
|
id=session_id,
|
|
state=merged,
|
|
last_update_time=now,
|
|
)
|
|
|
|
@override
|
|
async def get_session(
|
|
self,
|
|
*,
|
|
app_name: str,
|
|
user_id: str,
|
|
session_id: str,
|
|
config: GetSessionConfig | None = None,
|
|
) -> Session | None:
|
|
snap = await self._session_ref(app_name, user_id, session_id).get()
|
|
if not snap.exists:
|
|
return None
|
|
|
|
session_data = snap.to_dict()
|
|
|
|
# Build events query
|
|
events_ref = self._events_col(app_name, user_id, session_id)
|
|
query = events_ref
|
|
if config and config.after_timestamp:
|
|
query = query.where(
|
|
filter=FieldFilter("timestamp", ">=", config.after_timestamp)
|
|
)
|
|
query = query.order_by("timestamp")
|
|
|
|
event_docs, app_state, user_state = await asyncio.gather(
|
|
query.get(),
|
|
self._get_app_state(app_name),
|
|
self._get_user_state(app_name, user_id),
|
|
)
|
|
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
|
|
|
if config and config.num_recent_events:
|
|
events = events[-config.num_recent_events :]
|
|
|
|
# Prepend conversation summary as synthetic context events
|
|
conversation_summary = session_data.get("conversation_summary")
|
|
if conversation_summary:
|
|
summary_event = Event(
|
|
id="summary-context",
|
|
author="user",
|
|
content=Content(
|
|
role="user",
|
|
parts=[
|
|
Part(
|
|
text=(
|
|
"[Conversation context from previous"
|
|
" messages]\n"
|
|
f"{conversation_summary}"
|
|
)
|
|
)
|
|
],
|
|
),
|
|
timestamp=0.0,
|
|
invocation_id="compaction-summary",
|
|
)
|
|
ack_event = Event(
|
|
id="summary-ack",
|
|
author=app_name,
|
|
content=Content(
|
|
role="model",
|
|
parts=[
|
|
Part(
|
|
text=(
|
|
"Understood, I have the context from our"
|
|
" previous conversation and will continue"
|
|
" accordingly."
|
|
)
|
|
)
|
|
],
|
|
),
|
|
timestamp=0.001,
|
|
invocation_id="compaction-summary",
|
|
)
|
|
events = [summary_event, ack_event, *events]
|
|
|
|
# Merge scoped state
|
|
merged = self._merge_state(app_state, user_state, session_data.get("state", {}))
|
|
|
|
return Session(
|
|
app_name=app_name,
|
|
user_id=user_id,
|
|
id=session_id,
|
|
state=merged,
|
|
events=events,
|
|
last_update_time=session_data.get("last_update_time", 0.0),
|
|
)
|
|
|
|
@override
|
|
async def list_sessions(
|
|
self, *, app_name: str, user_id: str | None = None
|
|
) -> ListSessionsResponse:
|
|
query = self._db.collection(f"{self._prefix}_sessions").where(
|
|
filter=FieldFilter("app_name", "==", app_name)
|
|
)
|
|
if user_id is not None:
|
|
query = query.where(filter=FieldFilter("user_id", "==", user_id))
|
|
|
|
docs = await query.get()
|
|
if not docs:
|
|
return ListSessionsResponse()
|
|
|
|
doc_dicts: list[dict[str, Any]] = [doc.to_dict() or {} for doc in docs]
|
|
|
|
# Pre-fetch app state and all distinct user states in parallel
|
|
unique_user_ids = list({d["user_id"] for d in doc_dicts})
|
|
app_state, *user_states = await asyncio.gather(
|
|
self._get_app_state(app_name),
|
|
*(self._get_user_state(app_name, uid) for uid in unique_user_ids),
|
|
)
|
|
user_state_cache = dict(zip(unique_user_ids, user_states, strict=False))
|
|
|
|
sessions: list[Session] = []
|
|
for data in doc_dicts:
|
|
s_user_id = data["user_id"]
|
|
merged = self._merge_state(
|
|
app_state,
|
|
user_state_cache[s_user_id],
|
|
data.get("state", {}),
|
|
)
|
|
|
|
sessions.append(
|
|
Session(
|
|
app_name=app_name,
|
|
user_id=s_user_id,
|
|
id=data["session_id"],
|
|
state=merged,
|
|
events=[],
|
|
last_update_time=data.get("last_update_time", 0.0),
|
|
)
|
|
)
|
|
|
|
return ListSessionsResponse(sessions=sessions)
|
|
|
|
@override
|
|
async def delete_session(
|
|
self, *, app_name: str, user_id: str, session_id: str
|
|
) -> None:
|
|
ref = self._session_ref(app_name, user_id, session_id)
|
|
await self._db.recursive_delete(ref)
|
|
|
|
@override
|
|
async def append_event(self, session: Session, event: Event) -> Event:
|
|
if event.partial:
|
|
return event
|
|
|
|
t0 = time.monotonic()
|
|
|
|
app_name = session.app_name
|
|
user_id = session.user_id
|
|
session_id = session.id
|
|
|
|
# Base class: strips temp state, applies delta to in-memory session,
|
|
# appends event to session.events
|
|
event = await super().append_event(session=session, event=event)
|
|
session.last_update_time = event.timestamp
|
|
|
|
# Persist event document
|
|
event_data = event.model_dump(mode="json", exclude_none=True)
|
|
await (
|
|
self._events_col(app_name, user_id, session_id)
|
|
.document(event.id)
|
|
.set(event_data)
|
|
)
|
|
|
|
# Persist state deltas
|
|
session_ref = self._session_ref(app_name, user_id, session_id)
|
|
|
|
if event.actions and event.actions.state_delta:
|
|
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
|
|
|
|
write_coros: list = []
|
|
if state_deltas["app"]:
|
|
write_coros.append(
|
|
self._app_state_ref(app_name).set(state_deltas["app"], merge=True)
|
|
)
|
|
if state_deltas["user"]:
|
|
write_coros.append(
|
|
self._user_state_ref(app_name, user_id).set(
|
|
state_deltas["user"], merge=True
|
|
)
|
|
)
|
|
|
|
if state_deltas["session"]:
|
|
field_updates: dict[str, Any] = {
|
|
FieldPath("state", k).to_api_repr(): v
|
|
for k, v in state_deltas["session"].items()
|
|
}
|
|
field_updates["last_update_time"] = event.timestamp
|
|
write_coros.append(session_ref.update(field_updates))
|
|
else:
|
|
write_coros.append(
|
|
session_ref.update({"last_update_time": event.timestamp})
|
|
)
|
|
|
|
await asyncio.gather(*write_coros)
|
|
else:
|
|
await session_ref.update({"last_update_time": event.timestamp})
|
|
|
|
# Log token usage
|
|
if event.usage_metadata:
|
|
meta = event.usage_metadata
|
|
logger.info(
|
|
"Token usage for session %s event %s: "
|
|
"prompt=%s, candidates=%s, total=%s",
|
|
session_id,
|
|
event.id,
|
|
meta.prompt_token_count,
|
|
meta.candidates_token_count,
|
|
meta.total_token_count,
|
|
)
|
|
|
|
# Trigger compaction if total token count exceeds threshold
|
|
if (
|
|
self._compaction_threshold is not None
|
|
and event.usage_metadata
|
|
and event.usage_metadata.total_token_count
|
|
and event.usage_metadata.total_token_count >= self._compaction_threshold
|
|
):
|
|
logger.info(
|
|
"Compaction triggered for session %s: "
|
|
"total_token_count=%d >= threshold=%d",
|
|
session_id,
|
|
event.usage_metadata.total_token_count,
|
|
self._compaction_threshold,
|
|
)
|
|
events_ref = self._events_col(app_name, user_id, session_id)
|
|
session_ref = self._session_ref(app_name, user_id, session_id)
|
|
task = asyncio.create_task(
|
|
self._compactor.guarded_compact(session, events_ref, session_ref)
|
|
)
|
|
self._active_tasks.add(task)
|
|
task.add_done_callback(self._active_tasks.discard)
|
|
|
|
elapsed = time.monotonic() - t0
|
|
logger.info(
|
|
"append_event completed for session %s event %s in %.3fs",
|
|
session_id,
|
|
event.id,
|
|
elapsed,
|
|
)
|
|
|
|
return event
|