110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Minimal CLI chat agent to test FirestoreSessionService."""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
|
|
from google import genai
|
|
from google.adk.agents import LlmAgent
|
|
from google.adk.runners import Runner
|
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
|
from google.genai import types
|
|
from rich.console import Console
|
|
from rich.logging import RichHandler
|
|
from rich.markdown import Markdown
|
|
from rich.panel import Panel
|
|
|
|
from adk_firestore_sessionmanager import FirestoreSessionService
|
|
|
|
APP_NAME = "test_agent"
|
|
USER_ID = "dev_user"
|
|
|
|
console = Console()
|
|
|
|
root_agent = LlmAgent(
|
|
name=APP_NAME,
|
|
model="gemini-2.5-flash",
|
|
instruction="You are a helpful conversational assistant.",
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Chat with a Firestore-backed ADK agent.")
|
|
parser.add_argument(
|
|
"--log-level",
|
|
default="WARNING",
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
|
help="Set the log level (default: WARNING)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
async def main(args: argparse.Namespace) -> None:
|
|
logging.basicConfig(
|
|
level=args.log_level,
|
|
format="%(message)s",
|
|
datefmt="[%X]",
|
|
handlers=[RichHandler(console=console, rich_tracebacks=True)],
|
|
)
|
|
|
|
db = AsyncClient()
|
|
session_service = FirestoreSessionService(
|
|
db=db,
|
|
compaction_token_threshold=10_000,
|
|
genai_client=genai.Client(),
|
|
)
|
|
|
|
runner = Runner(
|
|
app_name=APP_NAME,
|
|
agent=root_agent,
|
|
session_service=session_service,
|
|
)
|
|
|
|
# Reuse existing session or create a new one
|
|
resp = await session_service.list_sessions(
|
|
app_name=APP_NAME, user_id=USER_ID
|
|
)
|
|
if resp.sessions:
|
|
session = await session_service.get_session(
|
|
app_name=APP_NAME,
|
|
user_id=USER_ID,
|
|
session_id=resp.sessions[0].id,
|
|
)
|
|
console.print(f"Resuming session [bold cyan]{session.id}[/]")
|
|
else:
|
|
session = await session_service.create_session(
|
|
app_name=APP_NAME,
|
|
user_id=USER_ID,
|
|
)
|
|
console.print(f"Session [bold cyan]{session.id}[/] created.")
|
|
|
|
console.print("Type [bold]exit[/] to quit.\n")
|
|
|
|
while True:
|
|
try:
|
|
user_input = console.input("[bold green]You:[/] ").strip()
|
|
except (EOFError, KeyboardInterrupt):
|
|
break
|
|
if not user_input or user_input.lower() == "exit":
|
|
break
|
|
|
|
async for event in runner.run_async(
|
|
user_id=USER_ID,
|
|
session_id=session.id,
|
|
new_message=types.Content(
|
|
role="user", parts=[types.Part(text=user_input)]
|
|
),
|
|
):
|
|
if event.content and event.content.parts and not event.partial:
|
|
text = "".join(p.text or "" for p in event.content.parts)
|
|
if text:
|
|
console.print(Panel(Markdown(text), title="Agent", border_style="blue"))
|
|
|
|
await runner.close()
|
|
console.print("\n[dim]Goodbye![/]")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main(parse_args()))
|