diff --git a/.coverage b/.coverage deleted file mode 100644 index dab01fc..0000000 Binary files a/.coverage and /dev/null differ diff --git a/chat.py b/chat.py index bb37377..42346a2 100644 --- a/chat.py +++ b/chat.py @@ -3,6 +3,7 @@ import asyncio +from google import genai from google.adk.agents import LlmAgent from google.adk.runners import Runner from google.cloud.firestore_v1.async_client import AsyncClient @@ -22,7 +23,11 @@ root_agent = LlmAgent( async def main() -> None: db = AsyncClient() - session_service = FirestoreSessionService(db=db) + session_service = FirestoreSessionService( + db=db, + compaction_token_threshold=800_000, + genai_client=genai.Client(), + ) runner = Runner( app_name=APP_NAME, @@ -30,11 +35,25 @@ async def main() -> None: session_service=session_service, ) - session = await session_service.create_session( - app_name=APP_NAME, - user_id=USER_ID, + # Reuse existing session or create a new one + resp = await session_service.list_sessions( + app_name=APP_NAME, user_id=USER_ID ) - print(f"Session {session.id} created. Type 'exit' to quit.\n") + if resp.sessions: + session = await session_service.get_session( + app_name=APP_NAME, + user_id=USER_ID, + session_id=resp.sessions[0].id, + ) + print(f"Resuming session {session.id}.") + else: + session = await session_service.create_session( + app_name=APP_NAME, + user_id=USER_ID, + ) + print(f"Session {session.id} created.") + + print("Type 'exit' to quit.\n") while True: try: diff --git a/src/adk_firestore_sessionmanager/firestore_session_service.py b/src/adk_firestore_sessionmanager/firestore_session_service.py index f1fa67c..c0c4071 100644 --- a/src/adk_firestore_sessionmanager/firestore_session_service.py +++ b/src/adk_firestore_sessionmanager/firestore_session_service.py @@ -7,6 +7,7 @@ import time from typing import Any, Optional import uuid +from google import genai from google.adk.errors.already_exists_error import AlreadyExistsError from google.adk.events.event import Event from google.adk.sessions import _session_util @@ -19,6 +20,7 @@ from google.adk.sessions.session import Session from google.adk.sessions.state import State from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.field_path import FieldPath +from google.genai.types import Content, Part from typing_extensions import override logger = logging.getLogger("google_adk." + __name__) @@ -45,9 +47,21 @@ class FirestoreSessionService(BaseSessionService): *, db: AsyncClient, collection_prefix: str = "adk", + compaction_token_threshold: int | None = None, + compaction_model: str = "gemini-2.5-flash", + compaction_keep_recent: int = 10, + genai_client: genai.Client | None = None, ) -> None: + if compaction_token_threshold is not None and genai_client is None: + raise ValueError( + "genai_client is required when compaction_token_threshold is set." + ) self._db = db self._prefix = collection_prefix + self._compaction_threshold = compaction_token_threshold + self._compaction_model = compaction_model + self._compaction_keep_recent = compaction_keep_recent + self._genai_client = genai_client # ------------------------------------------------------------------ # Document-reference helpers @@ -100,6 +114,100 @@ class FirestoreSessionService(BaseSessionService): merged[State.USER_PREFIX + key] = value return merged + # ------------------------------------------------------------------ + # Compaction helpers + # ------------------------------------------------------------------ + + @staticmethod + def _events_to_text(events: list[Event]) -> str: + lines: list[str] = [] + for event in events: + if event.content and event.content.parts: + text = "".join(p.text or "" for p in event.content.parts) + if text: + role = "User" if event.author == "user" else "Assistant" + lines.append(f"{role}: {text}") + return "\n\n".join(lines) + + async def _generate_summary( + self, existing_summary: str, events: list[Event] + ) -> str: + conversation_text = self._events_to_text(events) + previous = ( + "Previous summary of earlier conversation:\n" + f"{existing_summary}\n\n" + if existing_summary + else "" + ) + prompt = ( + "Summarize the following conversation between a user and an " + "assistant. Preserve:\n" + "- Key decisions and conclusions\n" + "- User preferences and requirements\n" + "- Important facts, names, and numbers\n" + "- The overall topic and direction of the conversation\n" + "- Any pending tasks or open questions\n\n" + f"{previous}" + f"Conversation:\n{conversation_text}\n\n" + "Provide a clear, comprehensive summary." + ) + assert self._genai_client is not None + response = await self._genai_client.aio.models.generate_content( + model=self._compaction_model, + contents=prompt, + ) + return response.text or "" + + async def _compact_session(self, session: Session) -> None: + app_name = session.app_name + user_id = session.user_id + session_id = session.id + + events_ref = self._events_col(app_name, user_id, session_id) + query = events_ref.order_by("timestamp") + event_docs = await query.get() + + if len(event_docs) <= self._compaction_keep_recent: + return + + all_events = [ + Event.model_validate(doc.to_dict()) for doc in event_docs + ] + events_to_summarize = all_events[: -self._compaction_keep_recent] + + session_snap = await self._session_ref( + app_name, user_id, session_id + ).get() + existing_summary = (session_snap.to_dict() or {}).get( + "conversation_summary", "" + ) + + try: + summary = await self._generate_summary( + existing_summary, events_to_summarize + ) + except Exception: + logger.exception("Compaction summary generation failed; skipping.") + return + + docs_to_delete = event_docs[: -self._compaction_keep_recent] + for i in range(0, len(docs_to_delete), 500): + batch = self._db.batch() + for doc in docs_to_delete[i : i + 500]: + batch.delete(doc.reference) + await batch.commit() + + await self._session_ref(app_name, user_id, session_id).update( + {"conversation_summary": summary} + ) + + logger.info( + "Compacted session %s: summarised %d events, kept %d.", + session_id, + len(docs_to_delete), + self._compaction_keep_recent, + ) + # ------------------------------------------------------------------ # BaseSessionService implementation # ------------------------------------------------------------------ @@ -190,6 +298,47 @@ class FirestoreSessionService(BaseSessionService): if config and config.num_recent_events: events = events[-config.num_recent_events :] + # Prepend conversation summary as synthetic context events + conversation_summary = session_data.get("conversation_summary") + if conversation_summary: + summary_event = Event( + id="summary-context", + author="user", + content=Content( + role="user", + parts=[ + Part( + text=( + "[Conversation context from previous" + " messages]\n" + f"{conversation_summary}" + ) + ) + ], + ), + timestamp=0.0, + invocation_id="compaction-summary", + ) + ack_event = Event( + id="summary-ack", + author=app_name, + content=Content( + role="model", + parts=[ + Part( + text=( + "Understood, I have the context from our" + " previous conversation and will continue" + " accordingly." + ) + ) + ], + ), + timestamp=0.001, + invocation_id="compaction-summary", + ) + events = [summary_event, ack_event] + events + # Merge scoped state app_state = await self._get_app_state(app_name) user_state = await self._get_user_state(app_name, user_id) @@ -313,4 +462,14 @@ class FirestoreSessionService(BaseSessionService): else: await session_ref.update({"last_update_time": event.timestamp}) + # Trigger compaction if total token count exceeds threshold + if ( + self._compaction_threshold is not None + and event.usage_metadata + and event.usage_metadata.total_token_count + and event.usage_metadata.total_token_count + >= self._compaction_threshold + ): + await self._compact_session(session) + return event diff --git a/tests/conftest.py b/tests/conftest.py index 1ca3382..93e9424 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ from google.cloud.firestore_v1.async_client import AsyncClient from adk_firestore_sessionmanager import FirestoreSessionService -os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8161") +os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8219") @pytest_asyncio.fixture diff --git a/tests/test_compaction.py b/tests/test_compaction.py new file mode 100644 index 0000000..788f937 --- /dev/null +++ b/tests/test_compaction.py @@ -0,0 +1,282 @@ +"""Tests for conversation compaction in FirestoreSessionService.""" + +from __future__ import annotations + +import time +from unittest.mock import AsyncMock, MagicMock +import uuid + +import pytest +import pytest_asyncio +from google import genai +from google.adk.events.event import Event +from google.cloud.firestore_v1.async_client import AsyncClient +from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part + +from adk_firestore_sessionmanager import FirestoreSessionService + +pytestmark = pytest.mark.asyncio + + +@pytest_asyncio.fixture +async def mock_genai_client(): + client = MagicMock(spec=genai.Client) + response = MagicMock() + response.text = "Summary of the conversation so far." + client.aio.models.generate_content = AsyncMock(return_value=response) + return client + + +@pytest_asyncio.fixture +async def compaction_service(db: AsyncClient, mock_genai_client): + prefix = f"test_{uuid.uuid4().hex[:8]}" + return FirestoreSessionService( + db=db, + collection_prefix=prefix, + compaction_token_threshold=100, + compaction_keep_recent=2, + genai_client=mock_genai_client, + ) + + +# ------------------------------------------------------------------ +# __init__ validation +# ------------------------------------------------------------------ + + +class TestCompactionInit: + async def test_requires_genai_client(self, db): + with pytest.raises(ValueError, match="genai_client is required"): + FirestoreSessionService( + db=db, + compaction_token_threshold=1000, + ) + + async def test_no_threshold_no_client_ok(self, db): + svc = FirestoreSessionService(db=db) + assert svc._compaction_threshold is None + + +# ------------------------------------------------------------------ +# Compaction trigger +# ------------------------------------------------------------------ + + +class TestCompactionTrigger: + async def test_compaction_triggered_above_threshold( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + + # Add 5 events, last one with usage_metadata above threshold + base = time.time() + for i in range(4): + e = Event( + author="user" if i % 2 == 0 else app_name, + content=Content( + role="user" if i % 2 == 0 else "model", + parts=[Part(text=f"message {i}")], + ), + timestamp=base + i, + invocation_id=f"inv-{i}", + ) + await compaction_service.append_event(session, e) + + # This event crosses the threshold + trigger_event = Event( + author=app_name, + content=Content( + role="model", parts=[Part(text="final response")] + ), + timestamp=base + 4, + invocation_id="inv-4", + usage_metadata=GenerateContentResponseUsageMetadata( + total_token_count=200, + ), + ) + await compaction_service.append_event(session, trigger_event) + + # Summary generation should have been called + mock_genai_client.aio.models.generate_content.assert_called_once() + + # Fetch session: should have summary + only keep_recent events + fetched = await compaction_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + # 2 synthetic summary events + 2 kept real events + assert len(fetched.events) == 4 + assert fetched.events[0].id == "summary-context" + assert fetched.events[1].id == "summary-ack" + assert "Summary of the conversation" in fetched.events[0].content.parts[0].text + + async def test_no_compaction_below_threshold( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + event = Event( + author=app_name, + content=Content( + role="model", parts=[Part(text="short reply")] + ), + timestamp=time.time(), + invocation_id="inv-1", + usage_metadata=GenerateContentResponseUsageMetadata( + total_token_count=50, + ), + ) + await compaction_service.append_event(session, event) + + mock_genai_client.aio.models.generate_content.assert_not_called() + + async def test_no_compaction_without_usage_metadata( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + event = Event( + author="user", + content=Content( + role="user", parts=[Part(text="hello")] + ), + timestamp=time.time(), + invocation_id="inv-1", + ) + await compaction_service.append_event(session, event) + + mock_genai_client.aio.models.generate_content.assert_not_called() + + +# ------------------------------------------------------------------ +# Compaction with too few events (nothing to compact) +# ------------------------------------------------------------------ + + +class TestCompactionEdgeCases: + async def test_skip_when_fewer_events_than_keep_recent( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + # Only 2 events, keep_recent=2 → nothing to summarize + for i in range(2): + e = Event( + author="user", + content=Content( + role="user", parts=[Part(text=f"msg {i}")] + ), + timestamp=time.time() + i, + invocation_id=f"inv-{i}", + ) + await compaction_service.append_event(session, e) + + # Trigger compaction manually even though threshold wouldn't fire + await compaction_service._compact_session(session) + + mock_genai_client.aio.models.generate_content.assert_not_called() + + async def test_summary_generation_failure_is_non_fatal( + self, compaction_service, mock_genai_client, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + for i in range(5): + e = Event( + author="user", + content=Content( + role="user", parts=[Part(text=f"msg {i}")] + ), + timestamp=time.time() + i, + invocation_id=f"inv-{i}", + ) + await compaction_service.append_event(session, e) + + # Make summary generation fail + mock_genai_client.aio.models.generate_content = AsyncMock( + side_effect=RuntimeError("API error") + ) + + # Should not raise + await compaction_service._compact_session(session) + + # All events should still be present + fetched = await compaction_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert len(fetched.events) == 5 + + +# ------------------------------------------------------------------ +# get_session with summary +# ------------------------------------------------------------------ + + +class TestGetSessionWithSummary: + async def test_no_summary_no_synthetic_events( + self, compaction_service, app_name, user_id + ): + session = await compaction_service.create_session( + app_name=app_name, user_id=user_id + ) + event = Event( + author="user", + content=Content( + role="user", parts=[Part(text="hello")] + ), + timestamp=time.time(), + invocation_id="inv-1", + ) + await compaction_service.append_event(session, event) + + fetched = await compaction_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert len(fetched.events) == 1 + assert fetched.events[0].author == "user" + + +# ------------------------------------------------------------------ +# _events_to_text +# ------------------------------------------------------------------ + + +class TestEventsToText: + def test_formats_user_and_assistant(self): + events = [ + Event( + author="user", + content=Content( + role="user", parts=[Part(text="Hi there")] + ), + timestamp=1.0, + invocation_id="inv-1", + ), + Event( + author="bot", + content=Content( + role="model", parts=[Part(text="Hello!")] + ), + timestamp=2.0, + invocation_id="inv-2", + ), + ] + text = FirestoreSessionService._events_to_text(events) + assert "User: Hi there" in text + assert "Assistant: Hello!" in text + + def test_skips_events_without_text(self): + events = [ + Event( + author="user", + timestamp=1.0, + invocation_id="inv-1", + ), + ] + text = FirestoreSessionService._events_to_text(events) + assert text == "" diff --git a/view_summary.py b/view_summary.py new file mode 100644 index 0000000..e60c73e --- /dev/null +++ b/view_summary.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +"""Print the conversation summary for a specific user's session.""" + +import asyncio + +from google.cloud.firestore_v1.async_client import AsyncClient + +from adk_firestore_sessionmanager import FirestoreSessionService + +APP_NAME = "test_agent" +USER_ID = "dev_user" + + +async def main() -> None: + db = AsyncClient() + session_service = FirestoreSessionService(db=db) + + resp = await session_service.list_sessions( + app_name=APP_NAME, user_id=USER_ID + ) + + if not resp.sessions: + print("No sessions found.") + return + + for s in resp.sessions: + ref = session_service._session_ref(APP_NAME, USER_ID, s.id) + snap = await ref.get() + data = snap.to_dict() or {} + summary = data.get("conversation_summary") + + print(f"Session: {s.id}") + if summary: + print(f"Summary:\n{summary}") + else: + print("No summary yet.") + print() + + +if __name__ == "__main__": + asyncio.run(main())