Add compaction flow lock
This commit is contained in:
@@ -20,6 +20,7 @@ from google.adk.sessions.base_session_service import (
|
|||||||
from google.adk.sessions.session import Session
|
from google.adk.sessions.session import Session
|
||||||
from google.adk.sessions.state import State
|
from google.adk.sessions.state import State
|
||||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
from google.cloud.firestore_v1.async_transaction import async_transactional
|
||||||
from google.cloud.firestore_v1.base_query import FieldFilter
|
from google.cloud.firestore_v1.base_query import FieldFilter
|
||||||
from google.cloud.firestore_v1.field_path import FieldPath
|
from google.cloud.firestore_v1.field_path import FieldPath
|
||||||
from google.genai.types import Content, Part
|
from google.genai.types import Content, Part
|
||||||
@@ -27,6 +28,22 @@ from typing_extensions import override
|
|||||||
|
|
||||||
logger = logging.getLogger("google_adk." + __name__)
|
logger = logging.getLogger("google_adk." + __name__)
|
||||||
|
|
||||||
|
_COMPACTION_LOCK_TTL = 300 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
@async_transactional
|
||||||
|
async def _try_claim_compaction_txn(transaction, session_ref):
|
||||||
|
"""Atomically claim the compaction lock if it is free or stale."""
|
||||||
|
snapshot = await session_ref.get(transaction=transaction)
|
||||||
|
if not snapshot.exists:
|
||||||
|
return False
|
||||||
|
data = snapshot.to_dict() or {}
|
||||||
|
lock_time = data.get("compaction_lock")
|
||||||
|
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
|
||||||
|
return False
|
||||||
|
transaction.update(session_ref, {"compaction_lock": time.time()})
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class FirestoreSessionService(BaseSessionService):
|
class FirestoreSessionService(BaseSessionService):
|
||||||
"""A Firestore-backed implementation of BaseSessionService.
|
"""A Firestore-backed implementation of BaseSessionService.
|
||||||
@@ -64,6 +81,8 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
self._compaction_model = compaction_model
|
self._compaction_model = compaction_model
|
||||||
self._compaction_keep_recent = compaction_keep_recent
|
self._compaction_keep_recent = compaction_keep_recent
|
||||||
self._genai_client = genai_client
|
self._genai_client = genai_client
|
||||||
|
self._compaction_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._active_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Document-reference helpers
|
# Document-reference helpers
|
||||||
@@ -192,6 +211,12 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
logger.exception("Compaction summary generation failed; skipping.")
|
logger.exception("Compaction summary generation failed; skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Write summary BEFORE deleting events so a crash between the two
|
||||||
|
# steps leaves safe duplication rather than data loss.
|
||||||
|
await self._session_ref(app_name, user_id, session_id).update(
|
||||||
|
{"conversation_summary": summary}
|
||||||
|
)
|
||||||
|
|
||||||
docs_to_delete = event_docs[: -self._compaction_keep_recent]
|
docs_to_delete = event_docs[: -self._compaction_keep_recent]
|
||||||
for i in range(0, len(docs_to_delete), 500):
|
for i in range(0, len(docs_to_delete), 500):
|
||||||
batch = self._db.batch()
|
batch = self._db.batch()
|
||||||
@@ -199,10 +224,6 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
batch.delete(doc.reference)
|
batch.delete(doc.reference)
|
||||||
await batch.commit()
|
await batch.commit()
|
||||||
|
|
||||||
await self._session_ref(app_name, user_id, session_id).update(
|
|
||||||
{"conversation_summary": summary}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Compacted session %s: summarised %d events, kept %d.",
|
"Compacted session %s: summarised %d events, kept %d.",
|
||||||
session_id,
|
session_id,
|
||||||
@@ -210,6 +231,57 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
self._compaction_keep_recent,
|
self._compaction_keep_recent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _guarded_compact(self, session: Session) -> None:
|
||||||
|
"""Run compaction in the background with per-session locking."""
|
||||||
|
key = f"{session.app_name}__{session.user_id}__{session.id}"
|
||||||
|
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
|
||||||
|
|
||||||
|
if lock.locked():
|
||||||
|
logger.debug(
|
||||||
|
"Compaction already running locally for %s; skipping.", key
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
session_ref = self._session_ref(
|
||||||
|
session.app_name, session.user_id, session.id
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
transaction = self._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(
|
||||||
|
transaction, session_ref
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to claim compaction lock for %s", key
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not claimed:
|
||||||
|
logger.debug(
|
||||||
|
"Compaction lock held by another instance for %s;"
|
||||||
|
" skipping.",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._compact_session(session)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Background compaction failed for %s", key)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await session_ref.update({"compaction_lock": None})
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to release compaction lock for %s", key
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Await all in-flight compaction tasks. Call before shutdown."""
|
||||||
|
if self._active_tasks:
|
||||||
|
await asyncio.gather(*self._active_tasks, return_exceptions=True)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# BaseSessionService implementation
|
# BaseSessionService implementation
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -513,7 +585,9 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
event.usage_metadata.total_token_count,
|
event.usage_metadata.total_token_count,
|
||||||
self._compaction_threshold,
|
self._compaction_threshold,
|
||||||
)
|
)
|
||||||
await self._compact_session(session)
|
task = asyncio.create_task(self._guarded_compact(session))
|
||||||
|
self._active_tasks.add(task)
|
||||||
|
task.add_done_callback(self._active_tasks.discard)
|
||||||
|
|
||||||
elapsed = time.monotonic() - t0
|
elapsed = time.monotonic() - t0
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -2,8 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -14,6 +15,9 @@ from google.cloud.firestore_v1.async_client import AsyncClient
|
|||||||
from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part
|
from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part
|
||||||
|
|
||||||
from adk_firestore_sessionmanager import FirestoreSessionService
|
from adk_firestore_sessionmanager import FirestoreSessionService
|
||||||
|
from adk_firestore_sessionmanager.firestore_session_service import (
|
||||||
|
_try_claim_compaction_txn,
|
||||||
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
@@ -97,6 +101,7 @@ class TestCompactionTrigger:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
await compaction_service.append_event(session, trigger_event)
|
await compaction_service.append_event(session, trigger_event)
|
||||||
|
await compaction_service.close()
|
||||||
|
|
||||||
# Summary generation should have been called
|
# Summary generation should have been called
|
||||||
mock_genai_client.aio.models.generate_content.assert_called_once()
|
mock_genai_client.aio.models.generate_content.assert_called_once()
|
||||||
@@ -280,3 +285,235 @@ class TestEventsToText:
|
|||||||
]
|
]
|
||||||
text = FirestoreSessionService._events_to_text(events)
|
text = FirestoreSessionService._events_to_text(events)
|
||||||
assert text == ""
|
assert text == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Firestore distributed lock
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionLock:
|
||||||
|
async def test_claim_and_release(
|
||||||
|
self, compaction_service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Claim the lock
|
||||||
|
transaction = compaction_service._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||||
|
assert claimed is True
|
||||||
|
|
||||||
|
# Lock is now held — second claim should fail
|
||||||
|
transaction2 = compaction_service._db.transaction()
|
||||||
|
claimed2 = await _try_claim_compaction_txn(transaction2, session_ref)
|
||||||
|
assert claimed2 is False
|
||||||
|
|
||||||
|
# Release the lock
|
||||||
|
await session_ref.update({"compaction_lock": None})
|
||||||
|
|
||||||
|
# Can claim again after release
|
||||||
|
transaction3 = compaction_service._db.transaction()
|
||||||
|
claimed3 = await _try_claim_compaction_txn(transaction3, session_ref)
|
||||||
|
assert claimed3 is True
|
||||||
|
|
||||||
|
async def test_stale_lock_can_be_reclaimed(
|
||||||
|
self, compaction_service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set a stale lock (older than TTL)
|
||||||
|
await session_ref.update({"compaction_lock": time.time() - 600})
|
||||||
|
|
||||||
|
# Should be able to reclaim a stale lock
|
||||||
|
transaction = compaction_service._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||||
|
assert claimed is True
|
||||||
|
|
||||||
|
async def test_claim_nonexistent_session(self, compaction_service):
|
||||||
|
ref = compaction_service._session_ref("no_app", "no_user", "no_id")
|
||||||
|
transaction = compaction_service._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, ref)
|
||||||
|
assert claimed is False
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Guarded compact
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGuardedCompact:
|
||||||
|
async def test_local_lock_skips_concurrent(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Hold the in-process lock so _guarded_compact skips
|
||||||
|
key = f"{app_name}__{user_id}__{session.id}"
|
||||||
|
lock = compaction_service._compaction_locks.setdefault(
|
||||||
|
key, asyncio.Lock()
|
||||||
|
)
|
||||||
|
async with lock:
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_firestore_lock_held_skips(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Set a fresh Firestore lock (simulating another instance)
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
await session_ref.update({"compaction_lock": time.time()})
|
||||||
|
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_claim_failure_logs_and_skips(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"adk_firestore_sessionmanager.firestore_session_service"
|
||||||
|
"._try_claim_compaction_txn",
|
||||||
|
side_effect=RuntimeError("Firestore down"),
|
||||||
|
):
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_compaction_failure_releases_lock(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make _compact_session raise an unhandled exception
|
||||||
|
with patch.object(
|
||||||
|
compaction_service,
|
||||||
|
"_compact_session",
|
||||||
|
side_effect=RuntimeError("unexpected crash"),
|
||||||
|
):
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
# Lock should be released even after failure
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
snap = await session_ref.get()
|
||||||
|
assert snap.to_dict().get("compaction_lock") is None
|
||||||
|
|
||||||
|
async def test_lock_release_failure_is_non_fatal(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
original_session_ref = compaction_service._session_ref
|
||||||
|
|
||||||
|
def patched_session_ref(an, uid, sid):
|
||||||
|
ref = original_session_ref(an, uid, sid)
|
||||||
|
original_update = ref.update
|
||||||
|
|
||||||
|
async def failing_update(data):
|
||||||
|
if "compaction_lock" in data:
|
||||||
|
raise RuntimeError("Firestore write failed")
|
||||||
|
return await original_update(data)
|
||||||
|
|
||||||
|
ref.update = failing_update
|
||||||
|
return ref
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
compaction_service,
|
||||||
|
"_session_ref",
|
||||||
|
side_effect=patched_session_ref,
|
||||||
|
):
|
||||||
|
# Should not raise despite lock release failure
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# close()
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestClose:
|
||||||
|
async def test_close_no_tasks(self, compaction_service):
|
||||||
|
await compaction_service.close()
|
||||||
|
|
||||||
|
async def test_close_awaits_tasks(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
base = time.time()
|
||||||
|
for i in range(4):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=base + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
trigger = Event(
|
||||||
|
author=app_name,
|
||||||
|
content=Content(
|
||||||
|
role="model", parts=[Part(text="trigger")]
|
||||||
|
),
|
||||||
|
timestamp=base + 4,
|
||||||
|
invocation_id="inv-4",
|
||||||
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
||||||
|
total_token_count=200,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, trigger)
|
||||||
|
assert len(compaction_service._active_tasks) > 0
|
||||||
|
|
||||||
|
await compaction_service.close()
|
||||||
|
assert len(compaction_service._active_tasks) == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user