Compare commits

1 Commits

Author SHA1 Message Date
Anibal Angulo
7926d9881c Add notification model
Some checks failed
CI / ci (pull_request) Failing after 12s
2026-03-10 23:47:11 +00:00
2 changed files with 24 additions and 41 deletions

View File

@@ -78,7 +78,9 @@ class NotificationDocument(BaseModel):
class NotificationBackend(Protocol): class NotificationBackend(Protocol):
"""Backend-agnostic interface for notification storage.""" """Backend-agnostic interface for notification storage."""
async def get_recent_notifications(self, phone_number: str) -> list[Notification]: async def get_recent_notifications(
self, phone_number: str
) -> list[Notification]:
"""Return recent notifications for *phone_number*.""" """Return recent notifications for *phone_number*."""
... ...
@@ -111,7 +113,9 @@ class FirestoreNotificationBackend:
self._max_to_notify = max_to_notify self._max_to_notify = max_to_notify
self._window_hours = window_hours self._window_hours = window_hours
async def get_recent_notifications(self, phone_number: str) -> list[Notification]: async def get_recent_notifications(
self, phone_number: str
) -> list[Notification]:
"""Get recent notifications for a user. """Get recent notifications for a user.
Retrieves notifications created within the configured time window, Retrieves notifications created within the configured time window,
@@ -144,7 +148,9 @@ class FirestoreNotificationBackend:
cutoff = time.time() - (self._window_hours * 3600) cutoff = time.time() - (self._window_hours * 3600)
parsed = [ parsed = [
n for n in document.notificaciones if n.timestamp_creacion >= cutoff n
for n in document.notificaciones
if n.timestamp_creacion >= cutoff
] ]
if not parsed: if not parsed:
@@ -206,7 +212,9 @@ class RedisNotificationBackend:
self._max_to_notify = max_to_notify self._max_to_notify = max_to_notify
self._window_hours = window_hours self._window_hours = window_hours
async def get_recent_notifications(self, phone_number: str) -> list[Notification]: async def get_recent_notifications(
self, phone_number: str
) -> list[Notification]:
"""Get recent notifications for a user from Redis. """Get recent notifications for a user from Redis.
Reads from the ``notification:{phone}`` key, parses the JSON Reads from the ``notification:{phone}`` key, parses the JSON
@@ -238,7 +246,9 @@ class RedisNotificationBackend:
cutoff = time.time() - (self._window_hours * 3600) cutoff = time.time() - (self._window_hours * 3600)
parsed = [ parsed = [
n for n in document.notificaciones if n.timestamp_creacion >= cutoff n
for n in document.notificaciones
if n.timestamp_creacion >= cutoff
] ]
if not parsed: if not parsed:

View File

@@ -6,7 +6,6 @@ import asyncio
import logging import logging
import time import time
import uuid import uuid
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, override from typing import TYPE_CHECKING, Any, override
from google.adk.errors.already_exists_error import AlreadyExistsError from google.adk.errors.already_exists_error import AlreadyExistsError
@@ -43,9 +42,8 @@ class FirestoreSessionService(BaseSessionService):
adk_user_states/{app_name}__{user_id} adk_user_states/{app_name}__{user_id}
→ user-scoped state key/values → user-scoped state key/values
adk_sessions/{app_name}__{user_id} adk_sessions/{app_name}__{user_id}__{session_id}
{app_name, user_id, session_id, state: {…}, last_update_time} {app_name, user_id, session_id, state: {…}, last_update_time}
→ Single continuous session per user (session_id is ignored)
└─ events/{event_id} → serialised Event └─ events/{event_id} → serialised Event
""" """
@@ -97,32 +95,13 @@ class FirestoreSessionService(BaseSessionService):
) )
def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any: def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any:
# Single continuous session per user: use only user_id, ignore session_id
return self._db.collection(f"{self._prefix}_sessions").document( return self._db.collection(f"{self._prefix}_sessions").document(
f"{app_name}__{user_id}" f"{app_name}__{user_id}__{session_id}"
) )
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any: def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
return self._session_ref(app_name, user_id, session_id).collection("events") return self._session_ref(app_name, user_id, session_id).collection("events")
@staticmethod
def _timestamp_to_float(value: Any, default: float = 0.0) -> float:
if value is None:
return default
if isinstance(value, (int, float)):
return float(value)
if hasattr(value, "timestamp"):
try:
return float(value.timestamp())
except (
TypeError,
ValueError,
OSError,
OverflowError,
) as exc: # pragma: no cover
logger.debug("Failed to convert timestamp %r: %s", value, exc)
return default
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# State helpers # State helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -192,7 +171,7 @@ class FirestoreSessionService(BaseSessionService):
) )
) )
now = datetime.now(UTC) now = time.time()
write_coros.append( write_coros.append(
self._session_ref(app_name, user_id, session_id).set( self._session_ref(app_name, user_id, session_id).set(
{ {
@@ -217,7 +196,7 @@ class FirestoreSessionService(BaseSessionService):
user_id=user_id, user_id=user_id,
id=session_id, id=session_id,
state=merged, state=merged,
last_update_time=now.timestamp(), last_update_time=now,
) )
@override @override
@@ -304,9 +283,7 @@ class FirestoreSessionService(BaseSessionService):
id=session_id, id=session_id,
state=merged, state=merged,
events=events, events=events,
last_update_time=self._timestamp_to_float( last_update_time=session_data.get("last_update_time", 0.0),
session_data.get("last_update_time"), 0.0
),
) )
@override @override
@@ -349,9 +326,7 @@ class FirestoreSessionService(BaseSessionService):
id=data["session_id"], id=data["session_id"],
state=merged, state=merged,
events=[], events=[],
last_update_time=self._timestamp_to_float( last_update_time=data.get("last_update_time", 0.0),
data.get("last_update_time"), 0.0
),
) )
) )
@@ -391,8 +366,6 @@ class FirestoreSessionService(BaseSessionService):
# Persist state deltas # Persist state deltas
session_ref = self._session_ref(app_name, user_id, session_id) session_ref = self._session_ref(app_name, user_id, session_id)
last_update_dt = datetime.fromtimestamp(event.timestamp, UTC)
if event.actions and event.actions.state_delta: if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(event.actions.state_delta) state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
@@ -413,16 +386,16 @@ class FirestoreSessionService(BaseSessionService):
FieldPath("state", k).to_api_repr(): v FieldPath("state", k).to_api_repr(): v
for k, v in state_deltas["session"].items() for k, v in state_deltas["session"].items()
} }
field_updates["last_update_time"] = last_update_dt field_updates["last_update_time"] = event.timestamp
write_coros.append(session_ref.update(field_updates)) write_coros.append(session_ref.update(field_updates))
else: else:
write_coros.append( write_coros.append(
session_ref.update({"last_update_time": last_update_dt}) session_ref.update({"last_update_time": event.timestamp})
) )
await asyncio.gather(*write_coros) await asyncio.gather(*write_coros)
else: else:
await session_ref.update({"last_update_time": last_update_dt}) await session_ref.update({"last_update_time": event.timestamp})
# Log token usage # Log token usage
if event.usage_metadata: if event.usage_metadata: