Add compaction flow

This commit is contained in:
ajac-zero
2026-02-21 21:32:06 -06:00
parent 89b4d7ce73
commit 3cb78afc3a
6 changed files with 507 additions and 6 deletions

View File

@@ -7,6 +7,7 @@ import time
from typing import Any, Optional
import uuid
from google import genai
from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event
from google.adk.sessions import _session_util
@@ -19,6 +20,7 @@ from google.adk.sessions.session import Session
from google.adk.sessions.state import State
from google.cloud.firestore_v1.async_client import AsyncClient
from google.cloud.firestore_v1.field_path import FieldPath
from google.genai.types import Content, Part
from typing_extensions import override
logger = logging.getLogger("google_adk." + __name__)
@@ -45,9 +47,21 @@ class FirestoreSessionService(BaseSessionService):
*,
db: AsyncClient,
collection_prefix: str = "adk",
compaction_token_threshold: int | None = None,
compaction_model: str = "gemini-2.5-flash",
compaction_keep_recent: int = 10,
genai_client: genai.Client | None = None,
) -> None:
if compaction_token_threshold is not None and genai_client is None:
raise ValueError(
"genai_client is required when compaction_token_threshold is set."
)
self._db = db
self._prefix = collection_prefix
self._compaction_threshold = compaction_token_threshold
self._compaction_model = compaction_model
self._compaction_keep_recent = compaction_keep_recent
self._genai_client = genai_client
# ------------------------------------------------------------------
# Document-reference helpers
@@ -100,6 +114,100 @@ class FirestoreSessionService(BaseSessionService):
merged[State.USER_PREFIX + key] = value
return merged
# ------------------------------------------------------------------
# Compaction helpers
# ------------------------------------------------------------------
@staticmethod
def _events_to_text(events: list[Event]) -> str:
lines: list[str] = []
for event in events:
if event.content and event.content.parts:
text = "".join(p.text or "" for p in event.content.parts)
if text:
role = "User" if event.author == "user" else "Assistant"
lines.append(f"{role}: {text}")
return "\n\n".join(lines)
async def _generate_summary(
self, existing_summary: str, events: list[Event]
) -> str:
conversation_text = self._events_to_text(events)
previous = (
"Previous summary of earlier conversation:\n"
f"{existing_summary}\n\n"
if existing_summary
else ""
)
prompt = (
"Summarize the following conversation between a user and an "
"assistant. Preserve:\n"
"- Key decisions and conclusions\n"
"- User preferences and requirements\n"
"- Important facts, names, and numbers\n"
"- The overall topic and direction of the conversation\n"
"- Any pending tasks or open questions\n\n"
f"{previous}"
f"Conversation:\n{conversation_text}\n\n"
"Provide a clear, comprehensive summary."
)
assert self._genai_client is not None
response = await self._genai_client.aio.models.generate_content(
model=self._compaction_model,
contents=prompt,
)
return response.text or ""
async def _compact_session(self, session: Session) -> None:
app_name = session.app_name
user_id = session.user_id
session_id = session.id
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref.order_by("timestamp")
event_docs = await query.get()
if len(event_docs) <= self._compaction_keep_recent:
return
all_events = [
Event.model_validate(doc.to_dict()) for doc in event_docs
]
events_to_summarize = all_events[: -self._compaction_keep_recent]
session_snap = await self._session_ref(
app_name, user_id, session_id
).get()
existing_summary = (session_snap.to_dict() or {}).get(
"conversation_summary", ""
)
try:
summary = await self._generate_summary(
existing_summary, events_to_summarize
)
except Exception:
logger.exception("Compaction summary generation failed; skipping.")
return
docs_to_delete = event_docs[: -self._compaction_keep_recent]
for i in range(0, len(docs_to_delete), 500):
batch = self._db.batch()
for doc in docs_to_delete[i : i + 500]:
batch.delete(doc.reference)
await batch.commit()
await self._session_ref(app_name, user_id, session_id).update(
{"conversation_summary": summary}
)
logger.info(
"Compacted session %s: summarised %d events, kept %d.",
session_id,
len(docs_to_delete),
self._compaction_keep_recent,
)
# ------------------------------------------------------------------
# BaseSessionService implementation
# ------------------------------------------------------------------
@@ -190,6 +298,47 @@ class FirestoreSessionService(BaseSessionService):
if config and config.num_recent_events:
events = events[-config.num_recent_events :]
# Prepend conversation summary as synthetic context events
conversation_summary = session_data.get("conversation_summary")
if conversation_summary:
summary_event = Event(
id="summary-context",
author="user",
content=Content(
role="user",
parts=[
Part(
text=(
"[Conversation context from previous"
" messages]\n"
f"{conversation_summary}"
)
)
],
),
timestamp=0.0,
invocation_id="compaction-summary",
)
ack_event = Event(
id="summary-ack",
author=app_name,
content=Content(
role="model",
parts=[
Part(
text=(
"Understood, I have the context from our"
" previous conversation and will continue"
" accordingly."
)
)
],
),
timestamp=0.001,
invocation_id="compaction-summary",
)
events = [summary_event, ack_event] + events
# Merge scoped state
app_state = await self._get_app_state(app_name)
user_state = await self._get_user_state(app_name, user_id)
@@ -313,4 +462,14 @@ class FirestoreSessionService(BaseSessionService):
else:
await session_ref.update({"last_update_time": event.timestamp})
# Trigger compaction if total token count exceeds threshold
if (
self._compaction_threshold is not None
and event.usage_metadata
and event.usage_metadata.total_token_count
and event.usage_metadata.total_token_count
>= self._compaction_threshold
):
await self._compact_session(session)
return event