Lean MCP implementation

This commit is contained in:
2026-02-23 03:29:21 +00:00
parent 98d23b80e4
commit 1c6d942177
37 changed files with 2380 additions and 3541 deletions

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@

35
tests/conftest.py Normal file
View File

@@ -0,0 +1,35 @@
"""Shared fixtures for Firestore session service tests."""
from __future__ import annotations
import os
import uuid
import pytest
import pytest_asyncio
from google.cloud.firestore_v1.async_client import AsyncClient
from va_agent.session import FirestoreSessionService
os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8153")
@pytest_asyncio.fixture
async def db():
return AsyncClient(project="test-project")
@pytest_asyncio.fixture
async def service(db: AsyncClient):
prefix = f"test_{uuid.uuid4().hex[:8]}"
return FirestoreSessionService(db=db, collection_prefix=prefix)
@pytest.fixture
def app_name():
return f"app_{uuid.uuid4().hex[:8]}"
@pytest.fixture
def user_id():
return f"user_{uuid.uuid4().hex[:8]}"

515
tests/test_compaction.py Normal file
View File

@@ -0,0 +1,515 @@
"""Tests for conversation compaction in FirestoreSessionService."""
from __future__ import annotations
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import uuid
import pytest
import pytest_asyncio
from google import genai
from google.adk.events.event import Event
from google.cloud.firestore_v1.async_client import AsyncClient
from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part
from va_agent.session import FirestoreSessionService, _try_claim_compaction_txn
pytestmark = pytest.mark.asyncio
@pytest_asyncio.fixture
async def mock_genai_client():
client = MagicMock(spec=genai.Client)
response = MagicMock()
response.text = "Summary of the conversation so far."
client.aio.models.generate_content = AsyncMock(return_value=response)
return client
@pytest_asyncio.fixture
async def compaction_service(db: AsyncClient, mock_genai_client):
prefix = f"test_{uuid.uuid4().hex[:8]}"
return FirestoreSessionService(
db=db,
collection_prefix=prefix,
compaction_token_threshold=100,
compaction_keep_recent=2,
genai_client=mock_genai_client,
)
# ------------------------------------------------------------------
# __init__ validation
# ------------------------------------------------------------------
class TestCompactionInit:
async def test_requires_genai_client(self, db):
with pytest.raises(ValueError, match="genai_client is required"):
FirestoreSessionService(
db=db,
compaction_token_threshold=1000,
)
async def test_no_threshold_no_client_ok(self, db):
svc = FirestoreSessionService(db=db)
assert svc._compaction_threshold is None
# ------------------------------------------------------------------
# Compaction trigger
# ------------------------------------------------------------------
class TestCompactionTrigger:
async def test_compaction_triggered_above_threshold(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
# Add 5 events, last one with usage_metadata above threshold
base = time.time()
for i in range(4):
e = Event(
author="user" if i % 2 == 0 else app_name,
content=Content(
role="user" if i % 2 == 0 else "model",
parts=[Part(text=f"message {i}")],
),
timestamp=base + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# This event crosses the threshold
trigger_event = Event(
author=app_name,
content=Content(
role="model", parts=[Part(text="final response")]
),
timestamp=base + 4,
invocation_id="inv-4",
usage_metadata=GenerateContentResponseUsageMetadata(
total_token_count=200,
),
)
await compaction_service.append_event(session, trigger_event)
await compaction_service.close()
# Summary generation should have been called
mock_genai_client.aio.models.generate_content.assert_called_once()
# Fetch session: should have summary + only keep_recent events
fetched = await compaction_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
# 2 synthetic summary events + 2 kept real events
assert len(fetched.events) == 4
assert fetched.events[0].id == "summary-context"
assert fetched.events[1].id == "summary-ack"
assert "Summary of the conversation" in fetched.events[0].content.parts[0].text
async def test_no_compaction_below_threshold(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author=app_name,
content=Content(
role="model", parts=[Part(text="short reply")]
),
timestamp=time.time(),
invocation_id="inv-1",
usage_metadata=GenerateContentResponseUsageMetadata(
total_token_count=50,
),
)
await compaction_service.append_event(session, event)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_no_compaction_without_usage_metadata(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(
role="user", parts=[Part(text="hello")]
),
timestamp=time.time(),
invocation_id="inv-1",
)
await compaction_service.append_event(session, event)
mock_genai_client.aio.models.generate_content.assert_not_called()
# ------------------------------------------------------------------
# Compaction with too few events (nothing to compact)
# ------------------------------------------------------------------
class TestCompactionEdgeCases:
async def test_skip_when_fewer_events_than_keep_recent(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
# Only 2 events, keep_recent=2 → nothing to summarize
for i in range(2):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Trigger compaction manually even though threshold wouldn't fire
await compaction_service._compact_session(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_summary_generation_failure_is_non_fatal(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Make summary generation fail
mock_genai_client.aio.models.generate_content = AsyncMock(
side_effect=RuntimeError("API error")
)
# Should not raise
await compaction_service._compact_session(session)
# All events should still be present
fetched = await compaction_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 5
# ------------------------------------------------------------------
# get_session with summary
# ------------------------------------------------------------------
class TestGetSessionWithSummary:
async def test_no_summary_no_synthetic_events(
self, compaction_service, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(
role="user", parts=[Part(text="hello")]
),
timestamp=time.time(),
invocation_id="inv-1",
)
await compaction_service.append_event(session, event)
fetched = await compaction_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 1
assert fetched.events[0].author == "user"
# ------------------------------------------------------------------
# _events_to_text
# ------------------------------------------------------------------
class TestEventsToText:
async def test_formats_user_and_assistant(self):
events = [
Event(
author="user",
content=Content(
role="user", parts=[Part(text="Hi there")]
),
timestamp=1.0,
invocation_id="inv-1",
),
Event(
author="bot",
content=Content(
role="model", parts=[Part(text="Hello!")]
),
timestamp=2.0,
invocation_id="inv-2",
),
]
text = FirestoreSessionService._events_to_text(events)
assert "User: Hi there" in text
assert "Assistant: Hello!" in text
async def test_skips_events_without_text(self):
events = [
Event(
author="user",
timestamp=1.0,
invocation_id="inv-1",
),
]
text = FirestoreSessionService._events_to_text(events)
assert text == ""
# ------------------------------------------------------------------
# Firestore distributed lock
# ------------------------------------------------------------------
class TestCompactionLock:
async def test_claim_and_release(
self, compaction_service, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
# Claim the lock
transaction = compaction_service._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, session_ref)
assert claimed is True
# Lock is now held — second claim should fail
transaction2 = compaction_service._db.transaction()
claimed2 = await _try_claim_compaction_txn(transaction2, session_ref)
assert claimed2 is False
# Release the lock
await session_ref.update({"compaction_lock": None})
# Can claim again after release
transaction3 = compaction_service._db.transaction()
claimed3 = await _try_claim_compaction_txn(transaction3, session_ref)
assert claimed3 is True
async def test_stale_lock_can_be_reclaimed(
self, compaction_service, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
# Set a stale lock (older than TTL)
await session_ref.update({"compaction_lock": time.time() - 600})
# Should be able to reclaim a stale lock
transaction = compaction_service._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, session_ref)
assert claimed is True
async def test_claim_nonexistent_session(self, compaction_service):
ref = compaction_service._session_ref("no_app", "no_user", "no_id")
transaction = compaction_service._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, ref)
assert claimed is False
# ------------------------------------------------------------------
# Guarded compact
# ------------------------------------------------------------------
class TestGuardedCompact:
async def test_local_lock_skips_concurrent(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Hold the in-process lock so _guarded_compact skips
key = f"{app_name}__{user_id}__{session.id}"
lock = compaction_service._compaction_locks.setdefault(
key, asyncio.Lock()
)
async with lock:
await compaction_service._guarded_compact(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_firestore_lock_held_skips(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Set a fresh Firestore lock (simulating another instance)
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
await session_ref.update({"compaction_lock": time.time()})
await compaction_service._guarded_compact(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_claim_failure_logs_and_skips(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
with patch(
"va_agent.session._try_claim_compaction_txn",
side_effect=RuntimeError("Firestore down"),
):
await compaction_service._guarded_compact(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_compaction_failure_releases_lock(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
# Make _compact_session raise an unhandled exception
with patch.object(
compaction_service,
"_compact_session",
side_effect=RuntimeError("unexpected crash"),
):
await compaction_service._guarded_compact(session)
# Lock should be released even after failure
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
snap = await session_ref.get()
assert snap.to_dict().get("compaction_lock") is None
async def test_lock_release_failure_is_non_fatal(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
original_session_ref = compaction_service._session_ref
def patched_session_ref(an, uid, sid):
ref = original_session_ref(an, uid, sid)
original_update = ref.update
async def failing_update(data):
if "compaction_lock" in data:
raise RuntimeError("Firestore write failed")
return await original_update(data)
ref.update = failing_update
return ref
with patch.object(
compaction_service,
"_session_ref",
side_effect=patched_session_ref,
):
# Should not raise despite lock release failure
await compaction_service._guarded_compact(session)
# ------------------------------------------------------------------
# close()
# ------------------------------------------------------------------
class TestClose:
async def test_close_no_tasks(self, compaction_service):
await compaction_service.close()
async def test_close_awaits_tasks(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
base = time.time()
for i in range(4):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=base + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
trigger = Event(
author=app_name,
content=Content(
role="model", parts=[Part(text="trigger")]
),
timestamp=base + 4,
invocation_id="inv-4",
usage_metadata=GenerateContentResponseUsageMetadata(
total_token_count=200,
),
)
await compaction_service.append_event(session, trigger)
assert len(compaction_service._active_tasks) > 0
await compaction_service.close()
assert len(compaction_service._active_tasks) == 0

View File

@@ -0,0 +1,428 @@
"""Tests for FirestoreSessionService against the Firestore emulator."""
from __future__ import annotations
import time
import uuid
import pytest
from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.sessions.base_session_service import GetSessionConfig
from google.genai.types import Content, Part
pytestmark = pytest.mark.asyncio
# ------------------------------------------------------------------
# create_session
# ------------------------------------------------------------------
class TestCreateSession:
async def test_auto_generates_id(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
assert session.id
assert session.app_name == app_name
assert session.user_id == user_id
assert session.last_update_time > 0
async def test_custom_id(self, service, app_name, user_id):
sid = "my-custom-session"
session = await service.create_session(
app_name=app_name, user_id=user_id, session_id=sid
)
assert session.id == sid
async def test_duplicate_id_raises(self, service, app_name, user_id):
sid = "dup-session"
await service.create_session(
app_name=app_name, user_id=user_id, session_id=sid
)
with pytest.raises(AlreadyExistsError):
await service.create_session(
app_name=app_name, user_id=user_id, session_id=sid
)
async def test_session_state(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name,
user_id=user_id,
state={"count": 42},
)
assert session.state["count"] == 42
async def test_scoped_state(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name,
user_id=user_id,
state={
"app:global_flag": True,
"user:lang": "es",
"local_key": "val",
},
)
assert session.state["app:global_flag"] is True
assert session.state["user:lang"] == "es"
assert session.state["local_key"] == "val"
async def test_temp_state_not_persisted(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name,
user_id=user_id,
state={"temp:scratch": "gone", "keep": "yes"},
)
retrieved = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert "temp:scratch" not in retrieved.state
assert retrieved.state["keep"] == "yes"
# ------------------------------------------------------------------
# get_session
# ------------------------------------------------------------------
class TestGetSession:
async def test_nonexistent_returns_none(self, service, app_name, user_id):
result = await service.get_session(
app_name=app_name, user_id=user_id, session_id="nope"
)
assert result is None
async def test_roundtrip(self, service, app_name, user_id):
created = await service.create_session(
app_name=app_name,
user_id=user_id,
state={"foo": "bar"},
)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=created.id
)
assert fetched is not None
assert fetched.id == created.id
assert fetched.state["foo"] == "bar"
assert fetched.last_update_time == pytest.approx(
created.last_update_time, abs=0.01
)
async def test_returns_events(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(parts=[Part(text="hello")]),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 1
assert fetched.events[0].author == "user"
async def test_num_recent_events(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await service.append_event(session, e)
fetched = await service.get_session(
app_name=app_name,
user_id=user_id,
session_id=session.id,
config=GetSessionConfig(num_recent_events=2),
)
assert len(fetched.events) == 2
async def test_after_timestamp(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
base = time.time()
for i in range(3):
e = Event(
author="user",
timestamp=base + i,
invocation_id=f"inv-{i}",
)
await service.append_event(session, e)
fetched = await service.get_session(
app_name=app_name,
user_id=user_id,
session_id=session.id,
config=GetSessionConfig(after_timestamp=base + 1),
)
assert len(fetched.events) == 2
# ------------------------------------------------------------------
# list_sessions
# ------------------------------------------------------------------
class TestListSessions:
async def test_empty(self, service, app_name, user_id):
resp = await service.list_sessions(
app_name=app_name, user_id=user_id
)
assert resp.sessions == [] or resp.sessions is None
async def test_returns_created_sessions(
self, service, app_name, user_id
):
s1 = await service.create_session(
app_name=app_name, user_id=user_id
)
s2 = await service.create_session(
app_name=app_name, user_id=user_id
)
resp = await service.list_sessions(
app_name=app_name, user_id=user_id
)
ids = {s.id for s in resp.sessions}
assert s1.id in ids
assert s2.id in ids
async def test_filter_by_user(self, service, app_name):
uid1 = f"user_{uuid.uuid4().hex[:8]}"
uid2 = f"user_{uuid.uuid4().hex[:8]}"
await service.create_session(app_name=app_name, user_id=uid1)
await service.create_session(app_name=app_name, user_id=uid2)
resp = await service.list_sessions(
app_name=app_name, user_id=uid1
)
assert len(resp.sessions) == 1
assert resp.sessions[0].user_id == uid1
async def test_sessions_have_merged_state(
self, service, app_name, user_id
):
await service.create_session(
app_name=app_name,
user_id=user_id,
state={"app:shared": "yes", "local": "val"},
)
resp = await service.list_sessions(
app_name=app_name, user_id=user_id
)
s = resp.sessions[0]
assert s.state["app:shared"] == "yes"
assert s.state["local"] == "val"
async def test_sessions_have_no_events(
self, service, app_name, user_id
):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user", timestamp=time.time(), invocation_id="inv-1"
)
await service.append_event(session, event)
resp = await service.list_sessions(
app_name=app_name, user_id=user_id
)
assert resp.sessions[0].events == []
# ------------------------------------------------------------------
# delete_session
# ------------------------------------------------------------------
class TestDeleteSession:
async def test_delete(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
await service.delete_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
result = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert result is None
async def test_delete_removes_events(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user", timestamp=time.time(), invocation_id="inv-1"
)
await service.append_event(session, event)
await service.delete_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
result = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert result is None
# ------------------------------------------------------------------
# append_event
# ------------------------------------------------------------------
class TestAppendEvent:
async def test_basic(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(parts=[Part(text="hi")]),
timestamp=time.time(),
invocation_id="inv-1",
)
returned = await service.append_event(session, event)
assert returned.id == event.id
assert returned.timestamp > 0
async def test_partial_event_not_persisted(
self, service, app_name, user_id
):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
partial=True,
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 0
async def test_session_state_delta(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="agent",
actions=EventActions(state_delta={"counter": 1}),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert fetched.state["counter"] == 1
async def test_app_state_delta(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="agent",
actions=EventActions(state_delta={"app:version": "2.0"}),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert fetched.state["app:version"] == "2.0"
async def test_user_state_delta(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="agent",
actions=EventActions(state_delta={"user:pref": "dark"}),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert fetched.state["user:pref"] == "dark"
async def test_updates_last_update_time(
self, service, app_name, user_id
):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
original_time = session.last_update_time
event = Event(
author="user",
timestamp=time.time() + 10,
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert fetched.last_update_time > original_time
async def test_multiple_events_accumulate(
self, service, app_name, user_id
):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(3):
e = Event(
author="user",
content=Content(parts=[Part(text=f"msg {i}")]),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await service.append_event(session, e)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 3
async def test_app_state_shared_across_sessions(
self, service, app_name, user_id
):
s1 = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="agent",
actions=EventActions(state_delta={"app:shared_val": 99}),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(s1, event)
s2 = await service.create_session(
app_name=app_name, user_id=user_id
)
assert s2.state["app:shared_val"] == 99