Compare commits

..

2 Commits

4 changed files with 29 additions and 15 deletions

View File

@@ -3,7 +3,7 @@ google_cloud_location: us-central1
firestore_db: bnt-orquestador-cognitivo-firestore-bdo-dev
mcp_remote_url: "https://ap01194-orq-cog-rag-connector-1007577023101.us-central1.run.app/sse"
mcp_remote_url: "https://ap01194-orq-cog-rag-connector-1007577023101.us-central1.run.app/mcp"
# audience sin la ruta, para emitir el ID Token:
mcp_audience: "https://ap01194-orq-cog-rag-connector-1007577023101.us-central1.run.app"

View File

@@ -4,7 +4,7 @@ from google import genai
from google.adk.agents.llm_agent import Agent
from google.adk.runners import Runner
from google.adk.tools.mcp_tool import McpToolset
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
from google.cloud.firestore_v1.async_client import AsyncClient
from va_agent.auth import auth_headers_provider
@@ -12,7 +12,7 @@ from va_agent.config import settings
from va_agent.session import FirestoreSessionService
toolset = McpToolset(
connection_params=SseConnectionParams(url=settings.mcp_remote_url),
connection_params=StreamableHTTPConnectionParams(url=settings.mcp_remote_url),
header_provider=auth_headers_provider,
)

View File

@@ -1,6 +1,6 @@
"""ID-token auth for Cloud Run → Cloud Run calls."""
import threading
import logging
import time
from google.adk.agents.readonly_context import ReadonlyContext
@@ -10,18 +10,33 @@ from google.oauth2 import id_token
from va_agent.config import settings
_REFRESH_MARGIN = 300 # refresh 5 min before expiry
logger = logging.getLogger(__name__)
_REFRESH_MARGIN = 900 # refresh 15 min before expiry
_lock = threading.Lock()
_token: str | None = None
_token_exp: float = 0.0
def _fetch_token() -> tuple[str, float]:
"""Fetch a fresh ID token (blocking I/O)."""
tok = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience)
exp = jwt.decode(tok, verify=False)["exp"]
return tok, exp
def auth_headers_provider(_ctx: ReadonlyContext | None = None) -> dict[str, str]:
"""Return Authorization headers with a cached ID token."""
global _token, _token_exp # noqa: PLW0603
with _lock:
if _token is None or time.time() >= _token_exp - _REFRESH_MARGIN:
_token = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience)
_token_exp = jwt.decode(_token, verify=False)["exp"]
"""Return Authorization headers, refreshing the cached token when needed.
With Streamable HTTP transport every tool call is a fresh HTTP
request, so returning a valid token here is sufficient — no
background refresh loop required.
"""
global _token, _token_exp
if _token is not None and time.time() < _token_exp - _REFRESH_MARGIN:
return {"Authorization": f"Bearer {_token}"}
tok, exp = _fetch_token()
_token, _token_exp = tok, exp
return {"Authorization": f"Bearer {tok}"}

View File

@@ -75,8 +75,7 @@ class TestAuthHeadersProvider:
) -> None:
mock_settings.mcp_audience = "https://my-service"
# First token expires within the refresh margin (5 min).
first_exp = time.time() + 100 # < 300s margin
first_exp = time.time() + 100 # < 900s margin
second_exp = time.time() + 3600
mock_fetch.side_effect = [
_make_fake_token(first_exp),