Add auto-refresh, non-blocking auth #13

Merged
A8065384 merged 1 commits from auth into main 2026-02-25 17:18:20 +00:00
3 changed files with 49 additions and 20 deletions

View File

@@ -1,44 +1,73 @@
"""ADK agent with vector search RAG tool.""" """ADK agent with vector search RAG tool."""
import base64
import concurrent.futures
import json
import time
from typing import Any
from google import genai from google import genai
from google.adk.agents.llm_agent import Agent from google.adk.agents.llm_agent import Agent
from google.adk.runners import Runner from google.adk.runners import Runner
from google.adk.tools.mcp_tool import McpToolset from google.adk.tools.mcp_tool import McpToolset
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.auth.transport.requests import Request as GAuthRequest
from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.async_client import AsyncClient
# --- Autenticación Cloud Run → Cloud Run (ID Token) ---
from google.oauth2 import id_token
from va_agent.config import settings from va_agent.config import settings
from va_agent.session import FirestoreSessionService from va_agent.session import FirestoreSessionService
# --- Autenticación Cloud Run → Cloud Run (ID Token) ---
from google.oauth2 import id_token
from google.auth.transport.requests import Request as GAuthRequest
def _fetch_id_token(audience: str) -> str: def _fetch_id_token(audience: str) -> str:
"""Emite un ID Token para invocar un servicio Cloud Run protegido.""" """Return an ID token for a protected Cloud Run service."""
return id_token.fetch_id_token(GAuthRequest(), audience) return id_token.fetch_id_token(GAuthRequest(), audience)
# Audience = URL del MCP remoto
def _jwt_exp(token: str) -> float:
"""Return the ``exp`` claim (epoch seconds) from a JWT without verification."""
payload = token.split(".")[1]
# Fix base64url padding
padded = payload + "=" * (-len(payload) % 4)
return float(json.loads(base64.urlsafe_b64decode(padded))["exp"])
# Audience = URL del MCP remoto
_MCP_URL = settings.mcp_remote_url _MCP_URL = settings.mcp_remote_url
_MCP_AUDIENCE = getattr(settings, "mcp_audience", None) or _MCP_URL _MCP_AUDIENCE = settings.mcp_audience
# Reusable pool for the blocking metadata-server call inside the sync callback.
_token_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Cached token and its expiry (epoch seconds).
_cached_token: str | None = None
_cached_token_exp: float = 0.0
_TOKEN_REFRESH_MARGIN = 300 # refresh 5 min before expiry
def _auth_headers_provider(_ctx: Any = None) -> dict[str, str]:
def _auth_headers_provider() -> dict[str, str]: global _cached_token, _cached_token_exp # noqa: PLW0603
token = _fetch_id_token(_MCP_AUDIENCE) # Reuse a valid token; refresh only when near expiry.
return {"Authorization": f"Bearer {token}"} expired = time.time() >= _cached_token_exp - _TOKEN_REFRESH_MARGIN
if _cached_token is None or expired:
# header_provider is called synchronously by ADK inside an async path.
# Run the blocking HTTP call in a thread so we don't stall the event loop.
token = _token_pool.submit(_fetch_id_token, _MCP_AUDIENCE).result()
_cached_token = token
_cached_token_exp = _jwt_exp(token)
return {"Authorization": f"Bearer {_cached_token}"}
connection_params = SseConnectionParams( connection_params = SseConnectionParams(url=_MCP_URL)
url=_MCP_URL,
headers=_auth_headers_provider() toolset = McpToolset(
connection_params=connection_params,
header_provider=_auth_headers_provider,
) )
# connection_params = SseConnectionParams(url=settings.mcp_remote_url)
toolset = McpToolset(connection_params=connection_params)
agent = Agent( agent = Agent(
model=settings.agent_model, model=settings.agent_model,
name=settings.agent_name, name=settings.agent_name,

View File

@@ -28,8 +28,6 @@ class AgentSettings(BaseSettings):
# MCP configuration # MCP configuration
mcp_audience: str mcp_audience: str
# MCP configuration audience
mcp_remote_url: str mcp_remote_url: str
model_config = SettingsConfigDict( model_config = SettingsConfigDict(

2
uv.lock generated
View File

@@ -1922,6 +1922,7 @@ version = "0.1.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "google-adk" }, { name = "google-adk" },
{ name = "google-auth" },
{ name = "google-cloud-firestore" }, { name = "google-cloud-firestore" },
{ name = "pydantic-settings", extra = ["yaml"] }, { name = "pydantic-settings", extra = ["yaml"] },
] ]
@@ -1938,6 +1939,7 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "google-adk", specifier = ">=1.14.1" }, { name = "google-adk", specifier = ">=1.14.1" },
{ name = "google-auth", specifier = ">=2.34.0" },
{ name = "google-cloud-firestore", specifier = ">=2.23.0" }, { name = "google-cloud-firestore", specifier = ">=2.23.0" },
{ name = "pydantic-settings", extras = ["yaml"], specifier = ">=2.13.1" }, { name = "pydantic-settings", extras = ["yaml"], specifier = ">=2.13.1" },
] ]