Compare commits
1 Commits
main
...
7926d9881c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7926d9881c |
@@ -78,7 +78,9 @@ class NotificationDocument(BaseModel):
|
||||
class NotificationBackend(Protocol):
|
||||
"""Backend-agnostic interface for notification storage."""
|
||||
|
||||
async def get_recent_notifications(self, phone_number: str) -> list[Notification]:
|
||||
async def get_recent_notifications(
|
||||
self, phone_number: str
|
||||
) -> list[Notification]:
|
||||
"""Return recent notifications for *phone_number*."""
|
||||
...
|
||||
|
||||
@@ -111,7 +113,9 @@ class FirestoreNotificationBackend:
|
||||
self._max_to_notify = max_to_notify
|
||||
self._window_hours = window_hours
|
||||
|
||||
async def get_recent_notifications(self, phone_number: str) -> list[Notification]:
|
||||
async def get_recent_notifications(
|
||||
self, phone_number: str
|
||||
) -> list[Notification]:
|
||||
"""Get recent notifications for a user.
|
||||
|
||||
Retrieves notifications created within the configured time window,
|
||||
@@ -144,7 +148,9 @@ class FirestoreNotificationBackend:
|
||||
cutoff = time.time() - (self._window_hours * 3600)
|
||||
|
||||
parsed = [
|
||||
n for n in document.notificaciones if n.timestamp_creacion >= cutoff
|
||||
n
|
||||
for n in document.notificaciones
|
||||
if n.timestamp_creacion >= cutoff
|
||||
]
|
||||
|
||||
if not parsed:
|
||||
@@ -206,7 +212,9 @@ class RedisNotificationBackend:
|
||||
self._max_to_notify = max_to_notify
|
||||
self._window_hours = window_hours
|
||||
|
||||
async def get_recent_notifications(self, phone_number: str) -> list[Notification]:
|
||||
async def get_recent_notifications(
|
||||
self, phone_number: str
|
||||
) -> list[Notification]:
|
||||
"""Get recent notifications for a user from Redis.
|
||||
|
||||
Reads from the ``notification:{phone}`` key, parses the JSON
|
||||
@@ -238,7 +246,9 @@ class RedisNotificationBackend:
|
||||
cutoff = time.time() - (self._window_hours * 3600)
|
||||
|
||||
parsed = [
|
||||
n for n in document.notificaciones if n.timestamp_creacion >= cutoff
|
||||
n
|
||||
for n in document.notificaciones
|
||||
if n.timestamp_creacion >= cutoff
|
||||
]
|
||||
|
||||
if not parsed:
|
||||
|
||||
@@ -6,7 +6,6 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||
@@ -103,24 +102,6 @@ class FirestoreSessionService(BaseSessionService):
|
||||
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
|
||||
return self._session_ref(app_name, user_id, session_id).collection("events")
|
||||
|
||||
@staticmethod
|
||||
def _timestamp_to_float(value: Any, default: float = 0.0) -> float:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if hasattr(value, "timestamp"):
|
||||
try:
|
||||
return float(value.timestamp())
|
||||
except (
|
||||
TypeError,
|
||||
ValueError,
|
||||
OSError,
|
||||
OverflowError,
|
||||
) as exc: # pragma: no cover
|
||||
logger.debug("Failed to convert timestamp %r: %s", value, exc)
|
||||
return default
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State helpers
|
||||
# ------------------------------------------------------------------
|
||||
@@ -190,7 +171,7 @@ class FirestoreSessionService(BaseSessionService):
|
||||
)
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
now = time.time()
|
||||
write_coros.append(
|
||||
self._session_ref(app_name, user_id, session_id).set(
|
||||
{
|
||||
@@ -215,7 +196,7 @@ class FirestoreSessionService(BaseSessionService):
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=merged,
|
||||
last_update_time=now.timestamp(),
|
||||
last_update_time=now,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -302,9 +283,7 @@ class FirestoreSessionService(BaseSessionService):
|
||||
id=session_id,
|
||||
state=merged,
|
||||
events=events,
|
||||
last_update_time=self._timestamp_to_float(
|
||||
session_data.get("last_update_time"), 0.0
|
||||
),
|
||||
last_update_time=session_data.get("last_update_time", 0.0),
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -347,9 +326,7 @@ class FirestoreSessionService(BaseSessionService):
|
||||
id=data["session_id"],
|
||||
state=merged,
|
||||
events=[],
|
||||
last_update_time=self._timestamp_to_float(
|
||||
data.get("last_update_time"), 0.0
|
||||
),
|
||||
last_update_time=data.get("last_update_time", 0.0),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -389,8 +366,6 @@ class FirestoreSessionService(BaseSessionService):
|
||||
# Persist state deltas
|
||||
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||
|
||||
last_update_dt = datetime.fromtimestamp(event.timestamp, UTC)
|
||||
|
||||
if event.actions and event.actions.state_delta:
|
||||
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
|
||||
|
||||
@@ -411,16 +386,16 @@ class FirestoreSessionService(BaseSessionService):
|
||||
FieldPath("state", k).to_api_repr(): v
|
||||
for k, v in state_deltas["session"].items()
|
||||
}
|
||||
field_updates["last_update_time"] = last_update_dt
|
||||
field_updates["last_update_time"] = event.timestamp
|
||||
write_coros.append(session_ref.update(field_updates))
|
||||
else:
|
||||
write_coros.append(
|
||||
session_ref.update({"last_update_time": last_update_dt})
|
||||
session_ref.update({"last_update_time": event.timestamp})
|
||||
)
|
||||
|
||||
await asyncio.gather(*write_coros)
|
||||
else:
|
||||
await session_ref.update({"last_update_time": last_update_dt})
|
||||
await session_ref.update({"last_update_time": event.timestamp})
|
||||
|
||||
# Log token usage
|
||||
if event.usage_metadata:
|
||||
|
||||
Reference in New Issue
Block a user