Compare commits
8 Commits
14e4043d44
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 7bd0df4147 | |||
|
|
547296fb2d | ||
|
|
5f2f0474a5 | ||
|
|
dff25bcff0 | ||
|
|
ffcb2f4b90 | ||
|
|
52aa8bfe0d | ||
|
|
3cb78afc3a | ||
|
|
89b4d7ce73 |
216
.gitignore
vendored
216
.gitignore
vendored
@@ -8,3 +8,219 @@ wheels/
|
|||||||
|
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
.venv
|
.venv
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[codz]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
# Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
# poetry.lock
|
||||||
|
# poetry.toml
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||||
|
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||||
|
# pdm.lock
|
||||||
|
# pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# pixi
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||||
|
# pixi.lock
|
||||||
|
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||||
|
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||||
|
.pixi
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# Redis
|
||||||
|
*.rdb
|
||||||
|
*.aof
|
||||||
|
*.pid
|
||||||
|
|
||||||
|
# RabbitMQ
|
||||||
|
mnesia/
|
||||||
|
rabbitmq/
|
||||||
|
rabbitmq-data/
|
||||||
|
|
||||||
|
# ActiveMQ
|
||||||
|
activemq-data/
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.envrc
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
# .idea/
|
||||||
|
|
||||||
|
# Abstra
|
||||||
|
# Abstra is an AI-powered process automation framework.
|
||||||
|
# Ignore directories containing user credentials, local state, and settings.
|
||||||
|
# Learn more at https://abstra.io/docs
|
||||||
|
.abstra/
|
||||||
|
|
||||||
|
# Visual Studio Code
|
||||||
|
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||||
|
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||||
|
# you could uncomment the following to ignore the entire vscode folder
|
||||||
|
# .vscode/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
||||||
|
# Marimo
|
||||||
|
marimo/_static/
|
||||||
|
marimo/_lsp/
|
||||||
|
__marimo__/
|
||||||
|
|
||||||
|
# Streamlit
|
||||||
|
.streamlit/secrets.toml
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
Old experimental repo. Final implementation moved into [va/agent](https://gitea.ia-innovacion.work/va/agent)
|
||||||
59
chat.py
59
chat.py
@@ -1,18 +1,27 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Minimal CLI chat agent to test FirestoreSessionService."""
|
"""Minimal CLI chat agent to test FirestoreSessionService."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from google import genai
|
||||||
from google.adk.agents import LlmAgent
|
from google.adk.agents import LlmAgent
|
||||||
from google.adk.runners import Runner
|
from google.adk.runners import Runner
|
||||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
|
||||||
from adk_firestore_sessionmanager import FirestoreSessionService
|
from adk_firestore_sessionmanager import FirestoreSessionService
|
||||||
|
|
||||||
APP_NAME = "test_agent"
|
APP_NAME = "test_agent"
|
||||||
USER_ID = "dev_user"
|
USER_ID = "dev_user"
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
root_agent = LlmAgent(
|
root_agent = LlmAgent(
|
||||||
name=APP_NAME,
|
name=APP_NAME,
|
||||||
model="gemini-2.5-flash",
|
model="gemini-2.5-flash",
|
||||||
@@ -20,9 +29,31 @@ root_agent = LlmAgent(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Chat with a Firestore-backed ADK agent.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
default="WARNING",
|
||||||
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
|
help="Set the log level (default: WARNING)",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
async def main(args: argparse.Namespace) -> None:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=args.log_level,
|
||||||
|
format="%(message)s",
|
||||||
|
datefmt="[%X]",
|
||||||
|
handlers=[RichHandler(console=console, rich_tracebacks=True)],
|
||||||
|
)
|
||||||
|
|
||||||
db = AsyncClient()
|
db = AsyncClient()
|
||||||
session_service = FirestoreSessionService(db=db)
|
session_service = FirestoreSessionService(
|
||||||
|
db=db,
|
||||||
|
compaction_token_threshold=10_000,
|
||||||
|
genai_client=genai.Client(),
|
||||||
|
)
|
||||||
|
|
||||||
runner = Runner(
|
runner = Runner(
|
||||||
app_name=APP_NAME,
|
app_name=APP_NAME,
|
||||||
@@ -30,15 +61,29 @@ async def main() -> None:
|
|||||||
session_service=session_service,
|
session_service=session_service,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reuse existing session or create a new one
|
||||||
|
resp = await session_service.list_sessions(
|
||||||
|
app_name=APP_NAME, user_id=USER_ID
|
||||||
|
)
|
||||||
|
if resp.sessions:
|
||||||
|
session = await session_service.get_session(
|
||||||
|
app_name=APP_NAME,
|
||||||
|
user_id=USER_ID,
|
||||||
|
session_id=resp.sessions[0].id,
|
||||||
|
)
|
||||||
|
console.print(f"Resuming session [bold cyan]{session.id}[/]")
|
||||||
|
else:
|
||||||
session = await session_service.create_session(
|
session = await session_service.create_session(
|
||||||
app_name=APP_NAME,
|
app_name=APP_NAME,
|
||||||
user_id=USER_ID,
|
user_id=USER_ID,
|
||||||
)
|
)
|
||||||
print(f"Session {session.id} created. Type 'exit' to quit.\n")
|
console.print(f"Session [bold cyan]{session.id}[/] created.")
|
||||||
|
|
||||||
|
console.print("Type [bold]exit[/] to quit.\n")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
user_input = input("You: ").strip()
|
user_input = console.input("[bold green]You:[/] ").strip()
|
||||||
except (EOFError, KeyboardInterrupt):
|
except (EOFError, KeyboardInterrupt):
|
||||||
break
|
break
|
||||||
if not user_input or user_input.lower() == "exit":
|
if not user_input or user_input.lower() == "exit":
|
||||||
@@ -54,11 +99,11 @@ async def main() -> None:
|
|||||||
if event.content and event.content.parts and not event.partial:
|
if event.content and event.content.parts and not event.partial:
|
||||||
text = "".join(p.text or "" for p in event.content.parts)
|
text = "".join(p.text or "" for p in event.content.parts)
|
||||||
if text:
|
if text:
|
||||||
print(f"Agent: {text}")
|
console.print(Panel(Markdown(text), title="Agent", border_style="blue"))
|
||||||
|
|
||||||
await runner.close()
|
await runner.close()
|
||||||
print("\nGoodbye!")
|
console.print("\n[dim]Goodbye![/]")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main(parse_args()))
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ dev = [
|
|||||||
"pytest>=9.0.2",
|
"pytest>=9.0.2",
|
||||||
"pytest-asyncio>=0.24",
|
"pytest-asyncio>=0.24",
|
||||||
"pytest-cov>=7.0.0",
|
"pytest-cov>=7.0.0",
|
||||||
|
"rich>=14.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -2,11 +2,13 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from google import genai
|
||||||
from google.adk.errors.already_exists_error import AlreadyExistsError
|
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||||
from google.adk.events.event import Event
|
from google.adk.events.event import Event
|
||||||
from google.adk.sessions import _session_util
|
from google.adk.sessions import _session_util
|
||||||
@@ -18,11 +20,30 @@ from google.adk.sessions.base_session_service import (
|
|||||||
from google.adk.sessions.session import Session
|
from google.adk.sessions.session import Session
|
||||||
from google.adk.sessions.state import State
|
from google.adk.sessions.state import State
|
||||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
from google.cloud.firestore_v1.async_transaction import async_transactional
|
||||||
|
from google.cloud.firestore_v1.base_query import FieldFilter
|
||||||
from google.cloud.firestore_v1.field_path import FieldPath
|
from google.cloud.firestore_v1.field_path import FieldPath
|
||||||
|
from google.genai.types import Content, Part
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
logger = logging.getLogger("google_adk." + __name__)
|
logger = logging.getLogger("google_adk." + __name__)
|
||||||
|
|
||||||
|
_COMPACTION_LOCK_TTL = 300 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
@async_transactional
|
||||||
|
async def _try_claim_compaction_txn(transaction, session_ref):
|
||||||
|
"""Atomically claim the compaction lock if it is free or stale."""
|
||||||
|
snapshot = await session_ref.get(transaction=transaction)
|
||||||
|
if not snapshot.exists:
|
||||||
|
return False
|
||||||
|
data = snapshot.to_dict() or {}
|
||||||
|
lock_time = data.get("compaction_lock")
|
||||||
|
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
|
||||||
|
return False
|
||||||
|
transaction.update(session_ref, {"compaction_lock": time.time()})
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class FirestoreSessionService(BaseSessionService):
|
class FirestoreSessionService(BaseSessionService):
|
||||||
"""A Firestore-backed implementation of BaseSessionService.
|
"""A Firestore-backed implementation of BaseSessionService.
|
||||||
@@ -45,9 +66,23 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
*,
|
*,
|
||||||
db: AsyncClient,
|
db: AsyncClient,
|
||||||
collection_prefix: str = "adk",
|
collection_prefix: str = "adk",
|
||||||
|
compaction_token_threshold: int | None = None,
|
||||||
|
compaction_model: str = "gemini-2.5-flash",
|
||||||
|
compaction_keep_recent: int = 10,
|
||||||
|
genai_client: genai.Client | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if compaction_token_threshold is not None and genai_client is None:
|
||||||
|
raise ValueError(
|
||||||
|
"genai_client is required when compaction_token_threshold is set."
|
||||||
|
)
|
||||||
self._db = db
|
self._db = db
|
||||||
self._prefix = collection_prefix
|
self._prefix = collection_prefix
|
||||||
|
self._compaction_threshold = compaction_token_threshold
|
||||||
|
self._compaction_model = compaction_model
|
||||||
|
self._compaction_keep_recent = compaction_keep_recent
|
||||||
|
self._genai_client = genai_client
|
||||||
|
self._compaction_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._active_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Document-reference helpers
|
# Document-reference helpers
|
||||||
@@ -100,6 +135,153 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
merged[State.USER_PREFIX + key] = value
|
merged[State.USER_PREFIX + key] = value
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Compaction helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _events_to_text(events: list[Event]) -> str:
|
||||||
|
lines: list[str] = []
|
||||||
|
for event in events:
|
||||||
|
if event.content and event.content.parts:
|
||||||
|
text = "".join(p.text or "" for p in event.content.parts)
|
||||||
|
if text:
|
||||||
|
role = "User" if event.author == "user" else "Assistant"
|
||||||
|
lines.append(f"{role}: {text}")
|
||||||
|
return "\n\n".join(lines)
|
||||||
|
|
||||||
|
async def _generate_summary(
|
||||||
|
self, existing_summary: str, events: list[Event]
|
||||||
|
) -> str:
|
||||||
|
conversation_text = self._events_to_text(events)
|
||||||
|
previous = (
|
||||||
|
"Previous summary of earlier conversation:\n"
|
||||||
|
f"{existing_summary}\n\n"
|
||||||
|
if existing_summary
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
"Summarize the following conversation between a user and an "
|
||||||
|
"assistant. Preserve:\n"
|
||||||
|
"- Key decisions and conclusions\n"
|
||||||
|
"- User preferences and requirements\n"
|
||||||
|
"- Important facts, names, and numbers\n"
|
||||||
|
"- The overall topic and direction of the conversation\n"
|
||||||
|
"- Any pending tasks or open questions\n\n"
|
||||||
|
f"{previous}"
|
||||||
|
f"Conversation:\n{conversation_text}\n\n"
|
||||||
|
"Provide a clear, comprehensive summary."
|
||||||
|
)
|
||||||
|
assert self._genai_client is not None
|
||||||
|
response = await self._genai_client.aio.models.generate_content(
|
||||||
|
model=self._compaction_model,
|
||||||
|
contents=prompt,
|
||||||
|
)
|
||||||
|
return response.text or ""
|
||||||
|
|
||||||
|
async def _compact_session(self, session: Session) -> None:
|
||||||
|
app_name = session.app_name
|
||||||
|
user_id = session.user_id
|
||||||
|
session_id = session.id
|
||||||
|
|
||||||
|
events_ref = self._events_col(app_name, user_id, session_id)
|
||||||
|
query = events_ref.order_by("timestamp")
|
||||||
|
event_docs = await query.get()
|
||||||
|
|
||||||
|
if len(event_docs) <= self._compaction_keep_recent:
|
||||||
|
return
|
||||||
|
|
||||||
|
all_events = [
|
||||||
|
Event.model_validate(doc.to_dict()) for doc in event_docs
|
||||||
|
]
|
||||||
|
events_to_summarize = all_events[: -self._compaction_keep_recent]
|
||||||
|
|
||||||
|
session_snap = await self._session_ref(
|
||||||
|
app_name, user_id, session_id
|
||||||
|
).get()
|
||||||
|
existing_summary = (session_snap.to_dict() or {}).get(
|
||||||
|
"conversation_summary", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary = await self._generate_summary(
|
||||||
|
existing_summary, events_to_summarize
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Compaction summary generation failed; skipping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Write summary BEFORE deleting events so a crash between the two
|
||||||
|
# steps leaves safe duplication rather than data loss.
|
||||||
|
await self._session_ref(app_name, user_id, session_id).update(
|
||||||
|
{"conversation_summary": summary}
|
||||||
|
)
|
||||||
|
|
||||||
|
docs_to_delete = event_docs[: -self._compaction_keep_recent]
|
||||||
|
for i in range(0, len(docs_to_delete), 500):
|
||||||
|
batch = self._db.batch()
|
||||||
|
for doc in docs_to_delete[i : i + 500]:
|
||||||
|
batch.delete(doc.reference)
|
||||||
|
await batch.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Compacted session %s: summarised %d events, kept %d.",
|
||||||
|
session_id,
|
||||||
|
len(docs_to_delete),
|
||||||
|
self._compaction_keep_recent,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _guarded_compact(self, session: Session) -> None:
|
||||||
|
"""Run compaction in the background with per-session locking."""
|
||||||
|
key = f"{session.app_name}__{session.user_id}__{session.id}"
|
||||||
|
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
|
||||||
|
|
||||||
|
if lock.locked():
|
||||||
|
logger.debug(
|
||||||
|
"Compaction already running locally for %s; skipping.", key
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
session_ref = self._session_ref(
|
||||||
|
session.app_name, session.user_id, session.id
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
transaction = self._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(
|
||||||
|
transaction, session_ref
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to claim compaction lock for %s", key
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not claimed:
|
||||||
|
logger.debug(
|
||||||
|
"Compaction lock held by another instance for %s;"
|
||||||
|
" skipping.",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._compact_session(session)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Background compaction failed for %s", key)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await session_ref.update({"compaction_lock": None})
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to release compaction lock for %s", key
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Await all in-flight compaction tasks. Call before shutdown."""
|
||||||
|
if self._active_tasks:
|
||||||
|
await asyncio.gather(*self._active_tasks, return_exceptions=True)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# BaseSessionService implementation
|
# BaseSessionService implementation
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -130,17 +312,23 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
user_state_delta = state_deltas["user"]
|
user_state_delta = state_deltas["user"]
|
||||||
session_state = state_deltas["session"]
|
session_state = state_deltas["session"]
|
||||||
|
|
||||||
|
write_coros: list = []
|
||||||
if app_state_delta:
|
if app_state_delta:
|
||||||
await self._app_state_ref(app_name).set(
|
write_coros.append(
|
||||||
|
self._app_state_ref(app_name).set(
|
||||||
app_state_delta, merge=True
|
app_state_delta, merge=True
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if user_state_delta:
|
if user_state_delta:
|
||||||
await self._user_state_ref(app_name, user_id).set(
|
write_coros.append(
|
||||||
|
self._user_state_ref(app_name, user_id).set(
|
||||||
user_state_delta, merge=True
|
user_state_delta, merge=True
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
await self._session_ref(app_name, user_id, session_id).set(
|
write_coros.append(
|
||||||
|
self._session_ref(app_name, user_id, session_id).set(
|
||||||
{
|
{
|
||||||
"app_name": app_name,
|
"app_name": app_name,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@@ -149,9 +337,13 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
"last_update_time": now,
|
"last_update_time": now,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*write_coros)
|
||||||
|
|
||||||
app_state = await self._get_app_state(app_name)
|
app_state, user_state = await asyncio.gather(
|
||||||
user_state = await self._get_user_state(app_name, user_id)
|
self._get_app_state(app_name),
|
||||||
|
self._get_user_state(app_name, user_id),
|
||||||
|
)
|
||||||
merged = self._merge_state(app_state, user_state, session_state or {})
|
merged = self._merge_state(app_state, user_state, session_state or {})
|
||||||
|
|
||||||
return Session(
|
return Session(
|
||||||
@@ -181,18 +373,61 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
events_ref = self._events_col(app_name, user_id, session_id)
|
events_ref = self._events_col(app_name, user_id, session_id)
|
||||||
query = events_ref
|
query = events_ref
|
||||||
if config and config.after_timestamp:
|
if config and config.after_timestamp:
|
||||||
query = query.where("timestamp", ">=", config.after_timestamp)
|
query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp))
|
||||||
query = query.order_by("timestamp")
|
query = query.order_by("timestamp")
|
||||||
|
|
||||||
event_docs = await query.get()
|
event_docs, app_state, user_state = await asyncio.gather(
|
||||||
|
query.get(),
|
||||||
|
self._get_app_state(app_name),
|
||||||
|
self._get_user_state(app_name, user_id),
|
||||||
|
)
|
||||||
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||||
|
|
||||||
if config and config.num_recent_events:
|
if config and config.num_recent_events:
|
||||||
events = events[-config.num_recent_events :]
|
events = events[-config.num_recent_events :]
|
||||||
|
|
||||||
|
# Prepend conversation summary as synthetic context events
|
||||||
|
conversation_summary = session_data.get("conversation_summary")
|
||||||
|
if conversation_summary:
|
||||||
|
summary_event = Event(
|
||||||
|
id="summary-context",
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
text=(
|
||||||
|
"[Conversation context from previous"
|
||||||
|
" messages]\n"
|
||||||
|
f"{conversation_summary}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
timestamp=0.0,
|
||||||
|
invocation_id="compaction-summary",
|
||||||
|
)
|
||||||
|
ack_event = Event(
|
||||||
|
id="summary-ack",
|
||||||
|
author=app_name,
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
text=(
|
||||||
|
"Understood, I have the context from our"
|
||||||
|
" previous conversation and will continue"
|
||||||
|
" accordingly."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
timestamp=0.001,
|
||||||
|
invocation_id="compaction-summary",
|
||||||
|
)
|
||||||
|
events = [summary_event, ack_event] + events
|
||||||
|
|
||||||
# Merge scoped state
|
# Merge scoped state
|
||||||
app_state = await self._get_app_state(app_name)
|
|
||||||
user_state = await self._get_user_state(app_name, user_id)
|
|
||||||
merged = self._merge_state(
|
merged = self._merge_state(
|
||||||
app_state, user_state, session_data.get("state", {})
|
app_state, user_state, session_data.get("state", {})
|
||||||
)
|
)
|
||||||
@@ -211,31 +446,31 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
self, *, app_name: str, user_id: Optional[str] = None
|
self, *, app_name: str, user_id: Optional[str] = None
|
||||||
) -> ListSessionsResponse:
|
) -> ListSessionsResponse:
|
||||||
query = self._db.collection(f"{self._prefix}_sessions").where(
|
query = self._db.collection(f"{self._prefix}_sessions").where(
|
||||||
"app_name", "==", app_name
|
filter=FieldFilter("app_name", "==", app_name)
|
||||||
)
|
)
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
query = query.where("user_id", "==", user_id)
|
query = query.where(filter=FieldFilter("user_id", "==", user_id))
|
||||||
|
|
||||||
docs = await query.get()
|
docs = await query.get()
|
||||||
if not docs:
|
if not docs:
|
||||||
return ListSessionsResponse()
|
return ListSessionsResponse()
|
||||||
|
|
||||||
# Pre-fetch app state (shared across all sessions in this app)
|
doc_dicts = [doc.to_dict() for doc in docs]
|
||||||
app_state = await self._get_app_state(app_name)
|
|
||||||
|
|
||||||
# Cache user states to avoid repeated reads
|
# Pre-fetch app state and all distinct user states in parallel
|
||||||
user_state_cache: dict[str, dict[str, Any]] = {}
|
unique_user_ids = list({d["user_id"] for d in doc_dicts})
|
||||||
sessions: list[Session] = []
|
app_state, *user_states = await asyncio.gather(
|
||||||
|
self._get_app_state(app_name),
|
||||||
for doc in docs:
|
*(
|
||||||
data = doc.to_dict()
|
self._get_user_state(app_name, uid)
|
||||||
s_user_id = data["user_id"]
|
for uid in unique_user_ids
|
||||||
|
),
|
||||||
if s_user_id not in user_state_cache:
|
|
||||||
user_state_cache[s_user_id] = await self._get_user_state(
|
|
||||||
app_name, s_user_id
|
|
||||||
)
|
)
|
||||||
|
user_state_cache = dict(zip(unique_user_ids, user_states))
|
||||||
|
|
||||||
|
sessions: list[Session] = []
|
||||||
|
for data in doc_dicts:
|
||||||
|
s_user_id = data["user_id"]
|
||||||
merged = self._merge_state(
|
merged = self._merge_state(
|
||||||
app_state,
|
app_state,
|
||||||
user_state_cache[s_user_id],
|
user_state_cache[s_user_id],
|
||||||
@@ -267,6 +502,8 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
if event.partial:
|
if event.partial:
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
|
||||||
app_name = session.app_name
|
app_name = session.app_name
|
||||||
user_id = session.user_id
|
user_id = session.user_id
|
||||||
session_id = session.id
|
session_id = session.id
|
||||||
@@ -290,14 +527,19 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
event.actions.state_delta
|
event.actions.state_delta
|
||||||
)
|
)
|
||||||
|
|
||||||
|
write_coros: list = []
|
||||||
if state_deltas["app"]:
|
if state_deltas["app"]:
|
||||||
await self._app_state_ref(app_name).set(
|
write_coros.append(
|
||||||
|
self._app_state_ref(app_name).set(
|
||||||
state_deltas["app"], merge=True
|
state_deltas["app"], merge=True
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if state_deltas["user"]:
|
if state_deltas["user"]:
|
||||||
await self._user_state_ref(app_name, user_id).set(
|
write_coros.append(
|
||||||
|
self._user_state_ref(app_name, user_id).set(
|
||||||
state_deltas["user"], merge=True
|
state_deltas["user"], merge=True
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if state_deltas["session"]:
|
if state_deltas["session"]:
|
||||||
field_updates: dict[str, Any] = {
|
field_updates: dict[str, Any] = {
|
||||||
@@ -305,12 +547,54 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
for k, v in state_deltas["session"].items()
|
for k, v in state_deltas["session"].items()
|
||||||
}
|
}
|
||||||
field_updates["last_update_time"] = event.timestamp
|
field_updates["last_update_time"] = event.timestamp
|
||||||
await session_ref.update(field_updates)
|
write_coros.append(session_ref.update(field_updates))
|
||||||
else:
|
else:
|
||||||
await session_ref.update(
|
write_coros.append(
|
||||||
{"last_update_time": event.timestamp}
|
session_ref.update({"last_update_time": event.timestamp})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await asyncio.gather(*write_coros)
|
||||||
else:
|
else:
|
||||||
await session_ref.update({"last_update_time": event.timestamp})
|
await session_ref.update({"last_update_time": event.timestamp})
|
||||||
|
|
||||||
|
# Log token usage
|
||||||
|
if event.usage_metadata:
|
||||||
|
meta = event.usage_metadata
|
||||||
|
logger.info(
|
||||||
|
"Token usage for session %s event %s: "
|
||||||
|
"prompt=%s, candidates=%s, total=%s",
|
||||||
|
session_id,
|
||||||
|
event.id,
|
||||||
|
meta.prompt_token_count,
|
||||||
|
meta.candidates_token_count,
|
||||||
|
meta.total_token_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger compaction if total token count exceeds threshold
|
||||||
|
if (
|
||||||
|
self._compaction_threshold is not None
|
||||||
|
and event.usage_metadata
|
||||||
|
and event.usage_metadata.total_token_count
|
||||||
|
and event.usage_metadata.total_token_count
|
||||||
|
>= self._compaction_threshold
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Compaction triggered for session %s: "
|
||||||
|
"total_token_count=%d >= threshold=%d",
|
||||||
|
session_id,
|
||||||
|
event.usage_metadata.total_token_count,
|
||||||
|
self._compaction_threshold,
|
||||||
|
)
|
||||||
|
task = asyncio.create_task(self._guarded_compact(session))
|
||||||
|
self._active_tasks.add(task)
|
||||||
|
task.add_done_callback(self._active_tasks.discard)
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
logger.info(
|
||||||
|
"append_event completed for session %s event %s in %.3fs",
|
||||||
|
session_id,
|
||||||
|
event.id,
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from google.cloud.firestore_v1.async_client import AsyncClient
|
|||||||
|
|
||||||
from adk_firestore_sessionmanager import FirestoreSessionService
|
from adk_firestore_sessionmanager import FirestoreSessionService
|
||||||
|
|
||||||
os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8161")
|
os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8219")
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
|
|||||||
519
tests/test_compaction.py
Normal file
519
tests/test_compaction.py
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
"""Tests for conversation compaction in FirestoreSessionService."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from google import genai
|
||||||
|
from google.adk.events.event import Event
|
||||||
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part
|
||||||
|
|
||||||
|
from adk_firestore_sessionmanager import FirestoreSessionService
|
||||||
|
from adk_firestore_sessionmanager.firestore_session_service import (
|
||||||
|
_try_claim_compaction_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_genai_client():
|
||||||
|
client = MagicMock(spec=genai.Client)
|
||||||
|
response = MagicMock()
|
||||||
|
response.text = "Summary of the conversation so far."
|
||||||
|
client.aio.models.generate_content = AsyncMock(return_value=response)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def compaction_service(db: AsyncClient, mock_genai_client):
|
||||||
|
prefix = f"test_{uuid.uuid4().hex[:8]}"
|
||||||
|
return FirestoreSessionService(
|
||||||
|
db=db,
|
||||||
|
collection_prefix=prefix,
|
||||||
|
compaction_token_threshold=100,
|
||||||
|
compaction_keep_recent=2,
|
||||||
|
genai_client=mock_genai_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# __init__ validation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionInit:
|
||||||
|
async def test_requires_genai_client(self, db):
|
||||||
|
with pytest.raises(ValueError, match="genai_client is required"):
|
||||||
|
FirestoreSessionService(
|
||||||
|
db=db,
|
||||||
|
compaction_token_threshold=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_no_threshold_no_client_ok(self, db):
|
||||||
|
svc = FirestoreSessionService(db=db)
|
||||||
|
assert svc._compaction_threshold is None
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Compaction trigger
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionTrigger:
|
||||||
|
async def test_compaction_triggered_above_threshold(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add 5 events, last one with usage_metadata above threshold
|
||||||
|
base = time.time()
|
||||||
|
for i in range(4):
|
||||||
|
e = Event(
|
||||||
|
author="user" if i % 2 == 0 else app_name,
|
||||||
|
content=Content(
|
||||||
|
role="user" if i % 2 == 0 else "model",
|
||||||
|
parts=[Part(text=f"message {i}")],
|
||||||
|
),
|
||||||
|
timestamp=base + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# This event crosses the threshold
|
||||||
|
trigger_event = Event(
|
||||||
|
author=app_name,
|
||||||
|
content=Content(
|
||||||
|
role="model", parts=[Part(text="final response")]
|
||||||
|
),
|
||||||
|
timestamp=base + 4,
|
||||||
|
invocation_id="inv-4",
|
||||||
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
||||||
|
total_token_count=200,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, trigger_event)
|
||||||
|
await compaction_service.close()
|
||||||
|
|
||||||
|
# Summary generation should have been called
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_called_once()
|
||||||
|
|
||||||
|
# Fetch session: should have summary + only keep_recent events
|
||||||
|
fetched = await compaction_service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
# 2 synthetic summary events + 2 kept real events
|
||||||
|
assert len(fetched.events) == 4
|
||||||
|
assert fetched.events[0].id == "summary-context"
|
||||||
|
assert fetched.events[1].id == "summary-ack"
|
||||||
|
assert "Summary of the conversation" in fetched.events[0].content.parts[0].text
|
||||||
|
|
||||||
|
async def test_no_compaction_below_threshold(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author=app_name,
|
||||||
|
content=Content(
|
||||||
|
role="model", parts=[Part(text="short reply")]
|
||||||
|
),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
||||||
|
total_token_count=50,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, event)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_no_compaction_without_usage_metadata(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text="hello")]
|
||||||
|
),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, event)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Compaction with too few events (nothing to compact)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionEdgeCases:
|
||||||
|
async def test_skip_when_fewer_events_than_keep_recent(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
# Only 2 events, keep_recent=2 → nothing to summarize
|
||||||
|
for i in range(2):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Trigger compaction manually even though threshold wouldn't fire
|
||||||
|
await compaction_service._compact_session(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_summary_generation_failure_is_non_fatal(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Make summary generation fail
|
||||||
|
mock_genai_client.aio.models.generate_content = AsyncMock(
|
||||||
|
side_effect=RuntimeError("API error")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await compaction_service._compact_session(session)
|
||||||
|
|
||||||
|
# All events should still be present
|
||||||
|
fetched = await compaction_service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert len(fetched.events) == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# get_session with summary
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSessionWithSummary:
|
||||||
|
async def test_no_summary_no_synthetic_events(
|
||||||
|
self, compaction_service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text="hello")]
|
||||||
|
),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, event)
|
||||||
|
|
||||||
|
fetched = await compaction_service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert len(fetched.events) == 1
|
||||||
|
assert fetched.events[0].author == "user"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# _events_to_text
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventsToText:
|
||||||
|
async def test_formats_user_and_assistant(self):
|
||||||
|
events = [
|
||||||
|
Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text="Hi there")]
|
||||||
|
),
|
||||||
|
timestamp=1.0,
|
||||||
|
invocation_id="inv-1",
|
||||||
|
),
|
||||||
|
Event(
|
||||||
|
author="bot",
|
||||||
|
content=Content(
|
||||||
|
role="model", parts=[Part(text="Hello!")]
|
||||||
|
),
|
||||||
|
timestamp=2.0,
|
||||||
|
invocation_id="inv-2",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
text = FirestoreSessionService._events_to_text(events)
|
||||||
|
assert "User: Hi there" in text
|
||||||
|
assert "Assistant: Hello!" in text
|
||||||
|
|
||||||
|
async def test_skips_events_without_text(self):
|
||||||
|
events = [
|
||||||
|
Event(
|
||||||
|
author="user",
|
||||||
|
timestamp=1.0,
|
||||||
|
invocation_id="inv-1",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
text = FirestoreSessionService._events_to_text(events)
|
||||||
|
assert text == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Firestore distributed lock
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionLock:
|
||||||
|
async def test_claim_and_release(
|
||||||
|
self, compaction_service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Claim the lock
|
||||||
|
transaction = compaction_service._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||||
|
assert claimed is True
|
||||||
|
|
||||||
|
# Lock is now held — second claim should fail
|
||||||
|
transaction2 = compaction_service._db.transaction()
|
||||||
|
claimed2 = await _try_claim_compaction_txn(transaction2, session_ref)
|
||||||
|
assert claimed2 is False
|
||||||
|
|
||||||
|
# Release the lock
|
||||||
|
await session_ref.update({"compaction_lock": None})
|
||||||
|
|
||||||
|
# Can claim again after release
|
||||||
|
transaction3 = compaction_service._db.transaction()
|
||||||
|
claimed3 = await _try_claim_compaction_txn(transaction3, session_ref)
|
||||||
|
assert claimed3 is True
|
||||||
|
|
||||||
|
async def test_stale_lock_can_be_reclaimed(
|
||||||
|
self, compaction_service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set a stale lock (older than TTL)
|
||||||
|
await session_ref.update({"compaction_lock": time.time() - 600})
|
||||||
|
|
||||||
|
# Should be able to reclaim a stale lock
|
||||||
|
transaction = compaction_service._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||||
|
assert claimed is True
|
||||||
|
|
||||||
|
async def test_claim_nonexistent_session(self, compaction_service):
|
||||||
|
ref = compaction_service._session_ref("no_app", "no_user", "no_id")
|
||||||
|
transaction = compaction_service._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, ref)
|
||||||
|
assert claimed is False
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Guarded compact
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGuardedCompact:
|
||||||
|
async def test_local_lock_skips_concurrent(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Hold the in-process lock so _guarded_compact skips
|
||||||
|
key = f"{app_name}__{user_id}__{session.id}"
|
||||||
|
lock = compaction_service._compaction_locks.setdefault(
|
||||||
|
key, asyncio.Lock()
|
||||||
|
)
|
||||||
|
async with lock:
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_firestore_lock_held_skips(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Set a fresh Firestore lock (simulating another instance)
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
await session_ref.update({"compaction_lock": time.time()})
|
||||||
|
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_claim_failure_logs_and_skips(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"adk_firestore_sessionmanager.firestore_session_service"
|
||||||
|
"._try_claim_compaction_txn",
|
||||||
|
side_effect=RuntimeError("Firestore down"),
|
||||||
|
):
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_compaction_failure_releases_lock(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make _compact_session raise an unhandled exception
|
||||||
|
with patch.object(
|
||||||
|
compaction_service,
|
||||||
|
"_compact_session",
|
||||||
|
side_effect=RuntimeError("unexpected crash"),
|
||||||
|
):
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
# Lock should be released even after failure
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
snap = await session_ref.get()
|
||||||
|
assert snap.to_dict().get("compaction_lock") is None
|
||||||
|
|
||||||
|
async def test_lock_release_failure_is_non_fatal(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
original_session_ref = compaction_service._session_ref
|
||||||
|
|
||||||
|
def patched_session_ref(an, uid, sid):
|
||||||
|
ref = original_session_ref(an, uid, sid)
|
||||||
|
original_update = ref.update
|
||||||
|
|
||||||
|
async def failing_update(data):
|
||||||
|
if "compaction_lock" in data:
|
||||||
|
raise RuntimeError("Firestore write failed")
|
||||||
|
return await original_update(data)
|
||||||
|
|
||||||
|
ref.update = failing_update
|
||||||
|
return ref
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
compaction_service,
|
||||||
|
"_session_ref",
|
||||||
|
side_effect=patched_session_ref,
|
||||||
|
):
|
||||||
|
# Should not raise despite lock release failure
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# close()
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestClose:
|
||||||
|
async def test_close_no_tasks(self, compaction_service):
|
||||||
|
await compaction_service.close()
|
||||||
|
|
||||||
|
async def test_close_awaits_tasks(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
base = time.time()
|
||||||
|
for i in range(4):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=base + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
trigger = Event(
|
||||||
|
author=app_name,
|
||||||
|
content=Content(
|
||||||
|
role="model", parts=[Part(text="trigger")]
|
||||||
|
),
|
||||||
|
timestamp=base + 4,
|
||||||
|
invocation_id="inv-4",
|
||||||
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
||||||
|
total_token_count=200,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, trigger)
|
||||||
|
assert len(compaction_service._active_tasks) > 0
|
||||||
|
|
||||||
|
await compaction_service.close()
|
||||||
|
assert len(compaction_service._active_tasks) == 0
|
||||||
36
uv.lock
generated
36
uv.lock
generated
@@ -21,6 +21,7 @@ dev = [
|
|||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
{ name = "pytest-cov" },
|
{ name = "pytest-cov" },
|
||||||
|
{ name = "rich" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
@@ -34,6 +35,7 @@ dev = [
|
|||||||
{ name = "pytest", specifier = ">=9.0.2" },
|
{ name = "pytest", specifier = ">=9.0.2" },
|
||||||
{ name = "pytest-asyncio", specifier = ">=0.24" },
|
{ name = "pytest-asyncio", specifier = ">=0.24" },
|
||||||
{ name = "pytest-cov", specifier = ">=7.0.0" },
|
{ name = "pytest-cov", specifier = ">=7.0.0" },
|
||||||
|
{ name = "rich", specifier = ">=14.0.0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1432,6 +1434,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" },
|
{ url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "markdown-it-py"
|
||||||
|
version = "4.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "mdurl" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markupsafe"
|
name = "markupsafe"
|
||||||
version = "3.0.3"
|
version = "3.0.3"
|
||||||
@@ -1520,6 +1534,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" },
|
{ url = "https://files.pythonhosted.org/packages/fd/d9/eaa1f80170d2b7c5ba23f3b59f766f3a0bb41155fbc32a69adfa1adaaef9/mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca", size = 233615, upload-time = "2026-01-24T19:40:30.652Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mdurl"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mmh3"
|
name = "mmh3"
|
||||||
version = "5.2.0"
|
version = "5.2.0"
|
||||||
@@ -2352,6 +2375,19 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
{ url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rich"
|
||||||
|
version = "14.3.3"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "markdown-it-py" },
|
||||||
|
{ name = "pygments" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rpds-py"
|
name = "rpds-py"
|
||||||
version = "0.30.0"
|
version = "0.30.0"
|
||||||
|
|||||||
44
view_summary.py
Normal file
44
view_summary.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Print the conversation summary for a specific user's session."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
from adk_firestore_sessionmanager import FirestoreSessionService
|
||||||
|
|
||||||
|
APP_NAME = "test_agent"
|
||||||
|
USER_ID = "dev_user"
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
db = AsyncClient()
|
||||||
|
session_service = FirestoreSessionService(db=db)
|
||||||
|
|
||||||
|
resp = await session_service.list_sessions(
|
||||||
|
app_name=APP_NAME, user_id=USER_ID
|
||||||
|
)
|
||||||
|
|
||||||
|
if not resp.sessions:
|
||||||
|
console.print("[dim]No sessions found.[/]")
|
||||||
|
return
|
||||||
|
|
||||||
|
for s in resp.sessions:
|
||||||
|
ref = session_service._session_ref(APP_NAME, USER_ID, s.id)
|
||||||
|
snap = await ref.get()
|
||||||
|
data = snap.to_dict() or {}
|
||||||
|
summary = data.get("conversation_summary")
|
||||||
|
|
||||||
|
if summary:
|
||||||
|
console.print(Panel(Markdown(summary), title=f"Session {s.id}", border_style="cyan"))
|
||||||
|
else:
|
||||||
|
console.print(Panel("[dim]No summary yet.[/]", title=f"Session {s.id}", border_style="yellow"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
Reference in New Issue
Block a user