Improve scripts
This commit is contained in:
44
chat.py
44
chat.py
@@ -1,19 +1,27 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Minimal CLI chat agent to test FirestoreSessionService."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from google import genai
|
||||
from google.adk.agents import LlmAgent
|
||||
from google.adk.runners import Runner
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
from google.genai import types
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
from adk_firestore_sessionmanager import FirestoreSessionService
|
||||
|
||||
APP_NAME = "test_agent"
|
||||
USER_ID = "dev_user"
|
||||
|
||||
console = Console()
|
||||
|
||||
root_agent = LlmAgent(
|
||||
name=APP_NAME,
|
||||
model="gemini-2.5-flash",
|
||||
@@ -21,11 +29,29 @@ root_agent = LlmAgent(
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Chat with a Firestore-backed ADK agent.")
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
default="WARNING",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the log level (default: WARNING)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main(args: argparse.Namespace) -> None:
|
||||
logging.basicConfig(
|
||||
level=args.log_level,
|
||||
format="%(message)s",
|
||||
datefmt="[%X]",
|
||||
handlers=[RichHandler(console=console, rich_tracebacks=True)],
|
||||
)
|
||||
|
||||
db = AsyncClient()
|
||||
session_service = FirestoreSessionService(
|
||||
db=db,
|
||||
compaction_token_threshold=800_000,
|
||||
compaction_token_threshold=10_000,
|
||||
genai_client=genai.Client(),
|
||||
)
|
||||
|
||||
@@ -45,19 +71,19 @@ async def main() -> None:
|
||||
user_id=USER_ID,
|
||||
session_id=resp.sessions[0].id,
|
||||
)
|
||||
print(f"Resuming session {session.id}.")
|
||||
console.print(f"Resuming session [bold cyan]{session.id}[/]")
|
||||
else:
|
||||
session = await session_service.create_session(
|
||||
app_name=APP_NAME,
|
||||
user_id=USER_ID,
|
||||
)
|
||||
print(f"Session {session.id} created.")
|
||||
console.print(f"Session [bold cyan]{session.id}[/] created.")
|
||||
|
||||
print("Type 'exit' to quit.\n")
|
||||
console.print("Type [bold]exit[/] to quit.\n")
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
user_input = console.input("[bold green]You:[/] ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
break
|
||||
if not user_input or user_input.lower() == "exit":
|
||||
@@ -73,11 +99,11 @@ async def main() -> None:
|
||||
if event.content and event.content.parts and not event.partial:
|
||||
text = "".join(p.text or "" for p in event.content.parts)
|
||||
if text:
|
||||
print(f"Agent: {text}")
|
||||
console.print(Panel(Markdown(text), title="Agent", border_style="blue"))
|
||||
|
||||
await runner.close()
|
||||
print("\nGoodbye!")
|
||||
console.print("\n[dim]Goodbye![/]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main(parse_args()))
|
||||
|
||||
Reference in New Issue
Block a user