Initial implementation
This commit is contained in:
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
17
pyproject.toml
Normal file
17
pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
[project]
|
||||||
|
name = "adk-firestore-sessionmanager"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
authors = [
|
||||||
|
{ name = "A8065384", email = "anibal.angulo.cardoza@banorte.com" }
|
||||||
|
]
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"google-adk>=1.0.0",
|
||||||
|
"google-cloud-firestore>=2.19.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["uv_build>=0.9.22,<0.10.0"]
|
||||||
|
build-backend = "uv_build"
|
||||||
3
src/adk_firestore_sessionmanager/__init__.py
Normal file
3
src/adk_firestore_sessionmanager/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .firestore_session_service import FirestoreSessionService
|
||||||
|
|
||||||
|
__all__ = ["FirestoreSessionService"]
|
||||||
318
src/adk_firestore_sessionmanager/firestore_session_service.py
Normal file
318
src/adk_firestore_sessionmanager/firestore_session_service.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
"""Firestore-backed session service for Google ADK."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||||
|
from google.adk.events.event import Event
|
||||||
|
from google.adk.sessions import _session_util
|
||||||
|
from google.adk.sessions.base_session_service import (
|
||||||
|
BaseSessionService,
|
||||||
|
GetSessionConfig,
|
||||||
|
ListSessionsResponse,
|
||||||
|
)
|
||||||
|
from google.adk.sessions.session import Session
|
||||||
|
from google.adk.sessions.state import State
|
||||||
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
from google.cloud.firestore_v1.field_path import FieldPath
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
logger = logging.getLogger("google_adk." + __name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FirestoreSessionService(BaseSessionService):
|
||||||
|
"""A Firestore-backed implementation of BaseSessionService.
|
||||||
|
|
||||||
|
Firestore document layout (given ``collection_prefix="adk"``)::
|
||||||
|
|
||||||
|
adk_app_states/{app_name}
|
||||||
|
→ app-scoped state key/values
|
||||||
|
|
||||||
|
adk_user_states/{app_name}__{user_id}
|
||||||
|
→ user-scoped state key/values
|
||||||
|
|
||||||
|
adk_sessions/{app_name}__{user_id}__{session_id}
|
||||||
|
→ {app_name, user_id, session_id, state: {…}, last_update_time}
|
||||||
|
└─ events/{event_id} → serialised Event
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: AsyncClient,
|
||||||
|
collection_prefix: str = "adk",
|
||||||
|
) -> None:
|
||||||
|
self._db = db
|
||||||
|
self._prefix = collection_prefix
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Document-reference helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _app_state_ref(self, app_name: str):
|
||||||
|
return self._db.collection(f"{self._prefix}_app_states").document(
|
||||||
|
app_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def _user_state_ref(self, app_name: str, user_id: str):
|
||||||
|
return self._db.collection(f"{self._prefix}_user_states").document(
|
||||||
|
f"{app_name}__{user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _session_ref(self, app_name: str, user_id: str, session_id: str):
|
||||||
|
return self._db.collection(f"{self._prefix}_sessions").document(
|
||||||
|
f"{app_name}__{user_id}__{session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _events_col(self, app_name: str, user_id: str, session_id: str):
|
||||||
|
return self._session_ref(app_name, user_id, session_id).collection(
|
||||||
|
"events"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# State helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
|
||||||
|
snap = await self._app_state_ref(app_name).get()
|
||||||
|
return snap.to_dict() or {} if snap.exists else {}
|
||||||
|
|
||||||
|
async def _get_user_state(
|
||||||
|
self, app_name: str, user_id: str
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
snap = await self._user_state_ref(app_name, user_id).get()
|
||||||
|
return snap.to_dict() or {} if snap.exists else {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_state(
|
||||||
|
app_state: dict[str, Any],
|
||||||
|
user_state: dict[str, Any],
|
||||||
|
session_state: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
merged = dict(session_state)
|
||||||
|
for key, value in app_state.items():
|
||||||
|
merged[State.APP_PREFIX + key] = value
|
||||||
|
for key, value in user_state.items():
|
||||||
|
merged[State.USER_PREFIX + key] = value
|
||||||
|
return merged
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# BaseSessionService implementation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def create_session(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
app_name: str,
|
||||||
|
user_id: str,
|
||||||
|
state: Optional[dict[str, Any]] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> Session:
|
||||||
|
if session_id and session_id.strip():
|
||||||
|
session_id = session_id.strip()
|
||||||
|
existing = await self._session_ref(
|
||||||
|
app_name, user_id, session_id
|
||||||
|
).get()
|
||||||
|
if existing.exists:
|
||||||
|
raise AlreadyExistsError(
|
||||||
|
f"Session with id {session_id} already exists."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
state_deltas = _session_util.extract_state_delta(state)
|
||||||
|
app_state_delta = state_deltas["app"]
|
||||||
|
user_state_delta = state_deltas["user"]
|
||||||
|
session_state = state_deltas["session"]
|
||||||
|
|
||||||
|
if app_state_delta:
|
||||||
|
await self._app_state_ref(app_name).set(
|
||||||
|
app_state_delta, merge=True
|
||||||
|
)
|
||||||
|
if user_state_delta:
|
||||||
|
await self._user_state_ref(app_name, user_id).set(
|
||||||
|
user_state_delta, merge=True
|
||||||
|
)
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
await self._session_ref(app_name, user_id, session_id).set(
|
||||||
|
{
|
||||||
|
"app_name": app_name,
|
||||||
|
"user_id": user_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
"state": session_state or {},
|
||||||
|
"last_update_time": now,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
app_state = await self._get_app_state(app_name)
|
||||||
|
user_state = await self._get_user_state(app_name, user_id)
|
||||||
|
merged = self._merge_state(app_state, user_state, session_state or {})
|
||||||
|
|
||||||
|
return Session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
id=session_id,
|
||||||
|
state=merged,
|
||||||
|
last_update_time=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def get_session(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
app_name: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
config: Optional[GetSessionConfig] = None,
|
||||||
|
) -> Optional[Session]:
|
||||||
|
snap = await self._session_ref(app_name, user_id, session_id).get()
|
||||||
|
if not snap.exists:
|
||||||
|
return None
|
||||||
|
|
||||||
|
session_data = snap.to_dict()
|
||||||
|
|
||||||
|
# Build events query
|
||||||
|
events_ref = self._events_col(app_name, user_id, session_id)
|
||||||
|
query = events_ref
|
||||||
|
if config and config.after_timestamp:
|
||||||
|
query = query.where("timestamp", ">=", config.after_timestamp)
|
||||||
|
query = query.order_by("timestamp")
|
||||||
|
|
||||||
|
event_docs = await query.get()
|
||||||
|
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||||
|
|
||||||
|
if config and config.num_recent_events:
|
||||||
|
events = events[-config.num_recent_events :]
|
||||||
|
|
||||||
|
# Merge scoped state
|
||||||
|
app_state = await self._get_app_state(app_name)
|
||||||
|
user_state = await self._get_user_state(app_name, user_id)
|
||||||
|
merged = self._merge_state(
|
||||||
|
app_state, user_state, session_data.get("state", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
return Session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
id=session_id,
|
||||||
|
state=merged,
|
||||||
|
events=events,
|
||||||
|
last_update_time=session_data.get("last_update_time", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def list_sessions(
|
||||||
|
self, *, app_name: str, user_id: Optional[str] = None
|
||||||
|
) -> ListSessionsResponse:
|
||||||
|
query = self._db.collection(f"{self._prefix}_sessions").where(
|
||||||
|
"app_name", "==", app_name
|
||||||
|
)
|
||||||
|
if user_id is not None:
|
||||||
|
query = query.where("user_id", "==", user_id)
|
||||||
|
|
||||||
|
docs = await query.get()
|
||||||
|
if not docs:
|
||||||
|
return ListSessionsResponse()
|
||||||
|
|
||||||
|
# Pre-fetch app state (shared across all sessions in this app)
|
||||||
|
app_state = await self._get_app_state(app_name)
|
||||||
|
|
||||||
|
# Cache user states to avoid repeated reads
|
||||||
|
user_state_cache: dict[str, dict[str, Any]] = {}
|
||||||
|
sessions: list[Session] = []
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
data = doc.to_dict()
|
||||||
|
s_user_id = data["user_id"]
|
||||||
|
|
||||||
|
if s_user_id not in user_state_cache:
|
||||||
|
user_state_cache[s_user_id] = await self._get_user_state(
|
||||||
|
app_name, s_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
merged = self._merge_state(
|
||||||
|
app_state,
|
||||||
|
user_state_cache[s_user_id],
|
||||||
|
data.get("state", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
sessions.append(
|
||||||
|
Session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=s_user_id,
|
||||||
|
id=data["session_id"],
|
||||||
|
state=merged,
|
||||||
|
events=[],
|
||||||
|
last_update_time=data.get("last_update_time", 0.0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ListSessionsResponse(sessions=sessions)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def delete_session(
|
||||||
|
self, *, app_name: str, user_id: str, session_id: str
|
||||||
|
) -> None:
|
||||||
|
ref = self._session_ref(app_name, user_id, session_id)
|
||||||
|
await self._db.recursive_delete(ref)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def append_event(self, session: Session, event: Event) -> Event:
|
||||||
|
if event.partial:
|
||||||
|
return event
|
||||||
|
|
||||||
|
app_name = session.app_name
|
||||||
|
user_id = session.user_id
|
||||||
|
session_id = session.id
|
||||||
|
|
||||||
|
# Base class: strips temp state, applies delta to in-memory session,
|
||||||
|
# appends event to session.events
|
||||||
|
event = await super().append_event(session=session, event=event)
|
||||||
|
session.last_update_time = event.timestamp
|
||||||
|
|
||||||
|
# Persist event document
|
||||||
|
event_data = event.model_dump(mode="json", exclude_none=True)
|
||||||
|
await self._events_col(app_name, user_id, session_id).document(
|
||||||
|
event.id
|
||||||
|
).set(event_data)
|
||||||
|
|
||||||
|
# Persist state deltas
|
||||||
|
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||||
|
|
||||||
|
if event.actions and event.actions.state_delta:
|
||||||
|
state_deltas = _session_util.extract_state_delta(
|
||||||
|
event.actions.state_delta
|
||||||
|
)
|
||||||
|
|
||||||
|
if state_deltas["app"]:
|
||||||
|
await self._app_state_ref(app_name).set(
|
||||||
|
state_deltas["app"], merge=True
|
||||||
|
)
|
||||||
|
if state_deltas["user"]:
|
||||||
|
await self._user_state_ref(app_name, user_id).set(
|
||||||
|
state_deltas["user"], merge=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if state_deltas["session"]:
|
||||||
|
# Use FieldPath to safely update individual state keys
|
||||||
|
# (handles keys that contain dots)
|
||||||
|
field_updates: dict[str | FieldPath, Any] = {
|
||||||
|
FieldPath("state", k): v
|
||||||
|
for k, v in state_deltas["session"].items()
|
||||||
|
}
|
||||||
|
field_updates["last_update_time"] = event.timestamp
|
||||||
|
await session_ref.update(field_updates)
|
||||||
|
else:
|
||||||
|
await session_ref.update(
|
||||||
|
{"last_update_time": event.timestamp}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await session_ref.update({"last_update_time": event.timestamp})
|
||||||
|
|
||||||
|
return event
|
||||||
0
src/adk_firestore_sessionmanager/py.typed
Normal file
0
src/adk_firestore_sessionmanager/py.typed
Normal file
Reference in New Issue
Block a user