Merge pull request 'Add auto-refresh, non-blocking auth' (#13) from auth into main
Reviewed-on: #13
This commit was merged in pull request #13.
This commit is contained in:
@@ -1,44 +1,73 @@
|
||||
"""ADK agent with vector search RAG tool."""
|
||||
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from google import genai
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.tools.mcp_tool import McpToolset
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
||||
from google.auth.transport.requests import Request as GAuthRequest
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
|
||||
# --- Autenticación Cloud Run → Cloud Run (ID Token) ---
|
||||
from google.oauth2 import id_token
|
||||
|
||||
from va_agent.config import settings
|
||||
from va_agent.session import FirestoreSessionService
|
||||
|
||||
|
||||
|
||||
# --- Autenticación Cloud Run → Cloud Run (ID Token) ---
|
||||
from google.oauth2 import id_token
|
||||
from google.auth.transport.requests import Request as GAuthRequest
|
||||
|
||||
def _fetch_id_token(audience: str) -> str:
|
||||
"""Emite un ID Token para invocar un servicio Cloud Run protegido."""
|
||||
"""Return an ID token for a protected Cloud Run service."""
|
||||
return id_token.fetch_id_token(GAuthRequest(), audience)
|
||||
|
||||
|
||||
def _jwt_exp(token: str) -> float:
|
||||
"""Return the ``exp`` claim (epoch seconds) from a JWT without verification."""
|
||||
payload = token.split(".")[1]
|
||||
# Fix base64url padding
|
||||
padded = payload + "=" * (-len(payload) % 4)
|
||||
return float(json.loads(base64.urlsafe_b64decode(padded))["exp"])
|
||||
|
||||
|
||||
# Audience = URL del MCP remoto
|
||||
_MCP_URL = settings.mcp_remote_url
|
||||
_MCP_AUDIENCE = getattr(settings, "mcp_audience", None) or _MCP_URL
|
||||
_MCP_AUDIENCE = settings.mcp_audience
|
||||
|
||||
# Reusable pool for the blocking metadata-server call inside the sync callback.
|
||||
_token_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# Cached token and its expiry (epoch seconds).
|
||||
_cached_token: str | None = None
|
||||
_cached_token_exp: float = 0.0
|
||||
|
||||
_TOKEN_REFRESH_MARGIN = 300 # refresh 5 min before expiry
|
||||
|
||||
|
||||
|
||||
def _auth_headers_provider() -> dict[str, str]:
|
||||
token = _fetch_id_token(_MCP_AUDIENCE)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
def _auth_headers_provider(_ctx: Any = None) -> dict[str, str]:
|
||||
global _cached_token, _cached_token_exp # noqa: PLW0603
|
||||
# Reuse a valid token; refresh only when near expiry.
|
||||
expired = time.time() >= _cached_token_exp - _TOKEN_REFRESH_MARGIN
|
||||
if _cached_token is None or expired:
|
||||
# header_provider is called synchronously by ADK inside an async path.
|
||||
# Run the blocking HTTP call in a thread so we don't stall the event loop.
|
||||
token = _token_pool.submit(_fetch_id_token, _MCP_AUDIENCE).result()
|
||||
_cached_token = token
|
||||
_cached_token_exp = _jwt_exp(token)
|
||||
return {"Authorization": f"Bearer {_cached_token}"}
|
||||
|
||||
|
||||
connection_params = SseConnectionParams(
|
||||
url=_MCP_URL,
|
||||
headers=_auth_headers_provider()
|
||||
connection_params = SseConnectionParams(url=_MCP_URL)
|
||||
|
||||
toolset = McpToolset(
|
||||
connection_params=connection_params,
|
||||
header_provider=_auth_headers_provider,
|
||||
)
|
||||
|
||||
# connection_params = SseConnectionParams(url=settings.mcp_remote_url)
|
||||
toolset = McpToolset(connection_params=connection_params)
|
||||
|
||||
agent = Agent(
|
||||
model=settings.agent_model,
|
||||
name=settings.agent_name,
|
||||
|
||||
@@ -28,8 +28,6 @@ class AgentSettings(BaseSettings):
|
||||
|
||||
# MCP configuration
|
||||
mcp_audience: str
|
||||
|
||||
# MCP configuration audience
|
||||
mcp_remote_url: str
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -1922,6 +1922,7 @@ version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "google-adk" },
|
||||
{ name = "google-auth" },
|
||||
{ name = "google-cloud-firestore" },
|
||||
{ name = "pydantic-settings", extra = ["yaml"] },
|
||||
]
|
||||
@@ -1938,6 +1939,7 @@ dev = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "google-adk", specifier = ">=1.14.1" },
|
||||
{ name = "google-auth", specifier = ">=2.34.0" },
|
||||
{ name = "google-cloud-firestore", specifier = ">=2.23.0" },
|
||||
{ name = "pydantic-settings", extras = ["yaml"], specifier = ">=2.13.1" },
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user