Pool async calls

This commit is contained in:
ajac-zero
2026-02-21 22:45:12 -06:00
parent dff25bcff0
commit 5f2f0474a5

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import logging
import time
from typing import Any, Optional
@@ -239,28 +240,38 @@ class FirestoreSessionService(BaseSessionService):
user_state_delta = state_deltas["user"]
session_state = state_deltas["session"]
write_coros: list = []
if app_state_delta:
await self._app_state_ref(app_name).set(
app_state_delta, merge=True
write_coros.append(
self._app_state_ref(app_name).set(
app_state_delta, merge=True
)
)
if user_state_delta:
await self._user_state_ref(app_name, user_id).set(
user_state_delta, merge=True
write_coros.append(
self._user_state_ref(app_name, user_id).set(
user_state_delta, merge=True
)
)
now = time.time()
await self._session_ref(app_name, user_id, session_id).set(
{
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
"state": session_state or {},
"last_update_time": now,
}
write_coros.append(
self._session_ref(app_name, user_id, session_id).set(
{
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
"state": session_state or {},
"last_update_time": now,
}
)
)
await asyncio.gather(*write_coros)
app_state = await self._get_app_state(app_name)
user_state = await self._get_user_state(app_name, user_id)
app_state, user_state = await asyncio.gather(
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
merged = self._merge_state(app_state, user_state, session_state or {})
return Session(
@@ -293,7 +304,11 @@ class FirestoreSessionService(BaseSessionService):
query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp))
query = query.order_by("timestamp")
event_docs = await query.get()
event_docs, app_state, user_state = await asyncio.gather(
query.get(),
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
if config and config.num_recent_events:
@@ -341,8 +356,6 @@ class FirestoreSessionService(BaseSessionService):
events = [summary_event, ack_event] + events
# Merge scoped state
app_state = await self._get_app_state(app_name)
user_state = await self._get_user_state(app_name, user_id)
merged = self._merge_state(
app_state, user_state, session_data.get("state", {})
)
@@ -370,22 +383,22 @@ class FirestoreSessionService(BaseSessionService):
if not docs:
return ListSessionsResponse()
# Pre-fetch app state (shared across all sessions in this app)
app_state = await self._get_app_state(app_name)
doc_dicts = [doc.to_dict() for doc in docs]
# Pre-fetch app state and all distinct user states in parallel
unique_user_ids = list({d["user_id"] for d in doc_dicts})
app_state, *user_states = await asyncio.gather(
self._get_app_state(app_name),
*(
self._get_user_state(app_name, uid)
for uid in unique_user_ids
),
)
user_state_cache = dict(zip(unique_user_ids, user_states))
# Cache user states to avoid repeated reads
user_state_cache: dict[str, dict[str, Any]] = {}
sessions: list[Session] = []
for doc in docs:
data = doc.to_dict()
for data in doc_dicts:
s_user_id = data["user_id"]
if s_user_id not in user_state_cache:
user_state_cache[s_user_id] = await self._get_user_state(
app_name, s_user_id
)
merged = self._merge_state(
app_state,
user_state_cache[s_user_id],
@@ -442,13 +455,18 @@ class FirestoreSessionService(BaseSessionService):
event.actions.state_delta
)
write_coros: list = []
if state_deltas["app"]:
await self._app_state_ref(app_name).set(
state_deltas["app"], merge=True
write_coros.append(
self._app_state_ref(app_name).set(
state_deltas["app"], merge=True
)
)
if state_deltas["user"]:
await self._user_state_ref(app_name, user_id).set(
state_deltas["user"], merge=True
write_coros.append(
self._user_state_ref(app_name, user_id).set(
state_deltas["user"], merge=True
)
)
if state_deltas["session"]:
@@ -457,11 +475,13 @@ class FirestoreSessionService(BaseSessionService):
for k, v in state_deltas["session"].items()
}
field_updates["last_update_time"] = event.timestamp
await session_ref.update(field_updates)
write_coros.append(session_ref.update(field_updates))
else:
await session_ref.update(
{"last_update_time": event.timestamp}
write_coros.append(
session_ref.update({"last_update_time": event.timestamp})
)
await asyncio.gather(*write_coros)
else:
await session_ref.update({"last_update_time": event.timestamp})