Add compaction flow lock
This commit is contained in:
@@ -20,6 +20,7 @@ from google.adk.sessions.base_session_service import (
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.sessions.state import State
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
from google.cloud.firestore_v1.async_transaction import async_transactional
|
||||
from google.cloud.firestore_v1.base_query import FieldFilter
|
||||
from google.cloud.firestore_v1.field_path import FieldPath
|
||||
from google.genai.types import Content, Part
|
||||
@@ -27,6 +28,22 @@ from typing_extensions import override
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
_COMPACTION_LOCK_TTL = 300 # seconds
|
||||
|
||||
|
||||
@async_transactional
|
||||
async def _try_claim_compaction_txn(transaction, session_ref):
|
||||
"""Atomically claim the compaction lock if it is free or stale."""
|
||||
snapshot = await session_ref.get(transaction=transaction)
|
||||
if not snapshot.exists:
|
||||
return False
|
||||
data = snapshot.to_dict() or {}
|
||||
lock_time = data.get("compaction_lock")
|
||||
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
|
||||
return False
|
||||
transaction.update(session_ref, {"compaction_lock": time.time()})
|
||||
return True
|
||||
|
||||
|
||||
class FirestoreSessionService(BaseSessionService):
|
||||
"""A Firestore-backed implementation of BaseSessionService.
|
||||
@@ -64,6 +81,8 @@ class FirestoreSessionService(BaseSessionService):
|
||||
self._compaction_model = compaction_model
|
||||
self._compaction_keep_recent = compaction_keep_recent
|
||||
self._genai_client = genai_client
|
||||
self._compaction_locks: dict[str, asyncio.Lock] = {}
|
||||
self._active_tasks: set[asyncio.Task] = set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Document-reference helpers
|
||||
@@ -192,6 +211,12 @@ class FirestoreSessionService(BaseSessionService):
|
||||
logger.exception("Compaction summary generation failed; skipping.")
|
||||
return
|
||||
|
||||
# Write summary BEFORE deleting events so a crash between the two
|
||||
# steps leaves safe duplication rather than data loss.
|
||||
await self._session_ref(app_name, user_id, session_id).update(
|
||||
{"conversation_summary": summary}
|
||||
)
|
||||
|
||||
docs_to_delete = event_docs[: -self._compaction_keep_recent]
|
||||
for i in range(0, len(docs_to_delete), 500):
|
||||
batch = self._db.batch()
|
||||
@@ -199,10 +224,6 @@ class FirestoreSessionService(BaseSessionService):
|
||||
batch.delete(doc.reference)
|
||||
await batch.commit()
|
||||
|
||||
await self._session_ref(app_name, user_id, session_id).update(
|
||||
{"conversation_summary": summary}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Compacted session %s: summarised %d events, kept %d.",
|
||||
session_id,
|
||||
@@ -210,6 +231,57 @@ class FirestoreSessionService(BaseSessionService):
|
||||
self._compaction_keep_recent,
|
||||
)
|
||||
|
||||
async def _guarded_compact(self, session: Session) -> None:
|
||||
"""Run compaction in the background with per-session locking."""
|
||||
key = f"{session.app_name}__{session.user_id}__{session.id}"
|
||||
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
|
||||
|
||||
if lock.locked():
|
||||
logger.debug(
|
||||
"Compaction already running locally for %s; skipping.", key
|
||||
)
|
||||
return
|
||||
|
||||
async with lock:
|
||||
session_ref = self._session_ref(
|
||||
session.app_name, session.user_id, session.id
|
||||
)
|
||||
try:
|
||||
transaction = self._db.transaction()
|
||||
claimed = await _try_claim_compaction_txn(
|
||||
transaction, session_ref
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to claim compaction lock for %s", key
|
||||
)
|
||||
return
|
||||
|
||||
if not claimed:
|
||||
logger.debug(
|
||||
"Compaction lock held by another instance for %s;"
|
||||
" skipping.",
|
||||
key,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self._compact_session(session)
|
||||
except Exception:
|
||||
logger.exception("Background compaction failed for %s", key)
|
||||
finally:
|
||||
try:
|
||||
await session_ref.update({"compaction_lock": None})
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to release compaction lock for %s", key
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Await all in-flight compaction tasks. Call before shutdown."""
|
||||
if self._active_tasks:
|
||||
await asyncio.gather(*self._active_tasks, return_exceptions=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# BaseSessionService implementation
|
||||
# ------------------------------------------------------------------
|
||||
@@ -513,7 +585,9 @@ class FirestoreSessionService(BaseSessionService):
|
||||
event.usage_metadata.total_token_count,
|
||||
self._compaction_threshold,
|
||||
)
|
||||
await self._compact_session(session)
|
||||
task = asyncio.create_task(self._guarded_compact(session))
|
||||
self._active_tasks.add(task)
|
||||
task.add_done_callback(self._active_tasks.discard)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user