548 lines
19 KiB
Python
548 lines
19 KiB
Python
"""Tests for conversation compaction in FirestoreSessionService."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
import uuid
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from google import genai
|
|
from google.adk.events.event import Event
|
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
|
from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part
|
|
|
|
from va_agent.session import FirestoreSessionService
|
|
from va_agent.compaction import SessionCompactor, _try_claim_compaction_txn
|
|
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def mock_genai_client():
|
|
client = MagicMock(spec=genai.Client)
|
|
response = MagicMock()
|
|
response.text = "Summary of the conversation so far."
|
|
client.aio.models.generate_content = AsyncMock(return_value=response)
|
|
return client
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def compaction_service(db: AsyncClient, mock_genai_client):
|
|
prefix = f"test_{uuid.uuid4().hex[:8]}"
|
|
return FirestoreSessionService(
|
|
db=db,
|
|
collection_prefix=prefix,
|
|
compaction_token_threshold=100,
|
|
compaction_keep_recent=2,
|
|
genai_client=mock_genai_client,
|
|
)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# __init__ validation
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
class TestCompactionInit:
|
|
async def test_requires_genai_client(self, db):
|
|
with pytest.raises(ValueError, match="genai_client is required"):
|
|
FirestoreSessionService(
|
|
db=db,
|
|
compaction_token_threshold=1000,
|
|
)
|
|
|
|
async def test_no_threshold_no_client_ok(self, db):
|
|
svc = FirestoreSessionService(db=db)
|
|
assert svc._compaction_threshold is None
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Compaction trigger
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
class TestCompactionTrigger:
|
|
async def test_compaction_triggered_above_threshold(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
|
|
# Add 5 events, last one with usage_metadata above threshold
|
|
base = time.time()
|
|
for i in range(4):
|
|
e = Event(
|
|
author="user" if i % 2 == 0 else app_name,
|
|
content=Content(
|
|
role="user" if i % 2 == 0 else "model",
|
|
parts=[Part(text=f"message {i}")],
|
|
),
|
|
timestamp=base + i,
|
|
invocation_id=f"inv-{i}",
|
|
)
|
|
await compaction_service.append_event(session, e)
|
|
|
|
# This event crosses the threshold
|
|
trigger_event = Event(
|
|
author=app_name,
|
|
content=Content(
|
|
role="model", parts=[Part(text="final response")]
|
|
),
|
|
timestamp=base + 4,
|
|
invocation_id="inv-4",
|
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
|
total_token_count=200,
|
|
),
|
|
)
|
|
await compaction_service.append_event(session, trigger_event)
|
|
await compaction_service.close()
|
|
|
|
# Summary generation should have been called
|
|
mock_genai_client.aio.models.generate_content.assert_called_once()
|
|
|
|
# Fetch session: should have summary + only keep_recent events
|
|
fetched = await compaction_service.get_session(
|
|
app_name=app_name, user_id=user_id, session_id=session.id
|
|
)
|
|
# 2 synthetic summary events + 2 kept real events
|
|
assert len(fetched.events) == 4
|
|
assert fetched.events[0].id == "summary-context"
|
|
assert fetched.events[1].id == "summary-ack"
|
|
assert "Summary of the conversation" in fetched.events[0].content.parts[0].text
|
|
|
|
async def test_no_compaction_below_threshold(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
event = Event(
|
|
author=app_name,
|
|
content=Content(
|
|
role="model", parts=[Part(text="short reply")]
|
|
),
|
|
timestamp=time.time(),
|
|
invocation_id="inv-1",
|
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
|
total_token_count=50,
|
|
),
|
|
)
|
|
await compaction_service.append_event(session, event)
|
|
|
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
|
|
|
async def test_no_compaction_without_usage_metadata(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
event = Event(
|
|
author="user",
|
|
content=Content(
|
|
role="user", parts=[Part(text="hello")]
|
|
),
|
|
timestamp=time.time(),
|
|
invocation_id="inv-1",
|
|
)
|
|
await compaction_service.append_event(session, event)
|
|
|
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Compaction with too few events (nothing to compact)
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
class TestCompactionEdgeCases:
|
|
async def test_skip_when_fewer_events_than_keep_recent(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
# Only 2 events, keep_recent=2 → nothing to summarize
|
|
for i in range(2):
|
|
e = Event(
|
|
author="user",
|
|
content=Content(
|
|
role="user", parts=[Part(text=f"msg {i}")]
|
|
),
|
|
timestamp=time.time() + i,
|
|
invocation_id=f"inv-{i}",
|
|
)
|
|
await compaction_service.append_event(session, e)
|
|
|
|
# Trigger compaction manually even though threshold wouldn't fire
|
|
events_ref = compaction_service._events_col(app_name, user_id, session.id)
|
|
session_ref = compaction_service._session_ref(app_name, user_id, session.id)
|
|
await compaction_service._compactor._compact_session(session, events_ref, session_ref)
|
|
|
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
|
|
|
async def test_summary_generation_failure_is_non_fatal(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
for i in range(5):
|
|
e = Event(
|
|
author="user",
|
|
content=Content(
|
|
role="user", parts=[Part(text=f"msg {i}")]
|
|
),
|
|
timestamp=time.time() + i,
|
|
invocation_id=f"inv-{i}",
|
|
)
|
|
await compaction_service.append_event(session, e)
|
|
|
|
# Make summary generation fail
|
|
mock_genai_client.aio.models.generate_content = AsyncMock(
|
|
side_effect=RuntimeError("API error")
|
|
)
|
|
|
|
# Should not raise
|
|
events_ref = compaction_service._events_col(app_name, user_id, session.id)
|
|
session_ref = compaction_service._session_ref(app_name, user_id, session.id)
|
|
await compaction_service._compactor._compact_session(session, events_ref, session_ref)
|
|
|
|
# All events should still be present
|
|
fetched = await compaction_service.get_session(
|
|
app_name=app_name, user_id=user_id, session_id=session.id
|
|
)
|
|
assert len(fetched.events) == 5
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# get_session with summary
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
class TestGetSessionWithSummary:
|
|
async def test_no_summary_no_synthetic_events(
|
|
self, compaction_service, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
event = Event(
|
|
author="user",
|
|
content=Content(
|
|
role="user", parts=[Part(text="hello")]
|
|
),
|
|
timestamp=time.time(),
|
|
invocation_id="inv-1",
|
|
)
|
|
await compaction_service.append_event(session, event)
|
|
|
|
fetched = await compaction_service.get_session(
|
|
app_name=app_name, user_id=user_id, session_id=session.id
|
|
)
|
|
assert len(fetched.events) == 1
|
|
assert fetched.events[0].author == "user"
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# _events_to_text
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
class TestEventsToText:
|
|
async def test_formats_user_and_assistant(self):
|
|
events = [
|
|
Event(
|
|
author="user",
|
|
content=Content(
|
|
role="user", parts=[Part(text="Hi there")]
|
|
),
|
|
timestamp=1.0,
|
|
invocation_id="inv-1",
|
|
),
|
|
Event(
|
|
author="bot",
|
|
content=Content(
|
|
role="model", parts=[Part(text="Hello!")]
|
|
),
|
|
timestamp=2.0,
|
|
invocation_id="inv-2",
|
|
),
|
|
]
|
|
text = SessionCompactor._events_to_text(events)
|
|
assert "User: Hi there" in text
|
|
assert "Assistant: Hello!" in text
|
|
|
|
async def test_skips_events_without_text(self):
|
|
events = [
|
|
Event(
|
|
author="user",
|
|
timestamp=1.0,
|
|
invocation_id="inv-1",
|
|
),
|
|
]
|
|
text = SessionCompactor._events_to_text(events)
|
|
assert text == ""
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Firestore distributed lock
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
class TestCompactionLock:
|
|
async def test_claim_and_release(
|
|
self, compaction_service, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
session_ref = compaction_service._session_ref(
|
|
app_name, user_id, session.id
|
|
)
|
|
|
|
# Claim the lock
|
|
transaction = compaction_service._db.transaction()
|
|
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
|
assert claimed is True
|
|
|
|
# Lock is now held — second claim should fail
|
|
transaction2 = compaction_service._db.transaction()
|
|
claimed2 = await _try_claim_compaction_txn(transaction2, session_ref)
|
|
assert claimed2 is False
|
|
|
|
# Release the lock
|
|
await session_ref.update({"compaction_lock": None})
|
|
|
|
# Can claim again after release
|
|
transaction3 = compaction_service._db.transaction()
|
|
claimed3 = await _try_claim_compaction_txn(transaction3, session_ref)
|
|
assert claimed3 is True
|
|
|
|
async def test_stale_lock_can_be_reclaimed(
|
|
self, compaction_service, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
session_ref = compaction_service._session_ref(
|
|
app_name, user_id, session.id
|
|
)
|
|
|
|
# Set a stale lock (older than TTL)
|
|
await session_ref.update({"compaction_lock": time.time() - 600})
|
|
|
|
# Should be able to reclaim a stale lock
|
|
transaction = compaction_service._db.transaction()
|
|
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
|
assert claimed is True
|
|
|
|
async def test_claim_nonexistent_session(self, compaction_service):
|
|
ref = compaction_service._session_ref("no_app", "no_user", "no_id")
|
|
transaction = compaction_service._db.transaction()
|
|
claimed = await _try_claim_compaction_txn(transaction, ref)
|
|
assert claimed is False
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Guarded compact
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
class TestGuardedCompact:
|
|
async def test_local_lock_skips_concurrent(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
for i in range(5):
|
|
e = Event(
|
|
author="user",
|
|
content=Content(
|
|
role="user", parts=[Part(text=f"msg {i}")]
|
|
),
|
|
timestamp=time.time() + i,
|
|
invocation_id=f"inv-{i}",
|
|
)
|
|
await compaction_service.append_event(session, e)
|
|
|
|
# Hold the in-process lock so _guarded_compact skips
|
|
key = f"{app_name}__{user_id}__{session.id}"
|
|
lock = compaction_service._compactor._compaction_locks.setdefault(
|
|
key, asyncio.Lock()
|
|
)
|
|
events_ref = compaction_service._events_col(app_name, user_id, session.id)
|
|
session_ref = compaction_service._session_ref(app_name, user_id, session.id)
|
|
async with lock:
|
|
await compaction_service._compactor.guarded_compact(
|
|
session, events_ref, session_ref
|
|
)
|
|
|
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
|
|
|
async def test_firestore_lock_held_skips(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
for i in range(5):
|
|
e = Event(
|
|
author="user",
|
|
content=Content(
|
|
role="user", parts=[Part(text=f"msg {i}")]
|
|
),
|
|
timestamp=time.time() + i,
|
|
invocation_id=f"inv-{i}",
|
|
)
|
|
await compaction_service.append_event(session, e)
|
|
|
|
# Set a fresh Firestore lock (simulating another instance)
|
|
session_ref = compaction_service._session_ref(
|
|
app_name, user_id, session.id
|
|
)
|
|
await session_ref.update({"compaction_lock": time.time()})
|
|
|
|
events_ref = compaction_service._events_col(app_name, user_id, session.id)
|
|
await compaction_service._compactor.guarded_compact(
|
|
session, events_ref, session_ref
|
|
)
|
|
|
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
|
|
|
async def test_claim_failure_logs_and_skips(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
|
|
with patch(
|
|
"va_agent.compaction._try_claim_compaction_txn",
|
|
side_effect=RuntimeError("Firestore down"),
|
|
):
|
|
events_ref = compaction_service._events_col(
|
|
app_name, user_id, session.id
|
|
)
|
|
session_ref = compaction_service._session_ref(
|
|
app_name, user_id, session.id
|
|
)
|
|
await compaction_service._compactor.guarded_compact(
|
|
session, events_ref, session_ref
|
|
)
|
|
|
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
|
|
|
async def test_compaction_failure_releases_lock(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
|
|
# Make _compact_session raise an unhandled exception
|
|
with patch.object(
|
|
compaction_service._compactor,
|
|
"_compact_session",
|
|
side_effect=RuntimeError("unexpected crash"),
|
|
):
|
|
events_ref = compaction_service._events_col(
|
|
app_name, user_id, session.id
|
|
)
|
|
session_ref = compaction_service._session_ref(
|
|
app_name, user_id, session.id
|
|
)
|
|
await compaction_service._compactor.guarded_compact(
|
|
session, events_ref, session_ref
|
|
)
|
|
|
|
# Lock should be released even after failure
|
|
session_ref = compaction_service._session_ref(
|
|
app_name, user_id, session.id
|
|
)
|
|
snap = await session_ref.get()
|
|
assert snap.to_dict().get("compaction_lock") is None
|
|
|
|
async def test_lock_release_failure_is_non_fatal(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
|
|
original_session_ref = compaction_service._session_ref
|
|
|
|
def patched_session_ref(an, uid, sid):
|
|
ref = original_session_ref(an, uid, sid)
|
|
original_update = ref.update
|
|
|
|
async def failing_update(data):
|
|
if "compaction_lock" in data:
|
|
raise RuntimeError("Firestore write failed")
|
|
return await original_update(data)
|
|
|
|
ref.update = failing_update
|
|
return ref
|
|
|
|
with patch.object(
|
|
compaction_service,
|
|
"_session_ref",
|
|
side_effect=patched_session_ref,
|
|
):
|
|
# Should not raise despite lock release failure
|
|
events_ref = compaction_service._events_col(app_name, user_id, session.id)
|
|
session_ref = compaction_service._session_ref(app_name, user_id, session.id)
|
|
await compaction_service._compactor.guarded_compact(
|
|
session, events_ref, session_ref
|
|
)
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# close()
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
class TestClose:
|
|
async def test_close_no_tasks(self, compaction_service):
|
|
await compaction_service.close()
|
|
|
|
async def test_close_awaits_tasks(
|
|
self, compaction_service, mock_genai_client, app_name, user_id
|
|
):
|
|
session = await compaction_service.create_session(
|
|
app_name=app_name, user_id=user_id
|
|
)
|
|
base = time.time()
|
|
for i in range(4):
|
|
e = Event(
|
|
author="user",
|
|
content=Content(
|
|
role="user", parts=[Part(text=f"msg {i}")]
|
|
),
|
|
timestamp=base + i,
|
|
invocation_id=f"inv-{i}",
|
|
)
|
|
await compaction_service.append_event(session, e)
|
|
|
|
trigger = Event(
|
|
author=app_name,
|
|
content=Content(
|
|
role="model", parts=[Part(text="trigger")]
|
|
),
|
|
timestamp=base + 4,
|
|
invocation_id="inv-4",
|
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
|
total_token_count=200,
|
|
),
|
|
)
|
|
await compaction_service.append_event(session, trigger)
|
|
assert len(compaction_service._active_tasks) > 0
|
|
|
|
await compaction_service.close()
|
|
assert len(compaction_service._active_tasks) == 0
|