140 lines
4.0 KiB
Python
140 lines
4.0 KiB
Python
"""FastAPI server exposing the RAG agent endpoint."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from google.genai.types import Content, Part
|
|
from pydantic import BaseModel, Field
|
|
|
|
from va_agent.agent import runner
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(title="Vaia Agent")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request / Response models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class NotificationPayload(BaseModel):
|
|
"""Notification context sent alongside a user query."""
|
|
|
|
text: str | None = None
|
|
parameters: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
"""Incoming query request from the integration layer."""
|
|
|
|
phone_number: str
|
|
text: str
|
|
type: str = "conversation"
|
|
notification: NotificationPayload | None = None
|
|
language_code: str = "es"
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
"""Response returned to the integration layer."""
|
|
|
|
response_id: str
|
|
response_text: str
|
|
parameters: dict[str, Any] = Field(default_factory=dict)
|
|
confidence: float | None = None
|
|
|
|
|
|
class ErrorResponse(BaseModel):
|
|
"""Standard error body."""
|
|
|
|
error: str
|
|
message: str
|
|
status: int
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _build_user_message(request: QueryRequest) -> str:
|
|
"""Compose the text sent to the agent, including notification context."""
|
|
if request.type == "notification" and request.notification:
|
|
parts = [request.text]
|
|
if request.notification.text:
|
|
parts.append(f"\n[Notificación recibida]: {request.notification.text}")
|
|
if request.notification.parameters:
|
|
formatted = ", ".join(
|
|
f"{k}: {v}" for k, v in request.notification.parameters.items()
|
|
)
|
|
parts.append(f"[Parámetros de notificación]: {formatted}")
|
|
return "\n".join(parts)
|
|
return request.text
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@app.post(
|
|
"/api/v1/query",
|
|
response_model=QueryResponse,
|
|
responses={
|
|
400: {"model": ErrorResponse},
|
|
500: {"model": ErrorResponse},
|
|
503: {"model": ErrorResponse},
|
|
},
|
|
)
|
|
async def query(request: QueryRequest) -> QueryResponse:
|
|
"""Process a user message and return a generated response."""
|
|
user_message = _build_user_message(request)
|
|
session_id = request.phone_number
|
|
user_id = request.phone_number
|
|
|
|
new_message = Content(
|
|
role="user",
|
|
parts=[Part(text=user_message)],
|
|
)
|
|
|
|
try:
|
|
response_text = ""
|
|
async for event in runner.run_async(
|
|
user_id=user_id,
|
|
session_id=session_id,
|
|
new_message=new_message,
|
|
):
|
|
if event.content and event.content.parts:
|
|
for part in event.content.parts:
|
|
if part.text and event.author != "user":
|
|
response_text += part.text
|
|
except ValueError as exc:
|
|
logger.exception("Bad request while running agent")
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=ErrorResponse(
|
|
error="Bad Request",
|
|
message=str(exc),
|
|
status=400,
|
|
).model_dump(),
|
|
) from exc
|
|
except Exception as exc:
|
|
logger.exception("Internal error while running agent")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=ErrorResponse(
|
|
error="Internal Server Error",
|
|
message="Failed to generate response",
|
|
status=500,
|
|
).model_dump(),
|
|
) from exc
|
|
|
|
return QueryResponse(
|
|
response_id=f"rag-resp-{uuid.uuid4()}",
|
|
response_text=response_text,
|
|
)
|