Add auto-refresh, non-blocking auth

This commit is contained in:
2026-02-25 17:10:38 +00:00
parent 63eff5bde0
commit 57a215e733
3 changed files with 49 additions and 20 deletions

View File

@@ -1,44 +1,73 @@
"""ADK agent with vector search RAG tool."""
import base64
import concurrent.futures
import json
import time
from typing import Any
from google import genai
from google.adk.agents.llm_agent import Agent
from google.adk.runners import Runner
from google.adk.tools.mcp_tool import McpToolset
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.auth.transport.requests import Request as GAuthRequest
from google.cloud.firestore_v1.async_client import AsyncClient
# --- Autenticación Cloud Run → Cloud Run (ID Token) ---
from google.oauth2 import id_token
from va_agent.config import settings
from va_agent.session import FirestoreSessionService
# --- Autenticación Cloud Run → Cloud Run (ID Token) ---
from google.oauth2 import id_token
from google.auth.transport.requests import Request as GAuthRequest
def _fetch_id_token(audience: str) -> str:
"""Emite un ID Token para invocar un servicio Cloud Run protegido."""
"""Return an ID token for a protected Cloud Run service."""
return id_token.fetch_id_token(GAuthRequest(), audience)
def _jwt_exp(token: str) -> float:
"""Return the ``exp`` claim (epoch seconds) from a JWT without verification."""
payload = token.split(".")[1]
# Fix base64url padding
padded = payload + "=" * (-len(payload) % 4)
return float(json.loads(base64.urlsafe_b64decode(padded))["exp"])
# Audience = URL del MCP remoto
_MCP_URL = settings.mcp_remote_url
_MCP_AUDIENCE = getattr(settings, "mcp_audience", None) or _MCP_URL
_MCP_AUDIENCE = settings.mcp_audience
# Reusable pool for the blocking metadata-server call inside the sync callback.
_token_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Cached token and its expiry (epoch seconds).
_cached_token: str | None = None
_cached_token_exp: float = 0.0
_TOKEN_REFRESH_MARGIN = 300 # refresh 5 min before expiry
def _auth_headers_provider() -> dict[str, str]:
token = _fetch_id_token(_MCP_AUDIENCE)
return {"Authorization": f"Bearer {token}"}
def _auth_headers_provider(_ctx: Any = None) -> dict[str, str]:
global _cached_token, _cached_token_exp # noqa: PLW0603
# Reuse a valid token; refresh only when near expiry.
expired = time.time() >= _cached_token_exp - _TOKEN_REFRESH_MARGIN
if _cached_token is None or expired:
# header_provider is called synchronously by ADK inside an async path.
# Run the blocking HTTP call in a thread so we don't stall the event loop.
token = _token_pool.submit(_fetch_id_token, _MCP_AUDIENCE).result()
_cached_token = token
_cached_token_exp = _jwt_exp(token)
return {"Authorization": f"Bearer {_cached_token}"}
connection_params = SseConnectionParams(
url=_MCP_URL,
headers=_auth_headers_provider()
connection_params = SseConnectionParams(url=_MCP_URL)
toolset = McpToolset(
connection_params=connection_params,
header_provider=_auth_headers_provider,
)
# connection_params = SseConnectionParams(url=settings.mcp_remote_url)
toolset = McpToolset(connection_params=connection_params)
agent = Agent(
model=settings.agent_model,
name=settings.agent_name,

View File

@@ -28,8 +28,6 @@ class AgentSettings(BaseSettings):
# MCP configuration
mcp_audience: str
# MCP configuration audience
mcp_remote_url: str
model_config = SettingsConfigDict(

2
uv.lock generated
View File

@@ -1922,6 +1922,7 @@ version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "google-adk" },
{ name = "google-auth" },
{ name = "google-cloud-firestore" },
{ name = "pydantic-settings", extra = ["yaml"] },
]
@@ -1938,6 +1939,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "google-adk", specifier = ">=1.14.1" },
{ name = "google-auth", specifier = ">=2.34.0" },
{ name = "google-cloud-firestore", specifier = ">=2.23.0" },
{ name = "pydantic-settings", extras = ["yaml"], specifier = ">=2.13.1" },
]