Initial implementation

This commit is contained in:
2026-02-21 18:06:19 +00:00
commit a223f3500d
8 changed files with 2976 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@@ -0,0 +1,10 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12

0
README.md Normal file
View File

17
pyproject.toml Normal file
View File

@@ -0,0 +1,17 @@
[project]
name = "adk-firestore-sessionmanager"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
authors = [
{ name = "A8065384", email = "anibal.angulo.cardoza@banorte.com" }
]
requires-python = ">=3.12"
dependencies = [
"google-adk>=1.0.0",
"google-cloud-firestore>=2.19.0",
]
[build-system]
requires = ["uv_build>=0.9.22,<0.10.0"]
build-backend = "uv_build"

View File

@@ -0,0 +1,3 @@
from .firestore_session_service import FirestoreSessionService
__all__ = ["FirestoreSessionService"]

View File

@@ -0,0 +1,318 @@
"""Firestore-backed session service for Google ADK."""
from __future__ import annotations
import logging
import time
from typing import Any, Optional
import uuid
from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event
from google.adk.sessions import _session_util
from google.adk.sessions.base_session_service import (
BaseSessionService,
GetSessionConfig,
ListSessionsResponse,
)
from google.adk.sessions.session import Session
from google.adk.sessions.state import State
from google.cloud.firestore_v1.async_client import AsyncClient
from google.cloud.firestore_v1.field_path import FieldPath
from typing_extensions import override
logger = logging.getLogger("google_adk." + __name__)
class FirestoreSessionService(BaseSessionService):
"""A Firestore-backed implementation of BaseSessionService.
Firestore document layout (given ``collection_prefix="adk"``)::
adk_app_states/{app_name}
→ app-scoped state key/values
adk_user_states/{app_name}__{user_id}
→ user-scoped state key/values
adk_sessions/{app_name}__{user_id}__{session_id}
{app_name, user_id, session_id, state: {…}, last_update_time}
└─ events/{event_id} → serialised Event
"""
def __init__(
self,
*,
db: AsyncClient,
collection_prefix: str = "adk",
) -> None:
self._db = db
self._prefix = collection_prefix
# ------------------------------------------------------------------
# Document-reference helpers
# ------------------------------------------------------------------
def _app_state_ref(self, app_name: str):
return self._db.collection(f"{self._prefix}_app_states").document(
app_name
)
def _user_state_ref(self, app_name: str, user_id: str):
return self._db.collection(f"{self._prefix}_user_states").document(
f"{app_name}__{user_id}"
)
def _session_ref(self, app_name: str, user_id: str, session_id: str):
return self._db.collection(f"{self._prefix}_sessions").document(
f"{app_name}__{user_id}__{session_id}"
)
def _events_col(self, app_name: str, user_id: str, session_id: str):
return self._session_ref(app_name, user_id, session_id).collection(
"events"
)
# ------------------------------------------------------------------
# State helpers
# ------------------------------------------------------------------
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
snap = await self._app_state_ref(app_name).get()
return snap.to_dict() or {} if snap.exists else {}
async def _get_user_state(
self, app_name: str, user_id: str
) -> dict[str, Any]:
snap = await self._user_state_ref(app_name, user_id).get()
return snap.to_dict() or {} if snap.exists else {}
@staticmethod
def _merge_state(
app_state: dict[str, Any],
user_state: dict[str, Any],
session_state: dict[str, Any],
) -> dict[str, Any]:
merged = dict(session_state)
for key, value in app_state.items():
merged[State.APP_PREFIX + key] = value
for key, value in user_state.items():
merged[State.USER_PREFIX + key] = value
return merged
# ------------------------------------------------------------------
# BaseSessionService implementation
# ------------------------------------------------------------------
@override
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> Session:
if session_id and session_id.strip():
session_id = session_id.strip()
existing = await self._session_ref(
app_name, user_id, session_id
).get()
if existing.exists:
raise AlreadyExistsError(
f"Session with id {session_id} already exists."
)
else:
session_id = str(uuid.uuid4())
state_deltas = _session_util.extract_state_delta(state)
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state = state_deltas["session"]
if app_state_delta:
await self._app_state_ref(app_name).set(
app_state_delta, merge=True
)
if user_state_delta:
await self._user_state_ref(app_name, user_id).set(
user_state_delta, merge=True
)
now = time.time()
await self._session_ref(app_name, user_id, session_id).set(
{
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
"state": session_state or {},
"last_update_time": now,
}
)
app_state = await self._get_app_state(app_name)
user_state = await self._get_user_state(app_name, user_id)
merged = self._merge_state(app_state, user_state, session_state or {})
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
last_update_time=now,
)
@override
async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: Optional[GetSessionConfig] = None,
) -> Optional[Session]:
snap = await self._session_ref(app_name, user_id, session_id).get()
if not snap.exists:
return None
session_data = snap.to_dict()
# Build events query
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref
if config and config.after_timestamp:
query = query.where("timestamp", ">=", config.after_timestamp)
query = query.order_by("timestamp")
event_docs = await query.get()
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
if config and config.num_recent_events:
events = events[-config.num_recent_events :]
# Merge scoped state
app_state = await self._get_app_state(app_name)
user_state = await self._get_user_state(app_name, user_id)
merged = self._merge_state(
app_state, user_state, session_data.get("state", {})
)
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
events=events,
last_update_time=session_data.get("last_update_time", 0.0),
)
@override
async def list_sessions(
self, *, app_name: str, user_id: Optional[str] = None
) -> ListSessionsResponse:
query = self._db.collection(f"{self._prefix}_sessions").where(
"app_name", "==", app_name
)
if user_id is not None:
query = query.where("user_id", "==", user_id)
docs = await query.get()
if not docs:
return ListSessionsResponse()
# Pre-fetch app state (shared across all sessions in this app)
app_state = await self._get_app_state(app_name)
# Cache user states to avoid repeated reads
user_state_cache: dict[str, dict[str, Any]] = {}
sessions: list[Session] = []
for doc in docs:
data = doc.to_dict()
s_user_id = data["user_id"]
if s_user_id not in user_state_cache:
user_state_cache[s_user_id] = await self._get_user_state(
app_name, s_user_id
)
merged = self._merge_state(
app_state,
user_state_cache[s_user_id],
data.get("state", {}),
)
sessions.append(
Session(
app_name=app_name,
user_id=s_user_id,
id=data["session_id"],
state=merged,
events=[],
last_update_time=data.get("last_update_time", 0.0),
)
)
return ListSessionsResponse(sessions=sessions)
@override
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
ref = self._session_ref(app_name, user_id, session_id)
await self._db.recursive_delete(ref)
@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
return event
app_name = session.app_name
user_id = session.user_id
session_id = session.id
# Base class: strips temp state, applies delta to in-memory session,
# appends event to session.events
event = await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Persist event document
event_data = event.model_dump(mode="json", exclude_none=True)
await self._events_col(app_name, user_id, session_id).document(
event.id
).set(event_data)
# Persist state deltas
session_ref = self._session_ref(app_name, user_id, session_id)
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(
event.actions.state_delta
)
if state_deltas["app"]:
await self._app_state_ref(app_name).set(
state_deltas["app"], merge=True
)
if state_deltas["user"]:
await self._user_state_ref(app_name, user_id).set(
state_deltas["user"], merge=True
)
if state_deltas["session"]:
# Use FieldPath to safely update individual state keys
# (handles keys that contain dots)
field_updates: dict[str | FieldPath, Any] = {
FieldPath("state", k): v
for k, v in state_deltas["session"].items()
}
field_updates["last_update_time"] = event.timestamp
await session_ref.update(field_updates)
else:
await session_ref.update(
{"last_update_time": event.timestamp}
)
else:
await session_ref.update({"last_update_time": event.timestamp})
return event

2627
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff