Pool async calls

This commit is contained in:
ajac-zero
2026-02-21 22:45:12 -06:00
parent dff25bcff0
commit 5f2f0474a5

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import time import time
from typing import Any, Optional from typing import Any, Optional
@@ -239,28 +240,38 @@ class FirestoreSessionService(BaseSessionService):
user_state_delta = state_deltas["user"] user_state_delta = state_deltas["user"]
session_state = state_deltas["session"] session_state = state_deltas["session"]
write_coros: list = []
if app_state_delta: if app_state_delta:
await self._app_state_ref(app_name).set( write_coros.append(
app_state_delta, merge=True self._app_state_ref(app_name).set(
app_state_delta, merge=True
)
) )
if user_state_delta: if user_state_delta:
await self._user_state_ref(app_name, user_id).set( write_coros.append(
user_state_delta, merge=True self._user_state_ref(app_name, user_id).set(
user_state_delta, merge=True
)
) )
now = time.time() now = time.time()
await self._session_ref(app_name, user_id, session_id).set( write_coros.append(
{ self._session_ref(app_name, user_id, session_id).set(
"app_name": app_name, {
"user_id": user_id, "app_name": app_name,
"session_id": session_id, "user_id": user_id,
"state": session_state or {}, "session_id": session_id,
"last_update_time": now, "state": session_state or {},
} "last_update_time": now,
}
)
) )
await asyncio.gather(*write_coros)
app_state = await self._get_app_state(app_name) app_state, user_state = await asyncio.gather(
user_state = await self._get_user_state(app_name, user_id) self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
merged = self._merge_state(app_state, user_state, session_state or {}) merged = self._merge_state(app_state, user_state, session_state or {})
return Session( return Session(
@@ -293,7 +304,11 @@ class FirestoreSessionService(BaseSessionService):
query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp)) query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp))
query = query.order_by("timestamp") query = query.order_by("timestamp")
event_docs = await query.get() event_docs, app_state, user_state = await asyncio.gather(
query.get(),
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
events = [Event.model_validate(doc.to_dict()) for doc in event_docs] events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
if config and config.num_recent_events: if config and config.num_recent_events:
@@ -341,8 +356,6 @@ class FirestoreSessionService(BaseSessionService):
events = [summary_event, ack_event] + events events = [summary_event, ack_event] + events
# Merge scoped state # Merge scoped state
app_state = await self._get_app_state(app_name)
user_state = await self._get_user_state(app_name, user_id)
merged = self._merge_state( merged = self._merge_state(
app_state, user_state, session_data.get("state", {}) app_state, user_state, session_data.get("state", {})
) )
@@ -370,22 +383,22 @@ class FirestoreSessionService(BaseSessionService):
if not docs: if not docs:
return ListSessionsResponse() return ListSessionsResponse()
# Pre-fetch app state (shared across all sessions in this app) doc_dicts = [doc.to_dict() for doc in docs]
app_state = await self._get_app_state(app_name)
# Pre-fetch app state and all distinct user states in parallel
unique_user_ids = list({d["user_id"] for d in doc_dicts})
app_state, *user_states = await asyncio.gather(
self._get_app_state(app_name),
*(
self._get_user_state(app_name, uid)
for uid in unique_user_ids
),
)
user_state_cache = dict(zip(unique_user_ids, user_states))
# Cache user states to avoid repeated reads
user_state_cache: dict[str, dict[str, Any]] = {}
sessions: list[Session] = [] sessions: list[Session] = []
for data in doc_dicts:
for doc in docs:
data = doc.to_dict()
s_user_id = data["user_id"] s_user_id = data["user_id"]
if s_user_id not in user_state_cache:
user_state_cache[s_user_id] = await self._get_user_state(
app_name, s_user_id
)
merged = self._merge_state( merged = self._merge_state(
app_state, app_state,
user_state_cache[s_user_id], user_state_cache[s_user_id],
@@ -442,13 +455,18 @@ class FirestoreSessionService(BaseSessionService):
event.actions.state_delta event.actions.state_delta
) )
write_coros: list = []
if state_deltas["app"]: if state_deltas["app"]:
await self._app_state_ref(app_name).set( write_coros.append(
state_deltas["app"], merge=True self._app_state_ref(app_name).set(
state_deltas["app"], merge=True
)
) )
if state_deltas["user"]: if state_deltas["user"]:
await self._user_state_ref(app_name, user_id).set( write_coros.append(
state_deltas["user"], merge=True self._user_state_ref(app_name, user_id).set(
state_deltas["user"], merge=True
)
) )
if state_deltas["session"]: if state_deltas["session"]:
@@ -457,11 +475,13 @@ class FirestoreSessionService(BaseSessionService):
for k, v in state_deltas["session"].items() for k, v in state_deltas["session"].items()
} }
field_updates["last_update_time"] = event.timestamp field_updates["last_update_time"] = event.timestamp
await session_ref.update(field_updates) write_coros.append(session_ref.update(field_updates))
else: else:
await session_ref.update( write_coros.append(
{"last_update_time": event.timestamp} session_ref.update({"last_update_time": event.timestamp})
) )
await asyncio.gather(*write_coros)
else: else:
await session_ref.update({"last_update_time": event.timestamp}) await session_ref.update({"last_update_time": event.timestamp})