Merge branch 'main' into prompt
This commit is contained in:
@@ -4,7 +4,7 @@ from google import genai
|
|||||||
from google.adk.agents.llm_agent import Agent
|
from google.adk.agents.llm_agent import Agent
|
||||||
from google.adk.runners import Runner
|
from google.adk.runners import Runner
|
||||||
from google.adk.tools.mcp_tool import McpToolset
|
from google.adk.tools.mcp_tool import McpToolset
|
||||||
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
|
||||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
|
||||||
from va_agent.auth import auth_headers_provider
|
from va_agent.auth import auth_headers_provider
|
||||||
@@ -12,7 +12,7 @@ from va_agent.config import settings
|
|||||||
from va_agent.session import FirestoreSessionService
|
from va_agent.session import FirestoreSessionService
|
||||||
|
|
||||||
toolset = McpToolset(
|
toolset = McpToolset(
|
||||||
connection_params=SseConnectionParams(url=settings.mcp_remote_url),
|
connection_params=StreamableHTTPConnectionParams(url=settings.mcp_remote_url),
|
||||||
header_provider=auth_headers_provider,
|
header_provider=auth_headers_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""ID-token auth for Cloud Run → Cloud Run calls."""
|
"""ID-token auth for Cloud Run → Cloud Run calls."""
|
||||||
|
|
||||||
import threading
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from google.adk.agents.readonly_context import ReadonlyContext
|
from google.adk.agents.readonly_context import ReadonlyContext
|
||||||
@@ -10,18 +10,33 @@ from google.oauth2 import id_token
|
|||||||
|
|
||||||
from va_agent.config import settings
|
from va_agent.config import settings
|
||||||
|
|
||||||
_REFRESH_MARGIN = 300 # refresh 5 min before expiry
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_REFRESH_MARGIN = 900 # refresh 15 min before expiry
|
||||||
|
|
||||||
_lock = threading.Lock()
|
|
||||||
_token: str | None = None
|
_token: str | None = None
|
||||||
_token_exp: float = 0.0
|
_token_exp: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_token() -> tuple[str, float]:
|
||||||
|
"""Fetch a fresh ID token (blocking I/O)."""
|
||||||
|
tok = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience)
|
||||||
|
exp = jwt.decode(tok, verify=False)["exp"]
|
||||||
|
return tok, exp
|
||||||
|
|
||||||
|
|
||||||
def auth_headers_provider(_ctx: ReadonlyContext | None = None) -> dict[str, str]:
|
def auth_headers_provider(_ctx: ReadonlyContext | None = None) -> dict[str, str]:
|
||||||
"""Return Authorization headers with a cached ID token."""
|
"""Return Authorization headers, refreshing the cached token when needed.
|
||||||
global _token, _token_exp # noqa: PLW0603
|
|
||||||
with _lock:
|
With Streamable HTTP transport every tool call is a fresh HTTP
|
||||||
if _token is None or time.time() >= _token_exp - _REFRESH_MARGIN:
|
request, so returning a valid token here is sufficient — no
|
||||||
_token = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience)
|
background refresh loop required.
|
||||||
_token_exp = jwt.decode(_token, verify=False)["exp"]
|
"""
|
||||||
return {"Authorization": f"Bearer {_token}"}
|
global _token, _token_exp
|
||||||
|
|
||||||
|
if _token is not None and time.time() < _token_exp - _REFRESH_MARGIN:
|
||||||
|
return {"Authorization": f"Bearer {_token}"}
|
||||||
|
|
||||||
|
tok, exp = _fetch_token()
|
||||||
|
_token, _token_exp = tok, exp
|
||||||
|
return {"Authorization": f"Bearer {tok}"}
|
||||||
|
|||||||
@@ -75,8 +75,7 @@ class TestAuthHeadersProvider:
|
|||||||
) -> None:
|
) -> None:
|
||||||
mock_settings.mcp_audience = "https://my-service"
|
mock_settings.mcp_audience = "https://my-service"
|
||||||
|
|
||||||
# First token expires within the refresh margin (5 min).
|
first_exp = time.time() + 100 # < 900s margin
|
||||||
first_exp = time.time() + 100 # < 300s margin
|
|
||||||
second_exp = time.time() + 3600
|
second_exp = time.time() + 3600
|
||||||
mock_fetch.side_effect = [
|
mock_fetch.side_effect = [
|
||||||
_make_fake_token(first_exp),
|
_make_fake_token(first_exp),
|
||||||
|
|||||||
Reference in New Issue
Block a user