From 547296fb2df9c9f3bb16fa18083f3a19975f0cfa Mon Sep 17 00:00:00 2001 From: ajac-zero Date: Sat, 21 Feb 2026 23:10:52 -0600 Subject: [PATCH] Add compaction flow lock --- .../firestore_session_service.py | 84 +++++- tests/test_compaction.py | 239 +++++++++++++++++- 2 files changed, 317 insertions(+), 6 deletions(-) diff --git a/src/adk_firestore_sessionmanager/firestore_session_service.py b/src/adk_firestore_sessionmanager/firestore_session_service.py index b9102ce..e968109 100644 --- a/src/adk_firestore_sessionmanager/firestore_session_service.py +++ b/src/adk_firestore_sessionmanager/firestore_session_service.py @@ -20,6 +20,7 @@ from google.adk.sessions.base_session_service import ( from google.adk.sessions.session import Session from google.adk.sessions.state import State from google.cloud.firestore_v1.async_client import AsyncClient +from google.cloud.firestore_v1.async_transaction import async_transactional from google.cloud.firestore_v1.base_query import FieldFilter from google.cloud.firestore_v1.field_path import FieldPath from google.genai.types import Content, Part @@ -27,6 +28,22 @@ from typing_extensions import override logger = logging.getLogger("google_adk." + __name__) +_COMPACTION_LOCK_TTL = 300 # seconds + + +@async_transactional +async def _try_claim_compaction_txn(transaction, session_ref): + """Atomically claim the compaction lock if it is free or stale.""" + snapshot = await session_ref.get(transaction=transaction) + if not snapshot.exists: + return False + data = snapshot.to_dict() or {} + lock_time = data.get("compaction_lock") + if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL: + return False + transaction.update(session_ref, {"compaction_lock": time.time()}) + return True + class FirestoreSessionService(BaseSessionService): """A Firestore-backed implementation of BaseSessionService. @@ -64,6 +81,8 @@ class FirestoreSessionService(BaseSessionService): self._compaction_model = compaction_model self._compaction_keep_recent = compaction_keep_recent self._genai_client = genai_client + self._compaction_locks: dict[str, asyncio.Lock] = {} + self._active_tasks: set[asyncio.Task] = set() # ------------------------------------------------------------------ # Document-reference helpers @@ -192,6 +211,12 @@ class FirestoreSessionService(BaseSessionService): logger.exception("Compaction summary generation failed; skipping.") return + # Write summary BEFORE deleting events so a crash between the two + # steps leaves safe duplication rather than data loss. + await self._session_ref(app_name, user_id, session_id).update( + {"conversation_summary": summary} + ) + docs_to_delete = event_docs[: -self._compaction_keep_recent] for i in range(0, len(docs_to_delete), 500): batch = self._db.batch() @@ -199,10 +224,6 @@ class FirestoreSessionService(BaseSessionService): batch.delete(doc.reference) await batch.commit() - await self._session_ref(app_name, user_id, session_id).update( - {"conversation_summary": summary} - ) - logger.info( "Compacted session %s: summarised %d events, kept %d.", session_id, @@ -210,6 +231,57 @@ class FirestoreSessionService(BaseSessionService): self._compaction_keep_recent, ) + async def _guarded_compact(self, session: Session) -> None: + """Run compaction in the background with per-session locking.""" + key = f"{session.app_name}__{session.user_id}__{session.id}" + lock = self._compaction_locks.setdefault(key, asyncio.Lock()) + + if lock.locked(): + logger.debug( + "Compaction already running locally for %s; skipping.", key + ) + return + + async with lock: + session_ref = self._session_ref( + session.app_name, session.user_id, session.id + ) + try: + transaction = self._db.transaction() + claimed = await _try_claim_compaction_txn( + transaction, session_ref + ) + except Exception: + logger.exception( + "Failed to claim compaction lock for %s", key + ) + return + + if not claimed: + logger.debug( + "Compaction lock held by another instance for %s;" + " skipping.", + key, + ) + return + + try: + await self._compact_session(session) + except Exception: + logger.exception("Background compaction failed for %s", key) + finally: + try: + await session_ref.update({"compaction_lock": None}) + except Exception: + logger.exception( + "Failed to release compaction lock for %s", key + ) + + async def close(self) -> None: + """Await all in-flight compaction tasks. Call before shutdown.""" + if self._active_tasks: + await asyncio.gather(*self._active_tasks, return_exceptions=True) + # ------------------------------------------------------------------ # BaseSessionService implementation # ------------------------------------------------------------------ @@ -513,7 +585,9 @@ class FirestoreSessionService(BaseSessionService): event.usage_metadata.total_token_count, self._compaction_threshold, ) - await self._compact_session(session) + task = asyncio.create_task(self._guarded_compact(session)) + self._active_tasks.add(task) + task.add_done_callback(self._active_tasks.discard) elapsed = time.monotonic() - t0 logger.info( diff --git a/tests/test_compaction.py b/tests/test_compaction.py index 1fc7bba..62ccc51 100644 --- a/tests/test_compaction.py +++ b/tests/test_compaction.py @@ -2,8 +2,9 @@ from __future__ import annotations +import asyncio import time -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import uuid import pytest @@ -14,6 +15,9 @@ from google.cloud.firestore_v1.async_client import AsyncClient from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part from adk_firestore_sessionmanager import FirestoreSessionService +from adk_firestore_sessionmanager.firestore_session_service import ( + _try_claim_compaction_txn, +) pytestmark = pytest.mark.asyncio @@ -97,6 +101,7 @@ class TestCompactionTrigger: ), ) await compaction_service.append_event(session, trigger_event) + await compaction_service.close() # Summary generation should have been called mock_genai_client.aio.models.generate_content.assert_called_once() @@ -280,3 +285,235 @@ class TestEventsToText: ] text = FirestoreSessionService._events_to_text(events) assert text == "" + + +# ------------------------------------------------------------------ +# Firestore distributed lock +# ------------------------------------------------------------------ + + +class TestCompactionLock: + async def test_claim_and_release( + self, compaction_service, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + session_ref = compaction_service._session_ref( + app_name, user_id, session.id + ) + + # Claim the lock + transaction = compaction_service._db.transaction() + claimed = await _try_claim_compaction_txn(transaction, session_ref) + assert claimed is True + + # Lock is now held — second claim should fail + transaction2 = compaction_service._db.transaction() + claimed2 = await _try_claim_compaction_txn(transaction2, session_ref) + assert claimed2 is False + + # Release the lock + await session_ref.update({"compaction_lock": None}) + + # Can claim again after release + transaction3 = compaction_service._db.transaction() + claimed3 = await _try_claim_compaction_txn(transaction3, session_ref) + assert claimed3 is True + + async def test_stale_lock_can_be_reclaimed( + self, compaction_service, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + session_ref = compaction_service._session_ref( + app_name, user_id, session.id + ) + + # Set a stale lock (older than TTL) + await session_ref.update({"compaction_lock": time.time() - 600}) + + # Should be able to reclaim a stale lock + transaction = compaction_service._db.transaction() + claimed = await _try_claim_compaction_txn(transaction, session_ref) + assert claimed is True + + async def test_claim_nonexistent_session(self, compaction_service): + ref = compaction_service._session_ref("no_app", "no_user", "no_id") + transaction = compaction_service._db.transaction() + claimed = await _try_claim_compaction_txn(transaction, ref) + assert claimed is False + + +# ------------------------------------------------------------------ +# Guarded compact +# ------------------------------------------------------------------ + + +class TestGuardedCompact: + async def test_local_lock_skips_concurrent( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + for i in range(5): + e = Event( + author="user", + content=Content( + role="user", parts=[Part(text=f"msg {i}")] + ), + timestamp=time.time() + i, + invocation_id=f"inv-{i}", + ) + await compaction_service.append_event(session, e) + + # Hold the in-process lock so _guarded_compact skips + key = f"{app_name}__{user_id}__{session.id}" + lock = compaction_service._compaction_locks.setdefault( + key, asyncio.Lock() + ) + async with lock: + await compaction_service._guarded_compact(session) + + mock_genai_client.aio.models.generate_content.assert_not_called() + + async def test_firestore_lock_held_skips( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + for i in range(5): + e = Event( + author="user", + content=Content( + role="user", parts=[Part(text=f"msg {i}")] + ), + timestamp=time.time() + i, + invocation_id=f"inv-{i}", + ) + await compaction_service.append_event(session, e) + + # Set a fresh Firestore lock (simulating another instance) + session_ref = compaction_service._session_ref( + app_name, user_id, session.id + ) + await session_ref.update({"compaction_lock": time.time()}) + + await compaction_service._guarded_compact(session) + + mock_genai_client.aio.models.generate_content.assert_not_called() + + async def test_claim_failure_logs_and_skips( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + + with patch( + "adk_firestore_sessionmanager.firestore_session_service" + "._try_claim_compaction_txn", + side_effect=RuntimeError("Firestore down"), + ): + await compaction_service._guarded_compact(session) + + mock_genai_client.aio.models.generate_content.assert_not_called() + + async def test_compaction_failure_releases_lock( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + + # Make _compact_session raise an unhandled exception + with patch.object( + compaction_service, + "_compact_session", + side_effect=RuntimeError("unexpected crash"), + ): + await compaction_service._guarded_compact(session) + + # Lock should be released even after failure + session_ref = compaction_service._session_ref( + app_name, user_id, session.id + ) + snap = await session_ref.get() + assert snap.to_dict().get("compaction_lock") is None + + async def test_lock_release_failure_is_non_fatal( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + + original_session_ref = compaction_service._session_ref + + def patched_session_ref(an, uid, sid): + ref = original_session_ref(an, uid, sid) + original_update = ref.update + + async def failing_update(data): + if "compaction_lock" in data: + raise RuntimeError("Firestore write failed") + return await original_update(data) + + ref.update = failing_update + return ref + + with patch.object( + compaction_service, + "_session_ref", + side_effect=patched_session_ref, + ): + # Should not raise despite lock release failure + await compaction_service._guarded_compact(session) + + +# ------------------------------------------------------------------ +# close() +# ------------------------------------------------------------------ + + +class TestClose: + async def test_close_no_tasks(self, compaction_service): + await compaction_service.close() + + async def test_close_awaits_tasks( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + base = time.time() + for i in range(4): + e = Event( + author="user", + content=Content( + role="user", parts=[Part(text=f"msg {i}")] + ), + timestamp=base + i, + invocation_id=f"inv-{i}", + ) + await compaction_service.append_event(session, e) + + trigger = Event( + author=app_name, + content=Content( + role="model", parts=[Part(text="trigger")] + ), + timestamp=base + 4, + invocation_id="inv-4", + usage_metadata=GenerateContentResponseUsageMetadata( + total_token_count=200, + ), + ) + await compaction_service.append_event(session, trigger) + assert len(compaction_service._active_tasks) > 0 + + await compaction_service.close() + assert len(compaction_service._active_tasks) == 0