"""Tests for conversation compaction in FirestoreSessionService.""" from __future__ import annotations import asyncio import time from unittest.mock import AsyncMock, MagicMock, patch import uuid import pytest import pytest_asyncio from google import genai from google.adk.events.event import Event from google.cloud.firestore_v1.async_client import AsyncClient from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part from va_agent.session import FirestoreSessionService, _try_claim_compaction_txn pytestmark = pytest.mark.asyncio @pytest_asyncio.fixture async def mock_genai_client(): client = MagicMock(spec=genai.Client) response = MagicMock() response.text = "Summary of the conversation so far." client.aio.models.generate_content = AsyncMock(return_value=response) return client @pytest_asyncio.fixture async def compaction_service(db: AsyncClient, mock_genai_client): prefix = f"test_{uuid.uuid4().hex[:8]}" return FirestoreSessionService( db=db, collection_prefix=prefix, compaction_token_threshold=100, compaction_keep_recent=2, genai_client=mock_genai_client, ) # ------------------------------------------------------------------ # __init__ validation # ------------------------------------------------------------------ class TestCompactionInit: async def test_requires_genai_client(self, db): with pytest.raises(ValueError, match="genai_client is required"): FirestoreSessionService( db=db, compaction_token_threshold=1000, ) async def test_no_threshold_no_client_ok(self, db): svc = FirestoreSessionService(db=db) assert svc._compaction_threshold is None # ------------------------------------------------------------------ # Compaction trigger # ------------------------------------------------------------------ class TestCompactionTrigger: async def test_compaction_triggered_above_threshold( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) # Add 5 events, last one with usage_metadata above threshold base = time.time() for i in range(4): e = Event( author="user" if i % 2 == 0 else app_name, content=Content( role="user" if i % 2 == 0 else "model", parts=[Part(text=f"message {i}")], ), timestamp=base + i, invocation_id=f"inv-{i}", ) await compaction_service.append_event(session, e) # This event crosses the threshold trigger_event = Event( author=app_name, content=Content( role="model", parts=[Part(text="final response")] ), timestamp=base + 4, invocation_id="inv-4", usage_metadata=GenerateContentResponseUsageMetadata( total_token_count=200, ), ) await compaction_service.append_event(session, trigger_event) await compaction_service.close() # Summary generation should have been called mock_genai_client.aio.models.generate_content.assert_called_once() # Fetch session: should have summary + only keep_recent events fetched = await compaction_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) # 2 synthetic summary events + 2 kept real events assert len(fetched.events) == 4 assert fetched.events[0].id == "summary-context" assert fetched.events[1].id == "summary-ack" assert "Summary of the conversation" in fetched.events[0].content.parts[0].text async def test_no_compaction_below_threshold( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) event = Event( author=app_name, content=Content( role="model", parts=[Part(text="short reply")] ), timestamp=time.time(), invocation_id="inv-1", usage_metadata=GenerateContentResponseUsageMetadata( total_token_count=50, ), ) await compaction_service.append_event(session, event) mock_genai_client.aio.models.generate_content.assert_not_called() async def test_no_compaction_without_usage_metadata( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="user", content=Content( role="user", parts=[Part(text="hello")] ), timestamp=time.time(), invocation_id="inv-1", ) await compaction_service.append_event(session, event) mock_genai_client.aio.models.generate_content.assert_not_called() # ------------------------------------------------------------------ # Compaction with too few events (nothing to compact) # ------------------------------------------------------------------ class TestCompactionEdgeCases: async def test_skip_when_fewer_events_than_keep_recent( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) # Only 2 events, keep_recent=2 → nothing to summarize for i in range(2): e = Event( author="user", content=Content( role="user", parts=[Part(text=f"msg {i}")] ), timestamp=time.time() + i, invocation_id=f"inv-{i}", ) await compaction_service.append_event(session, e) # Trigger compaction manually even though threshold wouldn't fire await compaction_service._compact_session(session) mock_genai_client.aio.models.generate_content.assert_not_called() async def test_summary_generation_failure_is_non_fatal( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) for i in range(5): e = Event( author="user", content=Content( role="user", parts=[Part(text=f"msg {i}")] ), timestamp=time.time() + i, invocation_id=f"inv-{i}", ) await compaction_service.append_event(session, e) # Make summary generation fail mock_genai_client.aio.models.generate_content = AsyncMock( side_effect=RuntimeError("API error") ) # Should not raise await compaction_service._compact_session(session) # All events should still be present fetched = await compaction_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert len(fetched.events) == 5 # ------------------------------------------------------------------ # get_session with summary # ------------------------------------------------------------------ class TestGetSessionWithSummary: async def test_no_summary_no_synthetic_events( self, compaction_service, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="user", content=Content( role="user", parts=[Part(text="hello")] ), timestamp=time.time(), invocation_id="inv-1", ) await compaction_service.append_event(session, event) fetched = await compaction_service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert len(fetched.events) == 1 assert fetched.events[0].author == "user" # ------------------------------------------------------------------ # _events_to_text # ------------------------------------------------------------------ class TestEventsToText: async def test_formats_user_and_assistant(self): events = [ Event( author="user", content=Content( role="user", parts=[Part(text="Hi there")] ), timestamp=1.0, invocation_id="inv-1", ), Event( author="bot", content=Content( role="model", parts=[Part(text="Hello!")] ), timestamp=2.0, invocation_id="inv-2", ), ] text = FirestoreSessionService._events_to_text(events) assert "User: Hi there" in text assert "Assistant: Hello!" in text async def test_skips_events_without_text(self): events = [ Event( author="user", timestamp=1.0, invocation_id="inv-1", ), ] text = FirestoreSessionService._events_to_text(events) assert text == "" # ------------------------------------------------------------------ # Firestore distributed lock # ------------------------------------------------------------------ class TestCompactionLock: async def test_claim_and_release( self, compaction_service, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) session_ref = compaction_service._session_ref( app_name, user_id, session.id ) # Claim the lock transaction = compaction_service._db.transaction() claimed = await _try_claim_compaction_txn(transaction, session_ref) assert claimed is True # Lock is now held — second claim should fail transaction2 = compaction_service._db.transaction() claimed2 = await _try_claim_compaction_txn(transaction2, session_ref) assert claimed2 is False # Release the lock await session_ref.update({"compaction_lock": None}) # Can claim again after release transaction3 = compaction_service._db.transaction() claimed3 = await _try_claim_compaction_txn(transaction3, session_ref) assert claimed3 is True async def test_stale_lock_can_be_reclaimed( self, compaction_service, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) session_ref = compaction_service._session_ref( app_name, user_id, session.id ) # Set a stale lock (older than TTL) await session_ref.update({"compaction_lock": time.time() - 600}) # Should be able to reclaim a stale lock transaction = compaction_service._db.transaction() claimed = await _try_claim_compaction_txn(transaction, session_ref) assert claimed is True async def test_claim_nonexistent_session(self, compaction_service): ref = compaction_service._session_ref("no_app", "no_user", "no_id") transaction = compaction_service._db.transaction() claimed = await _try_claim_compaction_txn(transaction, ref) assert claimed is False # ------------------------------------------------------------------ # Guarded compact # ------------------------------------------------------------------ class TestGuardedCompact: async def test_local_lock_skips_concurrent( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) for i in range(5): e = Event( author="user", content=Content( role="user", parts=[Part(text=f"msg {i}")] ), timestamp=time.time() + i, invocation_id=f"inv-{i}", ) await compaction_service.append_event(session, e) # Hold the in-process lock so _guarded_compact skips key = f"{app_name}__{user_id}__{session.id}" lock = compaction_service._compaction_locks.setdefault( key, asyncio.Lock() ) async with lock: await compaction_service._guarded_compact(session) mock_genai_client.aio.models.generate_content.assert_not_called() async def test_firestore_lock_held_skips( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) for i in range(5): e = Event( author="user", content=Content( role="user", parts=[Part(text=f"msg {i}")] ), timestamp=time.time() + i, invocation_id=f"inv-{i}", ) await compaction_service.append_event(session, e) # Set a fresh Firestore lock (simulating another instance) session_ref = compaction_service._session_ref( app_name, user_id, session.id ) await session_ref.update({"compaction_lock": time.time()}) await compaction_service._guarded_compact(session) mock_genai_client.aio.models.generate_content.assert_not_called() async def test_claim_failure_logs_and_skips( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) with patch( "va_agent.session._try_claim_compaction_txn", side_effect=RuntimeError("Firestore down"), ): await compaction_service._guarded_compact(session) mock_genai_client.aio.models.generate_content.assert_not_called() async def test_compaction_failure_releases_lock( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) # Make _compact_session raise an unhandled exception with patch.object( compaction_service, "_compact_session", side_effect=RuntimeError("unexpected crash"), ): await compaction_service._guarded_compact(session) # Lock should be released even after failure session_ref = compaction_service._session_ref( app_name, user_id, session.id ) snap = await session_ref.get() assert snap.to_dict().get("compaction_lock") is None async def test_lock_release_failure_is_non_fatal( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) original_session_ref = compaction_service._session_ref def patched_session_ref(an, uid, sid): ref = original_session_ref(an, uid, sid) original_update = ref.update async def failing_update(data): if "compaction_lock" in data: raise RuntimeError("Firestore write failed") return await original_update(data) ref.update = failing_update return ref with patch.object( compaction_service, "_session_ref", side_effect=patched_session_ref, ): # Should not raise despite lock release failure await compaction_service._guarded_compact(session) # ------------------------------------------------------------------ # close() # ------------------------------------------------------------------ class TestClose: async def test_close_no_tasks(self, compaction_service): await compaction_service.close() async def test_close_awaits_tasks( self, compaction_service, mock_genai_client, app_name, user_id ): session = await compaction_service.create_session( app_name=app_name, user_id=user_id ) base = time.time() for i in range(4): e = Event( author="user", content=Content( role="user", parts=[Part(text=f"msg {i}")] ), timestamp=base + i, invocation_id=f"inv-{i}", ) await compaction_service.append_event(session, e) trigger = Event( author=app_name, content=Content( role="model", parts=[Part(text="trigger")] ), timestamp=base + 4, invocation_id="inv-4", usage_metadata=GenerateContentResponseUsageMetadata( total_token_count=200, ), ) await compaction_service.append_event(session, trigger) assert len(compaction_service._active_tasks) > 0 await compaction_service.close() assert len(compaction_service._active_tasks) == 0