diff --git a/src/va_agent/agent.py b/src/va_agent/agent.py index 48deae1..3ffcc13 100644 --- a/src/va_agent/agent.py +++ b/src/va_agent/agent.py @@ -4,7 +4,7 @@ from google import genai from google.adk.agents.llm_agent import Agent from google.adk.runners import Runner from google.adk.tools.mcp_tool import McpToolset -from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams from google.cloud.firestore_v1.async_client import AsyncClient from va_agent.auth import auth_headers_provider @@ -12,7 +12,7 @@ from va_agent.config import settings from va_agent.session import FirestoreSessionService toolset = McpToolset( - connection_params=SseConnectionParams(url=settings.mcp_remote_url), + connection_params=StreamableHTTPConnectionParams(url=settings.mcp_remote_url), header_provider=auth_headers_provider, ) diff --git a/src/va_agent/auth.py b/src/va_agent/auth.py index e51fcef..7e7435e 100644 --- a/src/va_agent/auth.py +++ b/src/va_agent/auth.py @@ -1,6 +1,6 @@ """ID-token auth for Cloud Run → Cloud Run calls.""" -import threading +import logging import time from google.adk.agents.readonly_context import ReadonlyContext @@ -10,18 +10,33 @@ from google.oauth2 import id_token from va_agent.config import settings -_REFRESH_MARGIN = 300 # refresh 5 min before expiry +logger = logging.getLogger(__name__) + +_REFRESH_MARGIN = 900 # refresh 15 min before expiry -_lock = threading.Lock() _token: str | None = None _token_exp: float = 0.0 +def _fetch_token() -> tuple[str, float]: + """Fetch a fresh ID token (blocking I/O).""" + tok = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience) + exp = jwt.decode(tok, verify=False)["exp"] + return tok, exp + + def auth_headers_provider(_ctx: ReadonlyContext | None = None) -> dict[str, str]: - """Return Authorization headers with a cached ID token.""" - global _token, _token_exp # noqa: PLW0603 - with _lock: - if _token is None or time.time() >= _token_exp - _REFRESH_MARGIN: - _token = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience) - _token_exp = jwt.decode(_token, verify=False)["exp"] - return {"Authorization": f"Bearer {_token}"} + """Return Authorization headers, refreshing the cached token when needed. + + With Streamable HTTP transport every tool call is a fresh HTTP + request, so returning a valid token here is sufficient — no + background refresh loop required. + """ + global _token, _token_exp + + if _token is not None and time.time() < _token_exp - _REFRESH_MARGIN: + return {"Authorization": f"Bearer {_token}"} + + tok, exp = _fetch_token() + _token, _token_exp = tok, exp + return {"Authorization": f"Bearer {tok}"} diff --git a/tests/test_auth.py b/tests/test_auth.py index 9a467db..1233e0b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -75,8 +75,7 @@ class TestAuthHeadersProvider: ) -> None: mock_settings.mcp_audience = "https://my-service" - # First token expires within the refresh margin (5 min). - first_exp = time.time() + 100 # < 300s margin + first_exp = time.time() + 100 # < 900s margin second_exp = time.time() + 3600 mock_fetch.side_effect = [ _make_fake_token(first_exp),