diff --git a/src/va_agent/agent.py b/src/va_agent/agent.py index 11b458b..07aa959 100644 --- a/src/va_agent/agent.py +++ b/src/va_agent/agent.py @@ -1,5 +1,7 @@ """ADK agent with vector search RAG tool.""" +from functools import partial + from google import genai from google.adk.agents.llm_agent import Agent from google.adk.runners import Runner @@ -10,8 +12,9 @@ from google.genai.types import Content, Part from va_agent.auth import auth_headers_provider from va_agent.config import settings +from va_agent.dynamic_instruction import provide_dynamic_instruction from va_agent.governance import GovernancePlugin -from va_agent.notifications import NotificationService +from va_agent.notifications import FirestoreNotificationBackend from va_agent.session import FirestoreSessionService # MCP Toolset for RAG knowledge search @@ -32,7 +35,7 @@ session_service = FirestoreSessionService( ) # Notification service -notification_service = NotificationService( +notification_service = FirestoreNotificationBackend( db=firestore_db, collection_path=settings.notifications_collection_path, max_to_notify=settings.notifications_max_to_notify, @@ -43,11 +46,11 @@ governance = GovernancePlugin() agent = Agent( model=settings.agent_model, name=settings.agent_name, + instruction=partial(provide_dynamic_instruction, notification_service), static_instruction=Content( role="user", parts=[Part(text=settings.agent_instructions)], ), - instruction=settings.agent_instructions, tools=[toolset], after_model_callback=governance.after_model_callback, ) diff --git a/src/va_agent/dynamic_instruction.py b/src/va_agent/dynamic_instruction.py index cc64fb9..4854480 100644 --- a/src/va_agent/dynamic_instruction.py +++ b/src/va_agent/dynamic_instruction.py @@ -8,13 +8,13 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from google.adk.agents.readonly_context import ReadonlyContext - from va_agent.notifications import NotificationService + from va_agent.notifications import NotificationBackend logger = logging.getLogger(__name__) async def provide_dynamic_instruction( - notification_service: NotificationService, + notification_service: NotificationBackend, ctx: ReadonlyContext | None = None, ) -> str: """Provide dynamic instructions based on pending notifications. diff --git a/src/va_agent/notifications.py b/src/va_agent/notifications.py index 8536fb2..958281a 100644 --- a/src/va_agent/notifications.py +++ b/src/va_agent/notifications.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: from google.cloud.firestore_v1.async_client import AsyncClient @@ -12,8 +12,25 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class NotificationService: - """Service for fetching and managing user notifications from Firestore.""" +@runtime_checkable +class NotificationBackend(Protocol): + """Backend-agnostic interface for notification storage.""" + + async def get_pending_notifications( + self, phone_number: str + ) -> list[dict[str, Any]]: + """Return pending (unread) notifications for *phone_number*.""" + ... + + async def mark_as_notified( + self, phone_number: str, notification_ids: list[str] + ) -> bool: + """Mark the given notification IDs as notified. Return success.""" + ... + + +class FirestoreNotificationBackend: + """Firestore-backed notification backend.""" def __init__( self, @@ -22,14 +39,7 @@ class NotificationService: collection_path: str, max_to_notify: int = 5, ) -> None: - """Initialize NotificationService. - - Args: - db: Firestore async client - collection_path: Path to notifications collection - max_to_notify: Maximum number of notifications to return - - """ + """Initialize with Firestore client and collection path.""" self._db = db self._collection_path = collection_path self._max_to_notify = max_to_notify @@ -57,7 +67,6 @@ class NotificationService: """ try: - # Query Firestore document by phone number doc_ref = self._db.collection(self._collection_path).document(phone_number) doc = await doc_ref.get() @@ -184,37 +193,3 @@ class NotificationService: return False else: return True - - def format_notification_summary(self, notifications: list[dict[str, Any]]) -> str: - """Format notifications into a human-readable summary. - - Args: - notifications: List of notification dictionaries - - Returns: - Formatted string summarizing the notifications - - """ - if not notifications: - return "" - - count = len(notifications) - summary_lines = [f"El usuario tiene {count} notificación(es) pendiente(s):"] - - for i, notif in enumerate(notifications, 1): - texto = notif.get("texto", "Sin texto") - params = notif.get("parametros", {}) - - # Extract key parameters if available - amount = params.get("notification_po_amount") - tx_id = params.get("notification_po_transaction_id") - - line = f"{i}. {texto}" - if amount: - line += f" (monto: ${amount})" - if tx_id: - line += f" [ID: {tx_id}]" - - summary_lines.append(line) - - return "\n".join(summary_lines) diff --git a/src/va_agent/server.py b/src/va_agent/server.py index b511653..bfa2597 100644 --- a/src/va_agent/server.py +++ b/src/va_agent/server.py @@ -22,20 +22,11 @@ app = FastAPI(title="Vaia Agent") # --------------------------------------------------------------------------- -class NotificationPayload(BaseModel): - """Notification context sent alongside a user query.""" - - text: str | None = None - parameters: dict[str, Any] = Field(default_factory=dict) - - class QueryRequest(BaseModel): """Incoming query request from the integration layer.""" phone_number: str text: str - type: str = "conversation" - notification: NotificationPayload | None = None language_code: str = "es" @@ -56,26 +47,6 @@ class ErrorResponse(BaseModel): status: int -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _build_user_message(request: QueryRequest) -> str: - """Compose the text sent to the agent, including notification context.""" - if request.type == "notification" and request.notification: - parts = [request.text] - if request.notification.text: - parts.append(f"\n[Notificación recibida]: {request.notification.text}") - if request.notification.parameters: - formatted = ", ".join( - f"{k}: {v}" for k, v in request.notification.parameters.items() - ) - parts.append(f"[Parámetros de notificación]: {formatted}") - return "\n".join(parts) - return request.text - - # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @@ -92,13 +63,12 @@ def _build_user_message(request: QueryRequest) -> str: ) async def query(request: QueryRequest) -> QueryResponse: """Process a user message and return a generated response.""" - user_message = _build_user_message(request) session_id = request.phone_number user_id = request.phone_number new_message = Content( role="user", - parts=[Part(text=user_message)], + parts=[Part(text=request.text)], ) try: