Compare commits

..

7 Commits

Author SHA1 Message Date
ajac-zero
547296fb2d Add compaction flow lock 2026-02-21 23:23:26 -06:00
ajac-zero
5f2f0474a5 Pool async calls 2026-02-21 22:45:48 -06:00
ajac-zero
dff25bcff0 Improve scripts 2026-02-21 22:35:17 -06:00
ajac-zero
ffcb2f4b90 Add logs 2026-02-21 22:26:16 -06:00
ajac-zero
52aa8bfe0d Use filters instead of string formatting 2026-02-21 22:09:54 -06:00
ajac-zero
3cb78afc3a Add compaction flow 2026-02-21 21:46:01 -06:00
ajac-zero
89b4d7ce73 Add gitignore 2026-02-21 20:53:48 -06:00
9 changed files with 1196 additions and 51 deletions

BIN
.coverage

Binary file not shown.

216
.gitignore vendored
View File

@@ -8,3 +8,219 @@ wheels/
# Virtual environments # Virtual environments
.venv .venv
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[codz]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py.cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
# Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
# poetry.lock
# poetry.toml
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
# pdm.lock
# pdm.toml
.pdm-python
.pdm-build/
# pixi
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
# pixi.lock
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
# in the .venv directory. It is recommended not to include this directory in version control.
.pixi
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# Redis
*.rdb
*.aof
*.pid
# RabbitMQ
mnesia/
rabbitmq/
rabbitmq-data/
# ActiveMQ
activemq-data/
# SageMath parsed files
*.sage.py
# Environments
.env
.envrc
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
# .idea/
# Abstra
# Abstra is an AI-powered process automation framework.
# Ignore directories containing user credentials, local state, and settings.
# Learn more at https://abstra.io/docs
.abstra/
# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
# Marimo
marimo/_static/
marimo/_lsp/
__marimo__/
# Streamlit
.streamlit/secrets.toml

59
chat.py
View File

@@ -1,18 +1,27 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Minimal CLI chat agent to test FirestoreSessionService.""" """Minimal CLI chat agent to test FirestoreSessionService."""
import argparse
import asyncio import asyncio
import logging
from google import genai
from google.adk.agents import LlmAgent from google.adk.agents import LlmAgent
from google.adk.runners import Runner from google.adk.runners import Runner
from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.async_client import AsyncClient
from google.genai import types from google.genai import types
from rich.console import Console
from rich.logging import RichHandler
from rich.markdown import Markdown
from rich.panel import Panel
from adk_firestore_sessionmanager import FirestoreSessionService from adk_firestore_sessionmanager import FirestoreSessionService
APP_NAME = "test_agent" APP_NAME = "test_agent"
USER_ID = "dev_user" USER_ID = "dev_user"
console = Console()
root_agent = LlmAgent( root_agent = LlmAgent(
name=APP_NAME, name=APP_NAME,
model="gemini-2.5-flash", model="gemini-2.5-flash",
@@ -20,9 +29,31 @@ root_agent = LlmAgent(
) )
async def main() -> None: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Chat with a Firestore-backed ADK agent.")
parser.add_argument(
"--log-level",
default="WARNING",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the log level (default: WARNING)",
)
return parser.parse_args()
async def main(args: argparse.Namespace) -> None:
logging.basicConfig(
level=args.log_level,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler(console=console, rich_tracebacks=True)],
)
db = AsyncClient() db = AsyncClient()
session_service = FirestoreSessionService(db=db) session_service = FirestoreSessionService(
db=db,
compaction_token_threshold=10_000,
genai_client=genai.Client(),
)
runner = Runner( runner = Runner(
app_name=APP_NAME, app_name=APP_NAME,
@@ -30,15 +61,29 @@ async def main() -> None:
session_service=session_service, session_service=session_service,
) )
# Reuse existing session or create a new one
resp = await session_service.list_sessions(
app_name=APP_NAME, user_id=USER_ID
)
if resp.sessions:
session = await session_service.get_session(
app_name=APP_NAME,
user_id=USER_ID,
session_id=resp.sessions[0].id,
)
console.print(f"Resuming session [bold cyan]{session.id}[/]")
else:
session = await session_service.create_session( session = await session_service.create_session(
app_name=APP_NAME, app_name=APP_NAME,
user_id=USER_ID, user_id=USER_ID,
) )
print(f"Session {session.id} created. Type 'exit' to quit.\n") console.print(f"Session [bold cyan]{session.id}[/] created.")
console.print("Type [bold]exit[/] to quit.\n")
while True: while True:
try: try:
user_input = input("You: ").strip() user_input = console.input("[bold green]You:[/] ").strip()
except (EOFError, KeyboardInterrupt): except (EOFError, KeyboardInterrupt):
break break
if not user_input or user_input.lower() == "exit": if not user_input or user_input.lower() == "exit":
@@ -54,11 +99,11 @@ async def main() -> None:
if event.content and event.content.parts and not event.partial: if event.content and event.content.parts and not event.partial:
text = "".join(p.text or "" for p in event.content.parts) text = "".join(p.text or "" for p in event.content.parts)
if text: if text:
print(f"Agent: {text}") console.print(Panel(Markdown(text), title="Agent", border_style="blue"))
await runner.close() await runner.close()
print("\nGoodbye!") console.print("\n[dim]Goodbye![/]")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main(parse_args()))

View File

@@ -21,6 +21,7 @@ dev = [
"pytest>=9.0.2", "pytest>=9.0.2",
"pytest-asyncio>=0.24", "pytest-asyncio>=0.24",
"pytest-cov>=7.0.0", "pytest-cov>=7.0.0",
"rich>=14.0.0",
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]

View File

@@ -2,11 +2,13 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import time import time
from typing import Any, Optional from typing import Any, Optional
import uuid import uuid
from google import genai
from google.adk.errors.already_exists_error import AlreadyExistsError from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event from google.adk.events.event import Event
from google.adk.sessions import _session_util from google.adk.sessions import _session_util
@@ -18,11 +20,30 @@ from google.adk.sessions.base_session_service import (
from google.adk.sessions.session import Session from google.adk.sessions.session import Session
from google.adk.sessions.state import State from google.adk.sessions.state import State
from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.async_client import AsyncClient
from google.cloud.firestore_v1.async_transaction import async_transactional
from google.cloud.firestore_v1.base_query import FieldFilter
from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.field_path import FieldPath
from google.genai.types import Content, Part
from typing_extensions import override from typing_extensions import override
logger = logging.getLogger("google_adk." + __name__) logger = logging.getLogger("google_adk." + __name__)
_COMPACTION_LOCK_TTL = 300 # seconds
@async_transactional
async def _try_claim_compaction_txn(transaction, session_ref):
"""Atomically claim the compaction lock if it is free or stale."""
snapshot = await session_ref.get(transaction=transaction)
if not snapshot.exists:
return False
data = snapshot.to_dict() or {}
lock_time = data.get("compaction_lock")
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
return False
transaction.update(session_ref, {"compaction_lock": time.time()})
return True
class FirestoreSessionService(BaseSessionService): class FirestoreSessionService(BaseSessionService):
"""A Firestore-backed implementation of BaseSessionService. """A Firestore-backed implementation of BaseSessionService.
@@ -45,9 +66,23 @@ class FirestoreSessionService(BaseSessionService):
*, *,
db: AsyncClient, db: AsyncClient,
collection_prefix: str = "adk", collection_prefix: str = "adk",
compaction_token_threshold: int | None = None,
compaction_model: str = "gemini-2.5-flash",
compaction_keep_recent: int = 10,
genai_client: genai.Client | None = None,
) -> None: ) -> None:
if compaction_token_threshold is not None and genai_client is None:
raise ValueError(
"genai_client is required when compaction_token_threshold is set."
)
self._db = db self._db = db
self._prefix = collection_prefix self._prefix = collection_prefix
self._compaction_threshold = compaction_token_threshold
self._compaction_model = compaction_model
self._compaction_keep_recent = compaction_keep_recent
self._genai_client = genai_client
self._compaction_locks: dict[str, asyncio.Lock] = {}
self._active_tasks: set[asyncio.Task] = set()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Document-reference helpers # Document-reference helpers
@@ -100,6 +135,153 @@ class FirestoreSessionService(BaseSessionService):
merged[State.USER_PREFIX + key] = value merged[State.USER_PREFIX + key] = value
return merged return merged
# ------------------------------------------------------------------
# Compaction helpers
# ------------------------------------------------------------------
@staticmethod
def _events_to_text(events: list[Event]) -> str:
lines: list[str] = []
for event in events:
if event.content and event.content.parts:
text = "".join(p.text or "" for p in event.content.parts)
if text:
role = "User" if event.author == "user" else "Assistant"
lines.append(f"{role}: {text}")
return "\n\n".join(lines)
async def _generate_summary(
self, existing_summary: str, events: list[Event]
) -> str:
conversation_text = self._events_to_text(events)
previous = (
"Previous summary of earlier conversation:\n"
f"{existing_summary}\n\n"
if existing_summary
else ""
)
prompt = (
"Summarize the following conversation between a user and an "
"assistant. Preserve:\n"
"- Key decisions and conclusions\n"
"- User preferences and requirements\n"
"- Important facts, names, and numbers\n"
"- The overall topic and direction of the conversation\n"
"- Any pending tasks or open questions\n\n"
f"{previous}"
f"Conversation:\n{conversation_text}\n\n"
"Provide a clear, comprehensive summary."
)
assert self._genai_client is not None
response = await self._genai_client.aio.models.generate_content(
model=self._compaction_model,
contents=prompt,
)
return response.text or ""
async def _compact_session(self, session: Session) -> None:
app_name = session.app_name
user_id = session.user_id
session_id = session.id
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref.order_by("timestamp")
event_docs = await query.get()
if len(event_docs) <= self._compaction_keep_recent:
return
all_events = [
Event.model_validate(doc.to_dict()) for doc in event_docs
]
events_to_summarize = all_events[: -self._compaction_keep_recent]
session_snap = await self._session_ref(
app_name, user_id, session_id
).get()
existing_summary = (session_snap.to_dict() or {}).get(
"conversation_summary", ""
)
try:
summary = await self._generate_summary(
existing_summary, events_to_summarize
)
except Exception:
logger.exception("Compaction summary generation failed; skipping.")
return
# Write summary BEFORE deleting events so a crash between the two
# steps leaves safe duplication rather than data loss.
await self._session_ref(app_name, user_id, session_id).update(
{"conversation_summary": summary}
)
docs_to_delete = event_docs[: -self._compaction_keep_recent]
for i in range(0, len(docs_to_delete), 500):
batch = self._db.batch()
for doc in docs_to_delete[i : i + 500]:
batch.delete(doc.reference)
await batch.commit()
logger.info(
"Compacted session %s: summarised %d events, kept %d.",
session_id,
len(docs_to_delete),
self._compaction_keep_recent,
)
async def _guarded_compact(self, session: Session) -> None:
"""Run compaction in the background with per-session locking."""
key = f"{session.app_name}__{session.user_id}__{session.id}"
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
if lock.locked():
logger.debug(
"Compaction already running locally for %s; skipping.", key
)
return
async with lock:
session_ref = self._session_ref(
session.app_name, session.user_id, session.id
)
try:
transaction = self._db.transaction()
claimed = await _try_claim_compaction_txn(
transaction, session_ref
)
except Exception:
logger.exception(
"Failed to claim compaction lock for %s", key
)
return
if not claimed:
logger.debug(
"Compaction lock held by another instance for %s;"
" skipping.",
key,
)
return
try:
await self._compact_session(session)
except Exception:
logger.exception("Background compaction failed for %s", key)
finally:
try:
await session_ref.update({"compaction_lock": None})
except Exception:
logger.exception(
"Failed to release compaction lock for %s", key
)
async def close(self) -> None:
"""Await all in-flight compaction tasks. Call before shutdown."""
if self._active_tasks:
await asyncio.gather(*self._active_tasks, return_exceptions=True)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# BaseSessionService implementation # BaseSessionService implementation
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -130,17 +312,23 @@ class FirestoreSessionService(BaseSessionService):
user_state_delta = state_deltas["user"] user_state_delta = state_deltas["user"]
session_state = state_deltas["session"] session_state = state_deltas["session"]
write_coros: list = []
if app_state_delta: if app_state_delta:
await self._app_state_ref(app_name).set( write_coros.append(
self._app_state_ref(app_name).set(
app_state_delta, merge=True app_state_delta, merge=True
) )
)
if user_state_delta: if user_state_delta:
await self._user_state_ref(app_name, user_id).set( write_coros.append(
self._user_state_ref(app_name, user_id).set(
user_state_delta, merge=True user_state_delta, merge=True
) )
)
now = time.time() now = time.time()
await self._session_ref(app_name, user_id, session_id).set( write_coros.append(
self._session_ref(app_name, user_id, session_id).set(
{ {
"app_name": app_name, "app_name": app_name,
"user_id": user_id, "user_id": user_id,
@@ -149,9 +337,13 @@ class FirestoreSessionService(BaseSessionService):
"last_update_time": now, "last_update_time": now,
} }
) )
)
await asyncio.gather(*write_coros)
app_state = await self._get_app_state(app_name) app_state, user_state = await asyncio.gather(
user_state = await self._get_user_state(app_name, user_id) self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
merged = self._merge_state(app_state, user_state, session_state or {}) merged = self._merge_state(app_state, user_state, session_state or {})
return Session( return Session(
@@ -181,18 +373,61 @@ class FirestoreSessionService(BaseSessionService):
events_ref = self._events_col(app_name, user_id, session_id) events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref query = events_ref
if config and config.after_timestamp: if config and config.after_timestamp:
query = query.where("timestamp", ">=", config.after_timestamp) query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp))
query = query.order_by("timestamp") query = query.order_by("timestamp")
event_docs = await query.get() event_docs, app_state, user_state = await asyncio.gather(
query.get(),
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
events = [Event.model_validate(doc.to_dict()) for doc in event_docs] events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
if config and config.num_recent_events: if config and config.num_recent_events:
events = events[-config.num_recent_events :] events = events[-config.num_recent_events :]
# Prepend conversation summary as synthetic context events
conversation_summary = session_data.get("conversation_summary")
if conversation_summary:
summary_event = Event(
id="summary-context",
author="user",
content=Content(
role="user",
parts=[
Part(
text=(
"[Conversation context from previous"
" messages]\n"
f"{conversation_summary}"
)
)
],
),
timestamp=0.0,
invocation_id="compaction-summary",
)
ack_event = Event(
id="summary-ack",
author=app_name,
content=Content(
role="model",
parts=[
Part(
text=(
"Understood, I have the context from our"
" previous conversation and will continue"
" accordingly."
)
)
],
),
timestamp=0.001,
invocation_id="compaction-summary",
)
events = [summary_event, ack_event] + events
# Merge scoped state # Merge scoped state
app_state = await self._get_app_state(app_name)
user_state = await self._get_user_state(app_name, user_id)
merged = self._merge_state( merged = self._merge_state(
app_state, user_state, session_data.get("state", {}) app_state, user_state, session_data.get("state", {})
) )
@@ -211,31 +446,31 @@ class FirestoreSessionService(BaseSessionService):
self, *, app_name: str, user_id: Optional[str] = None self, *, app_name: str, user_id: Optional[str] = None
) -> ListSessionsResponse: ) -> ListSessionsResponse:
query = self._db.collection(f"{self._prefix}_sessions").where( query = self._db.collection(f"{self._prefix}_sessions").where(
"app_name", "==", app_name filter=FieldFilter("app_name", "==", app_name)
) )
if user_id is not None: if user_id is not None:
query = query.where("user_id", "==", user_id) query = query.where(filter=FieldFilter("user_id", "==", user_id))
docs = await query.get() docs = await query.get()
if not docs: if not docs:
return ListSessionsResponse() return ListSessionsResponse()
# Pre-fetch app state (shared across all sessions in this app) doc_dicts = [doc.to_dict() for doc in docs]
app_state = await self._get_app_state(app_name)
# Cache user states to avoid repeated reads # Pre-fetch app state and all distinct user states in parallel
user_state_cache: dict[str, dict[str, Any]] = {} unique_user_ids = list({d["user_id"] for d in doc_dicts})
sessions: list[Session] = [] app_state, *user_states = await asyncio.gather(
self._get_app_state(app_name),
for doc in docs: *(
data = doc.to_dict() self._get_user_state(app_name, uid)
s_user_id = data["user_id"] for uid in unique_user_ids
),
if s_user_id not in user_state_cache:
user_state_cache[s_user_id] = await self._get_user_state(
app_name, s_user_id
) )
user_state_cache = dict(zip(unique_user_ids, user_states))
sessions: list[Session] = []
for data in doc_dicts:
s_user_id = data["user_id"]
merged = self._merge_state( merged = self._merge_state(
app_state, app_state,
user_state_cache[s_user_id], user_state_cache[s_user_id],
@@ -267,6 +502,8 @@ class FirestoreSessionService(BaseSessionService):
if event.partial: if event.partial:
return event return event
t0 = time.monotonic()
app_name = session.app_name app_name = session.app_name
user_id = session.user_id user_id = session.user_id
session_id = session.id session_id = session.id
@@ -290,14 +527,19 @@ class FirestoreSessionService(BaseSessionService):
event.actions.state_delta event.actions.state_delta
) )
write_coros: list = []
if state_deltas["app"]: if state_deltas["app"]:
await self._app_state_ref(app_name).set( write_coros.append(
self._app_state_ref(app_name).set(
state_deltas["app"], merge=True state_deltas["app"], merge=True
) )
)
if state_deltas["user"]: if state_deltas["user"]:
await self._user_state_ref(app_name, user_id).set( write_coros.append(
self._user_state_ref(app_name, user_id).set(
state_deltas["user"], merge=True state_deltas["user"], merge=True
) )
)
if state_deltas["session"]: if state_deltas["session"]:
field_updates: dict[str, Any] = { field_updates: dict[str, Any] = {
@@ -305,12 +547,54 @@ class FirestoreSessionService(BaseSessionService):
for k, v in state_deltas["session"].items() for k, v in state_deltas["session"].items()
} }
field_updates["last_update_time"] = event.timestamp field_updates["last_update_time"] = event.timestamp
await session_ref.update(field_updates) write_coros.append(session_ref.update(field_updates))
else: else:
await session_ref.update( write_coros.append(
{"last_update_time": event.timestamp} session_ref.update({"last_update_time": event.timestamp})
) )
await asyncio.gather(*write_coros)
else: else:
await session_ref.update({"last_update_time": event.timestamp}) await session_ref.update({"last_update_time": event.timestamp})
# Log token usage
if event.usage_metadata:
meta = event.usage_metadata
logger.info(
"Token usage for session %s event %s: "
"prompt=%s, candidates=%s, total=%s",
session_id,
event.id,
meta.prompt_token_count,
meta.candidates_token_count,
meta.total_token_count,
)
# Trigger compaction if total token count exceeds threshold
if (
self._compaction_threshold is not None
and event.usage_metadata
and event.usage_metadata.total_token_count
and event.usage_metadata.total_token_count
>= self._compaction_threshold
):
logger.info(
"Compaction triggered for session %s: "
"total_token_count=%d >= threshold=%d",
session_id,
event.usage_metadata.total_token_count,
self._compaction_threshold,
)
task = asyncio.create_task(self._guarded_compact(session))
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
elapsed = time.monotonic() - t0
logger.info(
"append_event completed for session %s event %s in %.3fs",
session_id,
event.id,
elapsed,
)
return event return event

View File

@@ -11,7 +11,7 @@ from google.cloud.firestore_v1.async_client import AsyncClient
from adk_firestore_sessionmanager import FirestoreSessionService from adk_firestore_sessionmanager import FirestoreSessionService
os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8161") os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8219")
@pytest_asyncio.fixture @pytest_asyncio.fixture

519
tests/test_compaction.py Normal file
View File

@@ -0,0 +1,519 @@
"""Tests for conversation compaction in FirestoreSessionService."""
from __future__ import annotations
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import uuid
import pytest
import pytest_asyncio
from google import genai
from google.adk.events.event import Event
from google.cloud.firestore_v1.async_client import AsyncClient
from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part
from adk_firestore_sessionmanager import FirestoreSessionService
from adk_firestore_sessionmanager.firestore_session_service import (
_try_claim_compaction_txn,
)
pytestmark = pytest.mark.asyncio
@pytest_asyncio.fixture
async def mock_genai_client():
client = MagicMock(spec=genai.Client)
response = MagicMock()
response.text = "Summary of the conversation so far."
client.aio.models.generate_content = AsyncMock(return_value=response)
return client
@pytest_asyncio.fixture
async def compaction_service(db: AsyncClient, mock_genai_client):
prefix = f"test_{uuid.uuid4().hex[:8]}"
return FirestoreSessionService(
db=db,
collection_prefix=prefix,
compaction_token_threshold=100,
compaction_keep_recent=2,
genai_client=mock_genai_client,
)
# ------------------------------------------------------------------
# __init__ validation
# ------------------------------------------------------------------
class TestCompactionInit:
async def test_requires_genai_client(self, db):
with pytest.raises(ValueError, match="genai_client is required"):
FirestoreSessionService(
db=db,
compaction_token_threshold=1000,
)
async def test_no_threshold_no_client_ok(self, db):
svc = FirestoreSessionService(db=db)
assert svc._compaction_threshold is None
# ------------------------------------------------------------------
# Compaction trigger
# ------------------------------------------------------------------
class TestCompactionTrigger:
async def test_compaction_triggered_above_threshold(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
# Add 5 events, last one with usage_metadata above threshold
base = time.time()
for i in range(4):
e = Event(
author="user" if i % 2 == 0 else app_name,
content=Content(
role="user" if i % 2 == 0 else "model",
parts=[Part(text=f"message {i}")],
),
timestamp=base + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# This event crosses the threshold
trigger_event = Event(
author=app_name,
content=Content(
role="model", parts=[Part(text="final response")]
),
timestamp=base + 4,
invocation_id="inv-4",
usage_metadata=GenerateContentResponseUsageMetadata(
total_token_count=200,
),
)
await compaction_service.append_event(session, trigger_event)
await compaction_service.close()
# Summary generation should have been called
mock_genai_client.aio.models.generate_content.assert_called_once()
# Fetch session: should have summary + only keep_recent events
fetched = await compaction_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
# 2 synthetic summary events + 2 kept real events
assert len(fetched.events) == 4
assert fetched.events[0].id == "summary-context"
assert fetched.events[1].id == "summary-ack"
assert "Summary of the conversation" in fetched.events[0].content.parts[0].text
async def test_no_compaction_below_threshold(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author=app_name,
content=Content(
role="model", parts=[Part(text="short reply")]
),
timestamp=time.time(),
invocation_id="inv-1",
usage_metadata=GenerateContentResponseUsageMetadata(
total_token_count=50,
),
)
await compaction_service.append_event(session, event)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_no_compaction_without_usage_metadata(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(
role="user", parts=[Part(text="hello")]
),
timestamp=time.time(),
invocation_id="inv-1",
)
await compaction_service.append_event(session, event)
mock_genai_client.aio.models.generate_content.assert_not_called()
# ------------------------------------------------------------------
# Compaction with too few events (nothing to compact)
# ------------------------------------------------------------------
class TestCompactionEdgeCases:
async def test_skip_when_fewer_events_than_keep_recent(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
# Only 2 events, keep_recent=2 → nothing to summarize
for i in range(2):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Trigger compaction manually even though threshold wouldn't fire
await compaction_service._compact_session(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_summary_generation_failure_is_non_fatal(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Make summary generation fail
mock_genai_client.aio.models.generate_content = AsyncMock(
side_effect=RuntimeError("API error")
)
# Should not raise
await compaction_service._compact_session(session)
# All events should still be present
fetched = await compaction_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 5
# ------------------------------------------------------------------
# get_session with summary
# ------------------------------------------------------------------
class TestGetSessionWithSummary:
async def test_no_summary_no_synthetic_events(
self, compaction_service, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(
role="user", parts=[Part(text="hello")]
),
timestamp=time.time(),
invocation_id="inv-1",
)
await compaction_service.append_event(session, event)
fetched = await compaction_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 1
assert fetched.events[0].author == "user"
# ------------------------------------------------------------------
# _events_to_text
# ------------------------------------------------------------------
class TestEventsToText:
async def test_formats_user_and_assistant(self):
events = [
Event(
author="user",
content=Content(
role="user", parts=[Part(text="Hi there")]
),
timestamp=1.0,
invocation_id="inv-1",
),
Event(
author="bot",
content=Content(
role="model", parts=[Part(text="Hello!")]
),
timestamp=2.0,
invocation_id="inv-2",
),
]
text = FirestoreSessionService._events_to_text(events)
assert "User: Hi there" in text
assert "Assistant: Hello!" in text
async def test_skips_events_without_text(self):
events = [
Event(
author="user",
timestamp=1.0,
invocation_id="inv-1",
),
]
text = FirestoreSessionService._events_to_text(events)
assert text == ""
# ------------------------------------------------------------------
# Firestore distributed lock
# ------------------------------------------------------------------
class TestCompactionLock:
async def test_claim_and_release(
self, compaction_service, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
# Claim the lock
transaction = compaction_service._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, session_ref)
assert claimed is True
# Lock is now held — second claim should fail
transaction2 = compaction_service._db.transaction()
claimed2 = await _try_claim_compaction_txn(transaction2, session_ref)
assert claimed2 is False
# Release the lock
await session_ref.update({"compaction_lock": None})
# Can claim again after release
transaction3 = compaction_service._db.transaction()
claimed3 = await _try_claim_compaction_txn(transaction3, session_ref)
assert claimed3 is True
async def test_stale_lock_can_be_reclaimed(
self, compaction_service, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
# Set a stale lock (older than TTL)
await session_ref.update({"compaction_lock": time.time() - 600})
# Should be able to reclaim a stale lock
transaction = compaction_service._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, session_ref)
assert claimed is True
async def test_claim_nonexistent_session(self, compaction_service):
ref = compaction_service._session_ref("no_app", "no_user", "no_id")
transaction = compaction_service._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, ref)
assert claimed is False
# ------------------------------------------------------------------
# Guarded compact
# ------------------------------------------------------------------
class TestGuardedCompact:
async def test_local_lock_skips_concurrent(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Hold the in-process lock so _guarded_compact skips
key = f"{app_name}__{user_id}__{session.id}"
lock = compaction_service._compaction_locks.setdefault(
key, asyncio.Lock()
)
async with lock:
await compaction_service._guarded_compact(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_firestore_lock_held_skips(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Set a fresh Firestore lock (simulating another instance)
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
await session_ref.update({"compaction_lock": time.time()})
await compaction_service._guarded_compact(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_claim_failure_logs_and_skips(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
with patch(
"adk_firestore_sessionmanager.firestore_session_service"
"._try_claim_compaction_txn",
side_effect=RuntimeError("Firestore down"),
):
await compaction_service._guarded_compact(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_compaction_failure_releases_lock(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
# Make _compact_session raise an unhandled exception
with patch.object(
compaction_service,
"_compact_session",
side_effect=RuntimeError("unexpected crash"),
):
await compaction_service._guarded_compact(session)
# Lock should be released even after failure
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
snap = await session_ref.get()
assert snap.to_dict().get("compaction_lock") is None
async def test_lock_release_failure_is_non_fatal(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
original_session_ref = compaction_service._session_ref
def patched_session_ref(an, uid, sid):
ref = original_session_ref(an, uid, sid)
original_update = ref.update
async def failing_update(data):
if "compaction_lock" in data:
raise RuntimeError("Firestore write failed")
return await original_update(data)
ref.update = failing_update
return ref
with patch.object(
compaction_service,
"_session_ref",
side_effect=patched_session_ref,
):
# Should not raise despite lock release failure
await compaction_service._guarded_compact(session)
# ------------------------------------------------------------------
# close()
# ------------------------------------------------------------------
class TestClose:
async def test_close_no_tasks(self, compaction_service):
await compaction_service.close()
async def test_close_awaits_tasks(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
base = time.time()
for i in range(4):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=base + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
trigger = Event(
author=app_name,
content=Content(
role="model", parts=[Part(text="trigger")]
),
timestamp=base + 4,
invocation_id="inv-4",
usage_metadata=GenerateContentResponseUsageMetadata(
total_token_count=200,
),
)
await compaction_service.append_event(session, trigger)
assert len(compaction_service._active_tasks) > 0
await compaction_service.close()
assert len(compaction_service._active_tasks) == 0

36
uv.lock generated
View File

@@ -21,6 +21,7 @@ dev = [
{ name = "pytest" }, { name = "pytest" },
{ name = "pytest-asyncio" }, { name = "pytest-asyncio" },
{ name = "pytest-cov" }, { name = "pytest-cov" },
{ name = "rich" },
] ]
[package.metadata] [package.metadata]
@@ -34,6 +35,7 @@ dev = [
{ name = "pytest", specifier = ">=9.0.2" }, { name = "pytest", specifier = ">=9.0.2" },
{ name = "pytest-asyncio", specifier = ">=0.24" }, { name = "pytest-asyncio", specifier = ">=0.24" },
{ name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-cov", specifier = ">=7.0.0" },
{ name = "rich", specifier = ">=14.0.0" },
] ]
[[package]] [[package]]
@@ -1432,6 +1434,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" },
] ]
[[package]]
name = "markdown-it-py"
version = "4.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mdurl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
]
[[package]] [[package]]
name = "markupsafe" name = "markupsafe"
version = "3.0.3" version = "3.0.3"
@@ -1520,6 +1534,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" }, { url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" },
] ]
[[package]]
name = "mdurl"
version = "0.1.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
]
[[package]] [[package]]
name = "mmh3" name = "mmh3"
version = "5.2.0" version = "5.2.0"
@@ -2352,6 +2375,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
] ]
[[package]]
name = "rich"
version = "14.3.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markdown-it-py" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" },
]
[[package]] [[package]]
name = "rpds-py" name = "rpds-py"
version = "0.30.0" version = "0.30.0"

44
view_summary.py Normal file
View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python3
"""Print the conversation summary for a specific user's session."""
import asyncio
from google.cloud.firestore_v1.async_client import AsyncClient
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from adk_firestore_sessionmanager import FirestoreSessionService
APP_NAME = "test_agent"
USER_ID = "dev_user"
console = Console()
async def main() -> None:
db = AsyncClient()
session_service = FirestoreSessionService(db=db)
resp = await session_service.list_sessions(
app_name=APP_NAME, user_id=USER_ID
)
if not resp.sessions:
console.print("[dim]No sessions found.[/]")
return
for s in resp.sessions:
ref = session_service._session_ref(APP_NAME, USER_ID, s.id)
snap = await ref.get()
data = snap.to_dict() or {}
summary = data.get("conversation_summary")
if summary:
console.print(Panel(Markdown(summary), title=f"Session {s.id}", border_style="cyan"))
else:
console.print(Panel("[dim]No summary yet.[/]", title=f"Session {s.id}", border_style="yellow"))
if __name__ == "__main__":
asyncio.run(main())