Pool async calls
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -239,28 +240,38 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
user_state_delta = state_deltas["user"]
|
user_state_delta = state_deltas["user"]
|
||||||
session_state = state_deltas["session"]
|
session_state = state_deltas["session"]
|
||||||
|
|
||||||
|
write_coros: list = []
|
||||||
if app_state_delta:
|
if app_state_delta:
|
||||||
await self._app_state_ref(app_name).set(
|
write_coros.append(
|
||||||
app_state_delta, merge=True
|
self._app_state_ref(app_name).set(
|
||||||
|
app_state_delta, merge=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if user_state_delta:
|
if user_state_delta:
|
||||||
await self._user_state_ref(app_name, user_id).set(
|
write_coros.append(
|
||||||
user_state_delta, merge=True
|
self._user_state_ref(app_name, user_id).set(
|
||||||
|
user_state_delta, merge=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
await self._session_ref(app_name, user_id, session_id).set(
|
write_coros.append(
|
||||||
{
|
self._session_ref(app_name, user_id, session_id).set(
|
||||||
"app_name": app_name,
|
{
|
||||||
"user_id": user_id,
|
"app_name": app_name,
|
||||||
"session_id": session_id,
|
"user_id": user_id,
|
||||||
"state": session_state or {},
|
"session_id": session_id,
|
||||||
"last_update_time": now,
|
"state": session_state or {},
|
||||||
}
|
"last_update_time": now,
|
||||||
|
}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
await asyncio.gather(*write_coros)
|
||||||
|
|
||||||
app_state = await self._get_app_state(app_name)
|
app_state, user_state = await asyncio.gather(
|
||||||
user_state = await self._get_user_state(app_name, user_id)
|
self._get_app_state(app_name),
|
||||||
|
self._get_user_state(app_name, user_id),
|
||||||
|
)
|
||||||
merged = self._merge_state(app_state, user_state, session_state or {})
|
merged = self._merge_state(app_state, user_state, session_state or {})
|
||||||
|
|
||||||
return Session(
|
return Session(
|
||||||
@@ -293,7 +304,11 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp))
|
query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp))
|
||||||
query = query.order_by("timestamp")
|
query = query.order_by("timestamp")
|
||||||
|
|
||||||
event_docs = await query.get()
|
event_docs, app_state, user_state = await asyncio.gather(
|
||||||
|
query.get(),
|
||||||
|
self._get_app_state(app_name),
|
||||||
|
self._get_user_state(app_name, user_id),
|
||||||
|
)
|
||||||
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||||
|
|
||||||
if config and config.num_recent_events:
|
if config and config.num_recent_events:
|
||||||
@@ -341,8 +356,6 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
events = [summary_event, ack_event] + events
|
events = [summary_event, ack_event] + events
|
||||||
|
|
||||||
# Merge scoped state
|
# Merge scoped state
|
||||||
app_state = await self._get_app_state(app_name)
|
|
||||||
user_state = await self._get_user_state(app_name, user_id)
|
|
||||||
merged = self._merge_state(
|
merged = self._merge_state(
|
||||||
app_state, user_state, session_data.get("state", {})
|
app_state, user_state, session_data.get("state", {})
|
||||||
)
|
)
|
||||||
@@ -370,22 +383,22 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
if not docs:
|
if not docs:
|
||||||
return ListSessionsResponse()
|
return ListSessionsResponse()
|
||||||
|
|
||||||
# Pre-fetch app state (shared across all sessions in this app)
|
doc_dicts = [doc.to_dict() for doc in docs]
|
||||||
app_state = await self._get_app_state(app_name)
|
|
||||||
|
# Pre-fetch app state and all distinct user states in parallel
|
||||||
|
unique_user_ids = list({d["user_id"] for d in doc_dicts})
|
||||||
|
app_state, *user_states = await asyncio.gather(
|
||||||
|
self._get_app_state(app_name),
|
||||||
|
*(
|
||||||
|
self._get_user_state(app_name, uid)
|
||||||
|
for uid in unique_user_ids
|
||||||
|
),
|
||||||
|
)
|
||||||
|
user_state_cache = dict(zip(unique_user_ids, user_states))
|
||||||
|
|
||||||
# Cache user states to avoid repeated reads
|
|
||||||
user_state_cache: dict[str, dict[str, Any]] = {}
|
|
||||||
sessions: list[Session] = []
|
sessions: list[Session] = []
|
||||||
|
for data in doc_dicts:
|
||||||
for doc in docs:
|
|
||||||
data = doc.to_dict()
|
|
||||||
s_user_id = data["user_id"]
|
s_user_id = data["user_id"]
|
||||||
|
|
||||||
if s_user_id not in user_state_cache:
|
|
||||||
user_state_cache[s_user_id] = await self._get_user_state(
|
|
||||||
app_name, s_user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
merged = self._merge_state(
|
merged = self._merge_state(
|
||||||
app_state,
|
app_state,
|
||||||
user_state_cache[s_user_id],
|
user_state_cache[s_user_id],
|
||||||
@@ -442,13 +455,18 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
event.actions.state_delta
|
event.actions.state_delta
|
||||||
)
|
)
|
||||||
|
|
||||||
|
write_coros: list = []
|
||||||
if state_deltas["app"]:
|
if state_deltas["app"]:
|
||||||
await self._app_state_ref(app_name).set(
|
write_coros.append(
|
||||||
state_deltas["app"], merge=True
|
self._app_state_ref(app_name).set(
|
||||||
|
state_deltas["app"], merge=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if state_deltas["user"]:
|
if state_deltas["user"]:
|
||||||
await self._user_state_ref(app_name, user_id).set(
|
write_coros.append(
|
||||||
state_deltas["user"], merge=True
|
self._user_state_ref(app_name, user_id).set(
|
||||||
|
state_deltas["user"], merge=True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if state_deltas["session"]:
|
if state_deltas["session"]:
|
||||||
@@ -457,11 +475,13 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
for k, v in state_deltas["session"].items()
|
for k, v in state_deltas["session"].items()
|
||||||
}
|
}
|
||||||
field_updates["last_update_time"] = event.timestamp
|
field_updates["last_update_time"] = event.timestamp
|
||||||
await session_ref.update(field_updates)
|
write_coros.append(session_ref.update(field_updates))
|
||||||
else:
|
else:
|
||||||
await session_ref.update(
|
write_coros.append(
|
||||||
{"last_update_time": event.timestamp}
|
session_ref.update({"last_update_time": event.timestamp})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await asyncio.gather(*write_coros)
|
||||||
else:
|
else:
|
||||||
await session_ref.update({"last_update_time": event.timestamp})
|
await session_ref.update({"last_update_time": event.timestamp})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user