Initial implementation
This commit is contained in:
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12
|
||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[project]
|
||||
name = "adk-firestore-sessionmanager"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "A8065384", email = "anibal.angulo.cardoza@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"google-adk>=1.0.0",
|
||||
"google-cloud-firestore>=2.19.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.9.22,<0.10.0"]
|
||||
build-backend = "uv_build"
|
||||
3
src/adk_firestore_sessionmanager/__init__.py
Normal file
3
src/adk_firestore_sessionmanager/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .firestore_session_service import FirestoreSessionService
|
||||
|
||||
__all__ = ["FirestoreSessionService"]
|
||||
318
src/adk_firestore_sessionmanager/firestore_session_service.py
Normal file
318
src/adk_firestore_sessionmanager/firestore_session_service.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""Firestore-backed session service for Google ADK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.sessions import _session_util
|
||||
from google.adk.sessions.base_session_service import (
|
||||
BaseSessionService,
|
||||
GetSessionConfig,
|
||||
ListSessionsResponse,
|
||||
)
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.sessions.state import State
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
from google.cloud.firestore_v1.field_path import FieldPath
|
||||
from typing_extensions import override
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
class FirestoreSessionService(BaseSessionService):
|
||||
"""A Firestore-backed implementation of BaseSessionService.
|
||||
|
||||
Firestore document layout (given ``collection_prefix="adk"``)::
|
||||
|
||||
adk_app_states/{app_name}
|
||||
→ app-scoped state key/values
|
||||
|
||||
adk_user_states/{app_name}__{user_id}
|
||||
→ user-scoped state key/values
|
||||
|
||||
adk_sessions/{app_name}__{user_id}__{session_id}
|
||||
→ {app_name, user_id, session_id, state: {…}, last_update_time}
|
||||
└─ events/{event_id} → serialised Event
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db: AsyncClient,
|
||||
collection_prefix: str = "adk",
|
||||
) -> None:
|
||||
self._db = db
|
||||
self._prefix = collection_prefix
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Document-reference helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _app_state_ref(self, app_name: str):
|
||||
return self._db.collection(f"{self._prefix}_app_states").document(
|
||||
app_name
|
||||
)
|
||||
|
||||
def _user_state_ref(self, app_name: str, user_id: str):
|
||||
return self._db.collection(f"{self._prefix}_user_states").document(
|
||||
f"{app_name}__{user_id}"
|
||||
)
|
||||
|
||||
def _session_ref(self, app_name: str, user_id: str, session_id: str):
|
||||
return self._db.collection(f"{self._prefix}_sessions").document(
|
||||
f"{app_name}__{user_id}__{session_id}"
|
||||
)
|
||||
|
||||
def _events_col(self, app_name: str, user_id: str, session_id: str):
|
||||
return self._session_ref(app_name, user_id, session_id).collection(
|
||||
"events"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
|
||||
snap = await self._app_state_ref(app_name).get()
|
||||
return snap.to_dict() or {} if snap.exists else {}
|
||||
|
||||
async def _get_user_state(
|
||||
self, app_name: str, user_id: str
|
||||
) -> dict[str, Any]:
|
||||
snap = await self._user_state_ref(app_name, user_id).get()
|
||||
return snap.to_dict() or {} if snap.exists else {}
|
||||
|
||||
@staticmethod
|
||||
def _merge_state(
|
||||
app_state: dict[str, Any],
|
||||
user_state: dict[str, Any],
|
||||
session_state: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
merged = dict(session_state)
|
||||
for key, value in app_state.items():
|
||||
merged[State.APP_PREFIX + key] = value
|
||||
for key, value in user_state.items():
|
||||
merged[State.USER_PREFIX + key] = value
|
||||
return merged
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# BaseSessionService implementation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Session:
|
||||
if session_id and session_id.strip():
|
||||
session_id = session_id.strip()
|
||||
existing = await self._session_ref(
|
||||
app_name, user_id, session_id
|
||||
).get()
|
||||
if existing.exists:
|
||||
raise AlreadyExistsError(
|
||||
f"Session with id {session_id} already exists."
|
||||
)
|
||||
else:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
state_deltas = _session_util.extract_state_delta(state)
|
||||
app_state_delta = state_deltas["app"]
|
||||
user_state_delta = state_deltas["user"]
|
||||
session_state = state_deltas["session"]
|
||||
|
||||
if app_state_delta:
|
||||
await self._app_state_ref(app_name).set(
|
||||
app_state_delta, merge=True
|
||||
)
|
||||
if user_state_delta:
|
||||
await self._user_state_ref(app_name, user_id).set(
|
||||
user_state_delta, merge=True
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
await self._session_ref(app_name, user_id, session_id).set(
|
||||
{
|
||||
"app_name": app_name,
|
||||
"user_id": user_id,
|
||||
"session_id": session_id,
|
||||
"state": session_state or {},
|
||||
"last_update_time": now,
|
||||
}
|
||||
)
|
||||
|
||||
app_state = await self._get_app_state(app_name)
|
||||
user_state = await self._get_user_state(app_name, user_id)
|
||||
merged = self._merge_state(app_state, user_state, session_state or {})
|
||||
|
||||
return Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=merged,
|
||||
last_update_time=now,
|
||||
)
|
||||
|
||||
@override
|
||||
async def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
config: Optional[GetSessionConfig] = None,
|
||||
) -> Optional[Session]:
|
||||
snap = await self._session_ref(app_name, user_id, session_id).get()
|
||||
if not snap.exists:
|
||||
return None
|
||||
|
||||
session_data = snap.to_dict()
|
||||
|
||||
# Build events query
|
||||
events_ref = self._events_col(app_name, user_id, session_id)
|
||||
query = events_ref
|
||||
if config and config.after_timestamp:
|
||||
query = query.where("timestamp", ">=", config.after_timestamp)
|
||||
query = query.order_by("timestamp")
|
||||
|
||||
event_docs = await query.get()
|
||||
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||
|
||||
if config and config.num_recent_events:
|
||||
events = events[-config.num_recent_events :]
|
||||
|
||||
# Merge scoped state
|
||||
app_state = await self._get_app_state(app_name)
|
||||
user_state = await self._get_user_state(app_name, user_id)
|
||||
merged = self._merge_state(
|
||||
app_state, user_state, session_data.get("state", {})
|
||||
)
|
||||
|
||||
return Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=merged,
|
||||
events=events,
|
||||
last_update_time=session_data.get("last_update_time", 0.0),
|
||||
)
|
||||
|
||||
@override
|
||||
async def list_sessions(
|
||||
self, *, app_name: str, user_id: Optional[str] = None
|
||||
) -> ListSessionsResponse:
|
||||
query = self._db.collection(f"{self._prefix}_sessions").where(
|
||||
"app_name", "==", app_name
|
||||
)
|
||||
if user_id is not None:
|
||||
query = query.where("user_id", "==", user_id)
|
||||
|
||||
docs = await query.get()
|
||||
if not docs:
|
||||
return ListSessionsResponse()
|
||||
|
||||
# Pre-fetch app state (shared across all sessions in this app)
|
||||
app_state = await self._get_app_state(app_name)
|
||||
|
||||
# Cache user states to avoid repeated reads
|
||||
user_state_cache: dict[str, dict[str, Any]] = {}
|
||||
sessions: list[Session] = []
|
||||
|
||||
for doc in docs:
|
||||
data = doc.to_dict()
|
||||
s_user_id = data["user_id"]
|
||||
|
||||
if s_user_id not in user_state_cache:
|
||||
user_state_cache[s_user_id] = await self._get_user_state(
|
||||
app_name, s_user_id
|
||||
)
|
||||
|
||||
merged = self._merge_state(
|
||||
app_state,
|
||||
user_state_cache[s_user_id],
|
||||
data.get("state", {}),
|
||||
)
|
||||
|
||||
sessions.append(
|
||||
Session(
|
||||
app_name=app_name,
|
||||
user_id=s_user_id,
|
||||
id=data["session_id"],
|
||||
state=merged,
|
||||
events=[],
|
||||
last_update_time=data.get("last_update_time", 0.0),
|
||||
)
|
||||
)
|
||||
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
@override
|
||||
async def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
ref = self._session_ref(app_name, user_id, session_id)
|
||||
await self._db.recursive_delete(ref)
|
||||
|
||||
@override
|
||||
async def append_event(self, session: Session, event: Event) -> Event:
|
||||
if event.partial:
|
||||
return event
|
||||
|
||||
app_name = session.app_name
|
||||
user_id = session.user_id
|
||||
session_id = session.id
|
||||
|
||||
# Base class: strips temp state, applies delta to in-memory session,
|
||||
# appends event to session.events
|
||||
event = await super().append_event(session=session, event=event)
|
||||
session.last_update_time = event.timestamp
|
||||
|
||||
# Persist event document
|
||||
event_data = event.model_dump(mode="json", exclude_none=True)
|
||||
await self._events_col(app_name, user_id, session_id).document(
|
||||
event.id
|
||||
).set(event_data)
|
||||
|
||||
# Persist state deltas
|
||||
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||
|
||||
if event.actions and event.actions.state_delta:
|
||||
state_deltas = _session_util.extract_state_delta(
|
||||
event.actions.state_delta
|
||||
)
|
||||
|
||||
if state_deltas["app"]:
|
||||
await self._app_state_ref(app_name).set(
|
||||
state_deltas["app"], merge=True
|
||||
)
|
||||
if state_deltas["user"]:
|
||||
await self._user_state_ref(app_name, user_id).set(
|
||||
state_deltas["user"], merge=True
|
||||
)
|
||||
|
||||
if state_deltas["session"]:
|
||||
# Use FieldPath to safely update individual state keys
|
||||
# (handles keys that contain dots)
|
||||
field_updates: dict[str | FieldPath, Any] = {
|
||||
FieldPath("state", k): v
|
||||
for k, v in state_deltas["session"].items()
|
||||
}
|
||||
field_updates["last_update_time"] = event.timestamp
|
||||
await session_ref.update(field_updates)
|
||||
else:
|
||||
await session_ref.update(
|
||||
{"last_update_time": event.timestamp}
|
||||
)
|
||||
else:
|
||||
await session_ref.update({"last_update_time": event.timestamp})
|
||||
|
||||
return event
|
||||
0
src/adk_firestore_sessionmanager/py.typed
Normal file
0
src/adk_firestore_sessionmanager/py.typed
Normal file
Reference in New Issue
Block a user