Pool async calls
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
@@ -239,28 +240,38 @@ class FirestoreSessionService(BaseSessionService):
|
||||
user_state_delta = state_deltas["user"]
|
||||
session_state = state_deltas["session"]
|
||||
|
||||
write_coros: list = []
|
||||
if app_state_delta:
|
||||
await self._app_state_ref(app_name).set(
|
||||
app_state_delta, merge=True
|
||||
write_coros.append(
|
||||
self._app_state_ref(app_name).set(
|
||||
app_state_delta, merge=True
|
||||
)
|
||||
)
|
||||
if user_state_delta:
|
||||
await self._user_state_ref(app_name, user_id).set(
|
||||
user_state_delta, merge=True
|
||||
write_coros.append(
|
||||
self._user_state_ref(app_name, user_id).set(
|
||||
user_state_delta, merge=True
|
||||
)
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
await self._session_ref(app_name, user_id, session_id).set(
|
||||
{
|
||||
"app_name": app_name,
|
||||
"user_id": user_id,
|
||||
"session_id": session_id,
|
||||
"state": session_state or {},
|
||||
"last_update_time": now,
|
||||
}
|
||||
write_coros.append(
|
||||
self._session_ref(app_name, user_id, session_id).set(
|
||||
{
|
||||
"app_name": app_name,
|
||||
"user_id": user_id,
|
||||
"session_id": session_id,
|
||||
"state": session_state or {},
|
||||
"last_update_time": now,
|
||||
}
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*write_coros)
|
||||
|
||||
app_state = await self._get_app_state(app_name)
|
||||
user_state = await self._get_user_state(app_name, user_id)
|
||||
app_state, user_state = await asyncio.gather(
|
||||
self._get_app_state(app_name),
|
||||
self._get_user_state(app_name, user_id),
|
||||
)
|
||||
merged = self._merge_state(app_state, user_state, session_state or {})
|
||||
|
||||
return Session(
|
||||
@@ -293,7 +304,11 @@ class FirestoreSessionService(BaseSessionService):
|
||||
query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp))
|
||||
query = query.order_by("timestamp")
|
||||
|
||||
event_docs = await query.get()
|
||||
event_docs, app_state, user_state = await asyncio.gather(
|
||||
query.get(),
|
||||
self._get_app_state(app_name),
|
||||
self._get_user_state(app_name, user_id),
|
||||
)
|
||||
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||
|
||||
if config and config.num_recent_events:
|
||||
@@ -341,8 +356,6 @@ class FirestoreSessionService(BaseSessionService):
|
||||
events = [summary_event, ack_event] + events
|
||||
|
||||
# Merge scoped state
|
||||
app_state = await self._get_app_state(app_name)
|
||||
user_state = await self._get_user_state(app_name, user_id)
|
||||
merged = self._merge_state(
|
||||
app_state, user_state, session_data.get("state", {})
|
||||
)
|
||||
@@ -370,22 +383,22 @@ class FirestoreSessionService(BaseSessionService):
|
||||
if not docs:
|
||||
return ListSessionsResponse()
|
||||
|
||||
# Pre-fetch app state (shared across all sessions in this app)
|
||||
app_state = await self._get_app_state(app_name)
|
||||
doc_dicts = [doc.to_dict() for doc in docs]
|
||||
|
||||
# Pre-fetch app state and all distinct user states in parallel
|
||||
unique_user_ids = list({d["user_id"] for d in doc_dicts})
|
||||
app_state, *user_states = await asyncio.gather(
|
||||
self._get_app_state(app_name),
|
||||
*(
|
||||
self._get_user_state(app_name, uid)
|
||||
for uid in unique_user_ids
|
||||
),
|
||||
)
|
||||
user_state_cache = dict(zip(unique_user_ids, user_states))
|
||||
|
||||
# Cache user states to avoid repeated reads
|
||||
user_state_cache: dict[str, dict[str, Any]] = {}
|
||||
sessions: list[Session] = []
|
||||
|
||||
for doc in docs:
|
||||
data = doc.to_dict()
|
||||
for data in doc_dicts:
|
||||
s_user_id = data["user_id"]
|
||||
|
||||
if s_user_id not in user_state_cache:
|
||||
user_state_cache[s_user_id] = await self._get_user_state(
|
||||
app_name, s_user_id
|
||||
)
|
||||
|
||||
merged = self._merge_state(
|
||||
app_state,
|
||||
user_state_cache[s_user_id],
|
||||
@@ -442,13 +455,18 @@ class FirestoreSessionService(BaseSessionService):
|
||||
event.actions.state_delta
|
||||
)
|
||||
|
||||
write_coros: list = []
|
||||
if state_deltas["app"]:
|
||||
await self._app_state_ref(app_name).set(
|
||||
state_deltas["app"], merge=True
|
||||
write_coros.append(
|
||||
self._app_state_ref(app_name).set(
|
||||
state_deltas["app"], merge=True
|
||||
)
|
||||
)
|
||||
if state_deltas["user"]:
|
||||
await self._user_state_ref(app_name, user_id).set(
|
||||
state_deltas["user"], merge=True
|
||||
write_coros.append(
|
||||
self._user_state_ref(app_name, user_id).set(
|
||||
state_deltas["user"], merge=True
|
||||
)
|
||||
)
|
||||
|
||||
if state_deltas["session"]:
|
||||
@@ -457,11 +475,13 @@ class FirestoreSessionService(BaseSessionService):
|
||||
for k, v in state_deltas["session"].items()
|
||||
}
|
||||
field_updates["last_update_time"] = event.timestamp
|
||||
await session_ref.update(field_updates)
|
||||
write_coros.append(session_ref.update(field_updates))
|
||||
else:
|
||||
await session_ref.update(
|
||||
{"last_update_time": event.timestamp}
|
||||
write_coros.append(
|
||||
session_ref.update({"last_update_time": event.timestamp})
|
||||
)
|
||||
|
||||
await asyncio.gather(*write_coros)
|
||||
else:
|
||||
await session_ref.update({"last_update_time": event.timestamp})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user