diff --git a/src/adk_firestore_sessionmanager/firestore_session_service.py b/src/adk_firestore_sessionmanager/firestore_session_service.py index 07f61d2..b9102ce 100644 --- a/src/adk_firestore_sessionmanager/firestore_session_service.py +++ b/src/adk_firestore_sessionmanager/firestore_session_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging import time from typing import Any, Optional @@ -239,28 +240,38 @@ class FirestoreSessionService(BaseSessionService): user_state_delta = state_deltas["user"] session_state = state_deltas["session"] + write_coros: list = [] if app_state_delta: - await self._app_state_ref(app_name).set( - app_state_delta, merge=True + write_coros.append( + self._app_state_ref(app_name).set( + app_state_delta, merge=True + ) ) if user_state_delta: - await self._user_state_ref(app_name, user_id).set( - user_state_delta, merge=True + write_coros.append( + self._user_state_ref(app_name, user_id).set( + user_state_delta, merge=True + ) ) now = time.time() - await self._session_ref(app_name, user_id, session_id).set( - { - "app_name": app_name, - "user_id": user_id, - "session_id": session_id, - "state": session_state or {}, - "last_update_time": now, - } + write_coros.append( + self._session_ref(app_name, user_id, session_id).set( + { + "app_name": app_name, + "user_id": user_id, + "session_id": session_id, + "state": session_state or {}, + "last_update_time": now, + } + ) ) + await asyncio.gather(*write_coros) - app_state = await self._get_app_state(app_name) - user_state = await self._get_user_state(app_name, user_id) + app_state, user_state = await asyncio.gather( + self._get_app_state(app_name), + self._get_user_state(app_name, user_id), + ) merged = self._merge_state(app_state, user_state, session_state or {}) return Session( @@ -293,7 +304,11 @@ class FirestoreSessionService(BaseSessionService): query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp)) query = query.order_by("timestamp") - event_docs = await query.get() + event_docs, app_state, user_state = await asyncio.gather( + query.get(), + self._get_app_state(app_name), + self._get_user_state(app_name, user_id), + ) events = [Event.model_validate(doc.to_dict()) for doc in event_docs] if config and config.num_recent_events: @@ -341,8 +356,6 @@ class FirestoreSessionService(BaseSessionService): events = [summary_event, ack_event] + events # Merge scoped state - app_state = await self._get_app_state(app_name) - user_state = await self._get_user_state(app_name, user_id) merged = self._merge_state( app_state, user_state, session_data.get("state", {}) ) @@ -370,22 +383,22 @@ class FirestoreSessionService(BaseSessionService): if not docs: return ListSessionsResponse() - # Pre-fetch app state (shared across all sessions in this app) - app_state = await self._get_app_state(app_name) + doc_dicts = [doc.to_dict() for doc in docs] + + # Pre-fetch app state and all distinct user states in parallel + unique_user_ids = list({d["user_id"] for d in doc_dicts}) + app_state, *user_states = await asyncio.gather( + self._get_app_state(app_name), + *( + self._get_user_state(app_name, uid) + for uid in unique_user_ids + ), + ) + user_state_cache = dict(zip(unique_user_ids, user_states)) - # Cache user states to avoid repeated reads - user_state_cache: dict[str, dict[str, Any]] = {} sessions: list[Session] = [] - - for doc in docs: - data = doc.to_dict() + for data in doc_dicts: s_user_id = data["user_id"] - - if s_user_id not in user_state_cache: - user_state_cache[s_user_id] = await self._get_user_state( - app_name, s_user_id - ) - merged = self._merge_state( app_state, user_state_cache[s_user_id], @@ -442,13 +455,18 @@ class FirestoreSessionService(BaseSessionService): event.actions.state_delta ) + write_coros: list = [] if state_deltas["app"]: - await self._app_state_ref(app_name).set( - state_deltas["app"], merge=True + write_coros.append( + self._app_state_ref(app_name).set( + state_deltas["app"], merge=True + ) ) if state_deltas["user"]: - await self._user_state_ref(app_name, user_id).set( - state_deltas["user"], merge=True + write_coros.append( + self._user_state_ref(app_name, user_id).set( + state_deltas["user"], merge=True + ) ) if state_deltas["session"]: @@ -457,11 +475,13 @@ class FirestoreSessionService(BaseSessionService): for k, v in state_deltas["session"].items() } field_updates["last_update_time"] = event.timestamp - await session_ref.update(field_updates) + write_coros.append(session_ref.update(field_updates)) else: - await session_ref.update( - {"last_update_time": event.timestamp} + write_coros.append( + session_ref.update({"last_update_time": event.timestamp}) ) + + await asyncio.gather(*write_coros) else: await session_ref.update({"last_update_time": event.timestamp})