Add compaction flow lock
This commit is contained in:
@@ -2,8 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
@@ -14,6 +15,9 @@ from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part
|
||||
|
||||
from adk_firestore_sessionmanager import FirestoreSessionService
|
||||
from adk_firestore_sessionmanager.firestore_session_service import (
|
||||
_try_claim_compaction_txn,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
@@ -97,6 +101,7 @@ class TestCompactionTrigger:
|
||||
),
|
||||
)
|
||||
await compaction_service.append_event(session, trigger_event)
|
||||
await compaction_service.close()
|
||||
|
||||
# Summary generation should have been called
|
||||
mock_genai_client.aio.models.generate_content.assert_called_once()
|
||||
@@ -280,3 +285,235 @@ class TestEventsToText:
|
||||
]
|
||||
text = FirestoreSessionService._events_to_text(events)
|
||||
assert text == ""
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Firestore distributed lock
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompactionLock:
|
||||
async def test_claim_and_release(
|
||||
self, compaction_service, app_name, user_id
|
||||
):
|
||||
session = await compaction_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
session_ref = compaction_service._session_ref(
|
||||
app_name, user_id, session.id
|
||||
)
|
||||
|
||||
# Claim the lock
|
||||
transaction = compaction_service._db.transaction()
|
||||
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||
assert claimed is True
|
||||
|
||||
# Lock is now held — second claim should fail
|
||||
transaction2 = compaction_service._db.transaction()
|
||||
claimed2 = await _try_claim_compaction_txn(transaction2, session_ref)
|
||||
assert claimed2 is False
|
||||
|
||||
# Release the lock
|
||||
await session_ref.update({"compaction_lock": None})
|
||||
|
||||
# Can claim again after release
|
||||
transaction3 = compaction_service._db.transaction()
|
||||
claimed3 = await _try_claim_compaction_txn(transaction3, session_ref)
|
||||
assert claimed3 is True
|
||||
|
||||
async def test_stale_lock_can_be_reclaimed(
|
||||
self, compaction_service, app_name, user_id
|
||||
):
|
||||
session = await compaction_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
session_ref = compaction_service._session_ref(
|
||||
app_name, user_id, session.id
|
||||
)
|
||||
|
||||
# Set a stale lock (older than TTL)
|
||||
await session_ref.update({"compaction_lock": time.time() - 600})
|
||||
|
||||
# Should be able to reclaim a stale lock
|
||||
transaction = compaction_service._db.transaction()
|
||||
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||
assert claimed is True
|
||||
|
||||
async def test_claim_nonexistent_session(self, compaction_service):
|
||||
ref = compaction_service._session_ref("no_app", "no_user", "no_id")
|
||||
transaction = compaction_service._db.transaction()
|
||||
claimed = await _try_claim_compaction_txn(transaction, ref)
|
||||
assert claimed is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Guarded compact
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGuardedCompact:
|
||||
async def test_local_lock_skips_concurrent(
|
||||
self, compaction_service, mock_genai_client, app_name, user_id
|
||||
):
|
||||
session = await compaction_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
for i in range(5):
|
||||
e = Event(
|
||||
author="user",
|
||||
content=Content(
|
||||
role="user", parts=[Part(text=f"msg {i}")]
|
||||
),
|
||||
timestamp=time.time() + i,
|
||||
invocation_id=f"inv-{i}",
|
||||
)
|
||||
await compaction_service.append_event(session, e)
|
||||
|
||||
# Hold the in-process lock so _guarded_compact skips
|
||||
key = f"{app_name}__{user_id}__{session.id}"
|
||||
lock = compaction_service._compaction_locks.setdefault(
|
||||
key, asyncio.Lock()
|
||||
)
|
||||
async with lock:
|
||||
await compaction_service._guarded_compact(session)
|
||||
|
||||
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||
|
||||
async def test_firestore_lock_held_skips(
|
||||
self, compaction_service, mock_genai_client, app_name, user_id
|
||||
):
|
||||
session = await compaction_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
for i in range(5):
|
||||
e = Event(
|
||||
author="user",
|
||||
content=Content(
|
||||
role="user", parts=[Part(text=f"msg {i}")]
|
||||
),
|
||||
timestamp=time.time() + i,
|
||||
invocation_id=f"inv-{i}",
|
||||
)
|
||||
await compaction_service.append_event(session, e)
|
||||
|
||||
# Set a fresh Firestore lock (simulating another instance)
|
||||
session_ref = compaction_service._session_ref(
|
||||
app_name, user_id, session.id
|
||||
)
|
||||
await session_ref.update({"compaction_lock": time.time()})
|
||||
|
||||
await compaction_service._guarded_compact(session)
|
||||
|
||||
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||
|
||||
async def test_claim_failure_logs_and_skips(
|
||||
self, compaction_service, mock_genai_client, app_name, user_id
|
||||
):
|
||||
session = await compaction_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
|
||||
with patch(
|
||||
"adk_firestore_sessionmanager.firestore_session_service"
|
||||
"._try_claim_compaction_txn",
|
||||
side_effect=RuntimeError("Firestore down"),
|
||||
):
|
||||
await compaction_service._guarded_compact(session)
|
||||
|
||||
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||
|
||||
async def test_compaction_failure_releases_lock(
|
||||
self, compaction_service, mock_genai_client, app_name, user_id
|
||||
):
|
||||
session = await compaction_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
|
||||
# Make _compact_session raise an unhandled exception
|
||||
with patch.object(
|
||||
compaction_service,
|
||||
"_compact_session",
|
||||
side_effect=RuntimeError("unexpected crash"),
|
||||
):
|
||||
await compaction_service._guarded_compact(session)
|
||||
|
||||
# Lock should be released even after failure
|
||||
session_ref = compaction_service._session_ref(
|
||||
app_name, user_id, session.id
|
||||
)
|
||||
snap = await session_ref.get()
|
||||
assert snap.to_dict().get("compaction_lock") is None
|
||||
|
||||
async def test_lock_release_failure_is_non_fatal(
|
||||
self, compaction_service, mock_genai_client, app_name, user_id
|
||||
):
|
||||
session = await compaction_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
|
||||
original_session_ref = compaction_service._session_ref
|
||||
|
||||
def patched_session_ref(an, uid, sid):
|
||||
ref = original_session_ref(an, uid, sid)
|
||||
original_update = ref.update
|
||||
|
||||
async def failing_update(data):
|
||||
if "compaction_lock" in data:
|
||||
raise RuntimeError("Firestore write failed")
|
||||
return await original_update(data)
|
||||
|
||||
ref.update = failing_update
|
||||
return ref
|
||||
|
||||
with patch.object(
|
||||
compaction_service,
|
||||
"_session_ref",
|
||||
side_effect=patched_session_ref,
|
||||
):
|
||||
# Should not raise despite lock release failure
|
||||
await compaction_service._guarded_compact(session)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# close()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClose:
|
||||
async def test_close_no_tasks(self, compaction_service):
|
||||
await compaction_service.close()
|
||||
|
||||
async def test_close_awaits_tasks(
|
||||
self, compaction_service, mock_genai_client, app_name, user_id
|
||||
):
|
||||
session = await compaction_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
base = time.time()
|
||||
for i in range(4):
|
||||
e = Event(
|
||||
author="user",
|
||||
content=Content(
|
||||
role="user", parts=[Part(text=f"msg {i}")]
|
||||
),
|
||||
timestamp=base + i,
|
||||
invocation_id=f"inv-{i}",
|
||||
)
|
||||
await compaction_service.append_event(session, e)
|
||||
|
||||
trigger = Event(
|
||||
author=app_name,
|
||||
content=Content(
|
||||
role="model", parts=[Part(text="trigger")]
|
||||
),
|
||||
timestamp=base + 4,
|
||||
invocation_id="inv-4",
|
||||
usage_metadata=GenerateContentResponseUsageMetadata(
|
||||
total_token_count=200,
|
||||
),
|
||||
)
|
||||
await compaction_service.append_event(session, trigger)
|
||||
assert len(compaction_service._active_tasks) > 0
|
||||
|
||||
await compaction_service.close()
|
||||
assert len(compaction_service._active_tasks) == 0
|
||||
|
||||
Reference in New Issue
Block a user