diff --git a/README.md b/README.md index d154857..948f12b 100644 --- a/README.md +++ b/README.md @@ -90,3 +90,23 @@ For open source projects, say how it is licensed. ## Project status If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers. + +## Tests +### Compaction +Follow these steps before running the compaction test suite: + +1. Install the required dependencies (Java and Google Cloud CLI): + ```bash + mise use -g gcloud + mise use -g java + ``` +2. Open another terminal (or create a `tmux` pane) and start the Firestore emulator: + ```bash + gcloud emulators firestore start --host-port=localhost:8153 + ``` +3. Execute the tests with `pytest` through `uv`: + ```bash + uv run pytest tests/test_compaction.py -v + ``` + +If any step fails, double-check that the tools are installed and available on your `PATH` before trying again. diff --git a/src/va_agent/compaction.py b/src/va_agent/compaction.py new file mode 100644 index 0000000..e752ffd --- /dev/null +++ b/src/va_agent/compaction.py @@ -0,0 +1,213 @@ +"""Session compaction utilities for managing conversation history.""" + +from __future__ import annotations + +import asyncio +import logging +import time +from typing import TYPE_CHECKING, Any + +from google.adk.events.event import Event +from google.cloud.firestore_v1.async_transaction import async_transactional + +if TYPE_CHECKING: + from google import genai + from google.adk.sessions.session import Session + from google.cloud.firestore_v1.async_client import AsyncClient + +logger = logging.getLogger("google_adk." + __name__) + +_COMPACTION_LOCK_TTL = 300 # seconds + + +@async_transactional +async def _try_claim_compaction_txn(transaction: Any, session_ref: Any) -> bool: + """Atomically claim the compaction lock if it is free or stale.""" + snapshot = await session_ref.get(transaction=transaction) + if not snapshot.exists: + return False + data = snapshot.to_dict() or {} + lock_time = data.get("compaction_lock") + if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL: + return False + transaction.update(session_ref, {"compaction_lock": time.time()}) + return True + + +class SessionCompactor: + """Handles conversation history compaction for Firestore sessions. + + This class manages the automatic summarization and archival of older + conversation events to keep token counts manageable while preserving + context through AI-generated summaries. + """ + + def __init__( + self, + *, + db: AsyncClient, + genai_client: genai.Client | None = None, + compaction_model: str = "gemini-2.5-flash", + compaction_keep_recent: int = 10, + ) -> None: + """Initialize SessionCompactor. + + Args: + db: Firestore async client + genai_client: GenAI client for generating summaries + compaction_model: Model to use for summarization + compaction_keep_recent: Number of recent events to keep uncompacted + + """ + self._db = db + self._genai_client = genai_client + self._compaction_model = compaction_model + self._compaction_keep_recent = compaction_keep_recent + self._compaction_locks: dict[str, asyncio.Lock] = {} + + @staticmethod + def _events_to_text(events: list[Event]) -> str: + """Convert a list of events to a readable conversation text format.""" + lines: list[str] = [] + for event in events: + if event.content and event.content.parts: + text = "".join(p.text or "" for p in event.content.parts) + if text: + role = "User" if event.author == "user" else "Assistant" + lines.append(f"{role}: {text}") + return "\n\n".join(lines) + + async def _generate_summary( + self, existing_summary: str, events: list[Event] + ) -> str: + """Generate or update a conversation summary using the GenAI model.""" + conversation_text = self._events_to_text(events) + previous = ( + f"Previous summary of earlier conversation:\n{existing_summary}\n\n" + if existing_summary + else "" + ) + prompt = ( + "Summarize the following conversation between a user and an " + "assistant. Preserve:\n" + "- Key decisions and conclusions\n" + "- User preferences and requirements\n" + "- Important facts, names, and numbers\n" + "- The overall topic and direction of the conversation\n" + "- Any pending tasks or open questions\n\n" + f"{previous}" + f"Conversation:\n{conversation_text}\n\n" + "Provide a clear, comprehensive summary." + ) + if self._genai_client is None: + msg = "genai_client is required for compaction" + raise RuntimeError(msg) + response = await self._genai_client.aio.models.generate_content( + model=self._compaction_model, + contents=prompt, + ) + return response.text or "" + + async def _compact_session( + self, + session: Session, + events_col_ref: Any, + session_ref: Any, + ) -> None: + """Perform the actual compaction: summarize old events and delete them. + + Args: + session: The session to compact + events_col_ref: Firestore collection reference for events + session_ref: Firestore document reference for the session + + """ + query = events_col_ref.order_by("timestamp") + event_docs = await query.get() + + if len(event_docs) <= self._compaction_keep_recent: + return + + all_events = [Event.model_validate(doc.to_dict()) for doc in event_docs] + events_to_summarize = all_events[: -self._compaction_keep_recent] + + session_snap = await session_ref.get() + existing_summary = (session_snap.to_dict() or {}).get( + "conversation_summary", "" + ) + + try: + summary = await self._generate_summary( + existing_summary, events_to_summarize + ) + except Exception: + logger.exception("Compaction summary generation failed; skipping.") + return + + # Write summary BEFORE deleting events so a crash between the two + # steps leaves safe duplication rather than data loss. + await session_ref.update({"conversation_summary": summary}) + + docs_to_delete = event_docs[: -self._compaction_keep_recent] + for i in range(0, len(docs_to_delete), 500): + batch = self._db.batch() + for doc in docs_to_delete[i : i + 500]: + batch.delete(doc.reference) + await batch.commit() + + logger.info( + "Compacted session %s: summarised %d events, kept %d.", + session.id, + len(docs_to_delete), + self._compaction_keep_recent, + ) + + async def guarded_compact( + self, + session: Session, + events_col_ref: Any, + session_ref: Any, + ) -> None: + """Run compaction in the background with per-session locking. + + This method ensures that only one compaction process runs at a time + for a given session, both locally (using asyncio locks) and across + multiple instances (using Firestore-backed locks). + + Args: + session: The session to compact + events_col_ref: Firestore collection reference for events + session_ref: Firestore document reference for the session + + """ + key = f"{session.app_name}__{session.user_id}__{session.id}" + lock = self._compaction_locks.setdefault(key, asyncio.Lock()) + + if lock.locked(): + logger.debug("Compaction already running locally for %s; skipping.", key) + return + + async with lock: + try: + transaction = self._db.transaction() + claimed = await _try_claim_compaction_txn(transaction, session_ref) + except Exception: + logger.exception("Failed to claim compaction lock for %s", key) + return + + if not claimed: + logger.debug( + "Compaction lock held by another instance for %s; skipping.", + key, + ) + return + + try: + await self._compact_session(session, events_col_ref, session_ref) + except Exception: + logger.exception("Background compaction failed for %s", key) + finally: + try: + await session_ref.update({"compaction_lock": None}) + except Exception: + logger.exception("Failed to release compaction lock for %s", key) diff --git a/src/va_agent/session.py b/src/va_agent/session.py index bb7873a..706d7de 100644 --- a/src/va_agent/session.py +++ b/src/va_agent/session.py @@ -18,33 +18,18 @@ from google.adk.sessions.base_session_service import ( ) from google.adk.sessions.session import Session from google.adk.sessions.state import State -from google.cloud.firestore_v1.async_transaction import async_transactional from google.cloud.firestore_v1.base_query import FieldFilter from google.cloud.firestore_v1.field_path import FieldPath from google.genai.types import Content, Part +from .compaction import SessionCompactor + if TYPE_CHECKING: from google import genai from google.cloud.firestore_v1.async_client import AsyncClient logger = logging.getLogger("google_adk." + __name__) -_COMPACTION_LOCK_TTL = 300 # seconds - - -@async_transactional -async def _try_claim_compaction_txn(transaction: Any, session_ref: Any) -> bool: - """Atomically claim the compaction lock if it is free or stale.""" - snapshot = await session_ref.get(transaction=transaction) - if not snapshot.exists: - return False - data = snapshot.to_dict() or {} - lock_time = data.get("compaction_lock") - if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL: - return False - transaction.update(session_ref, {"compaction_lock": time.time()}) - return True - class FirestoreSessionService(BaseSessionService): """A Firestore-backed implementation of BaseSessionService. @@ -89,10 +74,12 @@ class FirestoreSessionService(BaseSessionService): self._db = db self._prefix = collection_prefix self._compaction_threshold = compaction_token_threshold - self._compaction_model = compaction_model - self._compaction_keep_recent = compaction_keep_recent - self._genai_client = genai_client - self._compaction_locks: dict[str, asyncio.Lock] = {} + self._compactor = SessionCompactor( + db=db, + genai_client=genai_client, + compaction_model=compaction_model, + compaction_keep_recent=compaction_keep_recent, + ) self._active_tasks: set[asyncio.Task] = set() # ------------------------------------------------------------------ @@ -140,136 +127,6 @@ class FirestoreSessionService(BaseSessionService): merged[State.USER_PREFIX + key] = value return merged - # ------------------------------------------------------------------ - # Compaction helpers - # ------------------------------------------------------------------ - - @staticmethod - def _events_to_text(events: list[Event]) -> str: - lines: list[str] = [] - for event in events: - if event.content and event.content.parts: - text = "".join(p.text or "" for p in event.content.parts) - if text: - role = "User" if event.author == "user" else "Assistant" - lines.append(f"{role}: {text}") - return "\n\n".join(lines) - - async def _generate_summary( - self, existing_summary: str, events: list[Event] - ) -> str: - conversation_text = self._events_to_text(events) - previous = ( - f"Previous summary of earlier conversation:\n{existing_summary}\n\n" - if existing_summary - else "" - ) - prompt = ( - "Summarize the following conversation between a user and an " - "assistant. Preserve:\n" - "- Key decisions and conclusions\n" - "- User preferences and requirements\n" - "- Important facts, names, and numbers\n" - "- The overall topic and direction of the conversation\n" - "- Any pending tasks or open questions\n\n" - f"{previous}" - f"Conversation:\n{conversation_text}\n\n" - "Provide a clear, comprehensive summary." - ) - if self._genai_client is None: - msg = "genai_client is required for compaction" - raise RuntimeError(msg) - response = await self._genai_client.aio.models.generate_content( - model=self._compaction_model, - contents=prompt, - ) - return response.text or "" - - async def _compact_session(self, session: Session) -> None: - app_name = session.app_name - user_id = session.user_id - session_id = session.id - - events_ref = self._events_col(app_name, user_id, session_id) - query = events_ref.order_by("timestamp") - event_docs = await query.get() - - if len(event_docs) <= self._compaction_keep_recent: - return - - all_events = [Event.model_validate(doc.to_dict()) for doc in event_docs] - events_to_summarize = all_events[: -self._compaction_keep_recent] - - session_snap = await self._session_ref(app_name, user_id, session_id).get() - existing_summary = (session_snap.to_dict() or {}).get( - "conversation_summary", "" - ) - - try: - summary = await self._generate_summary( - existing_summary, events_to_summarize - ) - except Exception: - logger.exception("Compaction summary generation failed; skipping.") - return - - # Write summary BEFORE deleting events so a crash between the two - # steps leaves safe duplication rather than data loss. - await self._session_ref(app_name, user_id, session_id).update( - {"conversation_summary": summary} - ) - - docs_to_delete = event_docs[: -self._compaction_keep_recent] - for i in range(0, len(docs_to_delete), 500): - batch = self._db.batch() - for doc in docs_to_delete[i : i + 500]: - batch.delete(doc.reference) - await batch.commit() - - logger.info( - "Compacted session %s: summarised %d events, kept %d.", - session_id, - len(docs_to_delete), - self._compaction_keep_recent, - ) - - async def _guarded_compact(self, session: Session) -> None: - """Run compaction in the background with per-session locking.""" - key = f"{session.app_name}__{session.user_id}__{session.id}" - lock = self._compaction_locks.setdefault(key, asyncio.Lock()) - - if lock.locked(): - logger.debug("Compaction already running locally for %s; skipping.", key) - return - - async with lock: - session_ref = self._session_ref( - session.app_name, session.user_id, session.id - ) - try: - transaction = self._db.transaction() - claimed = await _try_claim_compaction_txn(transaction, session_ref) - except Exception: - logger.exception("Failed to claim compaction lock for %s", key) - return - - if not claimed: - logger.debug( - "Compaction lock held by another instance for %s; skipping.", - key, - ) - return - - try: - await self._compact_session(session) - except Exception: - logger.exception("Background compaction failed for %s", key) - finally: - try: - await session_ref.update({"compaction_lock": None}) - except Exception: - logger.exception("Failed to release compaction lock for %s", key) - async def close(self) -> None: """Await all in-flight compaction tasks. Call before shutdown.""" if self._active_tasks: @@ -567,7 +424,11 @@ class FirestoreSessionService(BaseSessionService): event.usage_metadata.total_token_count, self._compaction_threshold, ) - task = asyncio.create_task(self._guarded_compact(session)) + events_ref = self._events_col(app_name, user_id, session_id) + session_ref = self._session_ref(app_name, user_id, session_id) + task = asyncio.create_task( + self._compactor.guarded_compact(session, events_ref, session_ref) + ) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) diff --git a/tests/test_compaction.py b/tests/test_compaction.py index 3e7b232..eca75a2 100644 --- a/tests/test_compaction.py +++ b/tests/test_compaction.py @@ -14,7 +14,8 @@ from google.adk.events.event import Event from google.cloud.firestore_v1.async_client import AsyncClient from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part -from va_agent.session import FirestoreSessionService, _try_claim_compaction_txn +from va_agent.session import FirestoreSessionService +from va_agent.compaction import SessionCompactor, _try_claim_compaction_txn pytestmark = pytest.mark.asyncio @@ -178,7 +179,9 @@ class TestCompactionEdgeCases: await compaction_service.append_event(session, e) # Trigger compaction manually even though threshold wouldn't fire - await compaction_service._compact_session(session) + events_ref = compaction_service._events_col(app_name, user_id, session.id) + session_ref = compaction_service._session_ref(app_name, user_id, session.id) + await compaction_service._compactor._compact_session(session, events_ref, session_ref) mock_genai_client.aio.models.generate_content.assert_not_called() @@ -205,7 +208,9 @@ class TestCompactionEdgeCases: ) # Should not raise - await compaction_service._compact_session(session) + events_ref = compaction_service._events_col(app_name, user_id, session.id) + session_ref = compaction_service._session_ref(app_name, user_id, session.id) + await compaction_service._compactor._compact_session(session, events_ref, session_ref) # All events should still be present fetched = await compaction_service.get_session( @@ -268,7 +273,7 @@ class TestEventsToText: invocation_id="inv-2", ), ] - text = FirestoreSessionService._events_to_text(events) + text = SessionCompactor._events_to_text(events) assert "User: Hi there" in text assert "Assistant: Hello!" in text @@ -280,7 +285,7 @@ class TestEventsToText: invocation_id="inv-1", ), ] - text = FirestoreSessionService._events_to_text(events) + text = SessionCompactor._events_to_text(events) assert text == "" @@ -368,11 +373,15 @@ class TestGuardedCompact: # Hold the in-process lock so _guarded_compact skips key = f"{app_name}__{user_id}__{session.id}" - lock = compaction_service._compaction_locks.setdefault( + lock = compaction_service._compactor._compaction_locks.setdefault( key, asyncio.Lock() ) + events_ref = compaction_service._events_col(app_name, user_id, session.id) + session_ref = compaction_service._session_ref(app_name, user_id, session.id) async with lock: - await compaction_service._guarded_compact(session) + await compaction_service._compactor.guarded_compact( + session, events_ref, session_ref + ) mock_genai_client.aio.models.generate_content.assert_not_called() @@ -399,7 +408,10 @@ class TestGuardedCompact: ) await session_ref.update({"compaction_lock": time.time()}) - await compaction_service._guarded_compact(session) + events_ref = compaction_service._events_col(app_name, user_id, session.id) + await compaction_service._compactor.guarded_compact( + session, events_ref, session_ref + ) mock_genai_client.aio.models.generate_content.assert_not_called() @@ -411,10 +423,18 @@ class TestGuardedCompact: ) with patch( - "va_agent.session._try_claim_compaction_txn", + "va_agent.compaction._try_claim_compaction_txn", side_effect=RuntimeError("Firestore down"), ): - await compaction_service._guarded_compact(session) + events_ref = compaction_service._events_col( + app_name, user_id, session.id + ) + session_ref = compaction_service._session_ref( + app_name, user_id, session.id + ) + await compaction_service._compactor.guarded_compact( + session, events_ref, session_ref + ) mock_genai_client.aio.models.generate_content.assert_not_called() @@ -427,11 +447,19 @@ class TestGuardedCompact: # Make _compact_session raise an unhandled exception with patch.object( - compaction_service, + compaction_service._compactor, "_compact_session", side_effect=RuntimeError("unexpected crash"), ): - await compaction_service._guarded_compact(session) + events_ref = compaction_service._events_col( + app_name, user_id, session.id + ) + session_ref = compaction_service._session_ref( + app_name, user_id, session.id + ) + await compaction_service._compactor.guarded_compact( + session, events_ref, session_ref + ) # Lock should be released even after failure session_ref = compaction_service._session_ref( @@ -467,7 +495,11 @@ class TestGuardedCompact: side_effect=patched_session_ref, ): # Should not raise despite lock release failure - await compaction_service._guarded_compact(session) + events_ref = compaction_service._events_col(app_name, user_id, session.id) + session_ref = compaction_service._session_ref(app_name, user_id, session.id) + await compaction_service._compactor.guarded_compact( + session, events_ref, session_ref + ) # ------------------------------------------------------------------