Add compaction flow
This commit is contained in:
@@ -7,6 +7,7 @@ import time
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from google import genai
|
||||
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.sessions import _session_util
|
||||
@@ -19,6 +20,7 @@ from google.adk.sessions.session import Session
|
||||
from google.adk.sessions.state import State
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
from google.cloud.firestore_v1.field_path import FieldPath
|
||||
from google.genai.types import Content, Part
|
||||
from typing_extensions import override
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
@@ -45,9 +47,21 @@ class FirestoreSessionService(BaseSessionService):
|
||||
*,
|
||||
db: AsyncClient,
|
||||
collection_prefix: str = "adk",
|
||||
compaction_token_threshold: int | None = None,
|
||||
compaction_model: str = "gemini-2.5-flash",
|
||||
compaction_keep_recent: int = 10,
|
||||
genai_client: genai.Client | None = None,
|
||||
) -> None:
|
||||
if compaction_token_threshold is not None and genai_client is None:
|
||||
raise ValueError(
|
||||
"genai_client is required when compaction_token_threshold is set."
|
||||
)
|
||||
self._db = db
|
||||
self._prefix = collection_prefix
|
||||
self._compaction_threshold = compaction_token_threshold
|
||||
self._compaction_model = compaction_model
|
||||
self._compaction_keep_recent = compaction_keep_recent
|
||||
self._genai_client = genai_client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Document-reference helpers
|
||||
@@ -100,6 +114,100 @@ class FirestoreSessionService(BaseSessionService):
|
||||
merged[State.USER_PREFIX + key] = value
|
||||
return merged
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Compaction helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _events_to_text(events: list[Event]) -> str:
|
||||
lines: list[str] = []
|
||||
for event in events:
|
||||
if event.content and event.content.parts:
|
||||
text = "".join(p.text or "" for p in event.content.parts)
|
||||
if text:
|
||||
role = "User" if event.author == "user" else "Assistant"
|
||||
lines.append(f"{role}: {text}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
async def _generate_summary(
|
||||
self, existing_summary: str, events: list[Event]
|
||||
) -> str:
|
||||
conversation_text = self._events_to_text(events)
|
||||
previous = (
|
||||
"Previous summary of earlier conversation:\n"
|
||||
f"{existing_summary}\n\n"
|
||||
if existing_summary
|
||||
else ""
|
||||
)
|
||||
prompt = (
|
||||
"Summarize the following conversation between a user and an "
|
||||
"assistant. Preserve:\n"
|
||||
"- Key decisions and conclusions\n"
|
||||
"- User preferences and requirements\n"
|
||||
"- Important facts, names, and numbers\n"
|
||||
"- The overall topic and direction of the conversation\n"
|
||||
"- Any pending tasks or open questions\n\n"
|
||||
f"{previous}"
|
||||
f"Conversation:\n{conversation_text}\n\n"
|
||||
"Provide a clear, comprehensive summary."
|
||||
)
|
||||
assert self._genai_client is not None
|
||||
response = await self._genai_client.aio.models.generate_content(
|
||||
model=self._compaction_model,
|
||||
contents=prompt,
|
||||
)
|
||||
return response.text or ""
|
||||
|
||||
async def _compact_session(self, session: Session) -> None:
|
||||
app_name = session.app_name
|
||||
user_id = session.user_id
|
||||
session_id = session.id
|
||||
|
||||
events_ref = self._events_col(app_name, user_id, session_id)
|
||||
query = events_ref.order_by("timestamp")
|
||||
event_docs = await query.get()
|
||||
|
||||
if len(event_docs) <= self._compaction_keep_recent:
|
||||
return
|
||||
|
||||
all_events = [
|
||||
Event.model_validate(doc.to_dict()) for doc in event_docs
|
||||
]
|
||||
events_to_summarize = all_events[: -self._compaction_keep_recent]
|
||||
|
||||
session_snap = await self._session_ref(
|
||||
app_name, user_id, session_id
|
||||
).get()
|
||||
existing_summary = (session_snap.to_dict() or {}).get(
|
||||
"conversation_summary", ""
|
||||
)
|
||||
|
||||
try:
|
||||
summary = await self._generate_summary(
|
||||
existing_summary, events_to_summarize
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Compaction summary generation failed; skipping.")
|
||||
return
|
||||
|
||||
docs_to_delete = event_docs[: -self._compaction_keep_recent]
|
||||
for i in range(0, len(docs_to_delete), 500):
|
||||
batch = self._db.batch()
|
||||
for doc in docs_to_delete[i : i + 500]:
|
||||
batch.delete(doc.reference)
|
||||
await batch.commit()
|
||||
|
||||
await self._session_ref(app_name, user_id, session_id).update(
|
||||
{"conversation_summary": summary}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Compacted session %s: summarised %d events, kept %d.",
|
||||
session_id,
|
||||
len(docs_to_delete),
|
||||
self._compaction_keep_recent,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# BaseSessionService implementation
|
||||
# ------------------------------------------------------------------
|
||||
@@ -190,6 +298,47 @@ class FirestoreSessionService(BaseSessionService):
|
||||
if config and config.num_recent_events:
|
||||
events = events[-config.num_recent_events :]
|
||||
|
||||
# Prepend conversation summary as synthetic context events
|
||||
conversation_summary = session_data.get("conversation_summary")
|
||||
if conversation_summary:
|
||||
summary_event = Event(
|
||||
id="summary-context",
|
||||
author="user",
|
||||
content=Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part(
|
||||
text=(
|
||||
"[Conversation context from previous"
|
||||
" messages]\n"
|
||||
f"{conversation_summary}"
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
timestamp=0.0,
|
||||
invocation_id="compaction-summary",
|
||||
)
|
||||
ack_event = Event(
|
||||
id="summary-ack",
|
||||
author=app_name,
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
text=(
|
||||
"Understood, I have the context from our"
|
||||
" previous conversation and will continue"
|
||||
" accordingly."
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
timestamp=0.001,
|
||||
invocation_id="compaction-summary",
|
||||
)
|
||||
events = [summary_event, ack_event] + events
|
||||
|
||||
# Merge scoped state
|
||||
app_state = await self._get_app_state(app_name)
|
||||
user_state = await self._get_user_state(app_name, user_id)
|
||||
@@ -313,4 +462,14 @@ class FirestoreSessionService(BaseSessionService):
|
||||
else:
|
||||
await session_ref.update({"last_update_time": event.timestamp})
|
||||
|
||||
# Trigger compaction if total token count exceeds threshold
|
||||
if (
|
||||
self._compaction_threshold is not None
|
||||
and event.usage_metadata
|
||||
and event.usage_metadata.total_token_count
|
||||
and event.usage_metadata.total_token_count
|
||||
>= self._compaction_threshold
|
||||
):
|
||||
await self._compact_session(session)
|
||||
|
||||
return event
|
||||
|
||||
Reference in New Issue
Block a user