Merge pull request 'Switch to shttp transport' (#16) from streamable-http into main
Reviewed-on: #16
This commit was merged in pull request #16.
This commit is contained in:
@@ -3,7 +3,7 @@ google_cloud_location: us-central1
|
||||
|
||||
firestore_db: bnt-orquestador-cognitivo-firestore-bdo-dev
|
||||
|
||||
mcp_remote_url: "https://ap01194-orq-cog-rag-connector-1007577023101.us-central1.run.app/sse"
|
||||
mcp_remote_url: "https://ap01194-orq-cog-rag-connector-1007577023101.us-central1.run.app/mcp"
|
||||
# audience sin la ruta, para emitir el ID Token:
|
||||
mcp_audience: "https://ap01194-orq-cog-rag-connector-1007577023101.us-central1.run.app"
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from google import genai
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.tools.mcp_tool import McpToolset
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
|
||||
from va_agent.auth import auth_headers_provider
|
||||
@@ -12,7 +12,7 @@ from va_agent.config import settings
|
||||
from va_agent.session import FirestoreSessionService
|
||||
|
||||
toolset = McpToolset(
|
||||
connection_params=SseConnectionParams(url=settings.mcp_remote_url),
|
||||
connection_params=StreamableHTTPConnectionParams(url=settings.mcp_remote_url),
|
||||
header_provider=auth_headers_provider,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""ID-token auth for Cloud Run → Cloud Run calls."""
|
||||
|
||||
import threading
|
||||
import logging
|
||||
import time
|
||||
|
||||
from google.adk.agents.readonly_context import ReadonlyContext
|
||||
@@ -10,18 +10,33 @@ from google.oauth2 import id_token
|
||||
|
||||
from va_agent.config import settings
|
||||
|
||||
_REFRESH_MARGIN = 300 # refresh 5 min before expiry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_REFRESH_MARGIN = 900 # refresh 15 min before expiry
|
||||
|
||||
_lock = threading.Lock()
|
||||
_token: str | None = None
|
||||
_token_exp: float = 0.0
|
||||
|
||||
|
||||
def _fetch_token() -> tuple[str, float]:
|
||||
"""Fetch a fresh ID token (blocking I/O)."""
|
||||
tok = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience)
|
||||
exp = jwt.decode(tok, verify=False)["exp"]
|
||||
return tok, exp
|
||||
|
||||
|
||||
def auth_headers_provider(_ctx: ReadonlyContext | None = None) -> dict[str, str]:
|
||||
"""Return Authorization headers with a cached ID token."""
|
||||
global _token, _token_exp # noqa: PLW0603
|
||||
with _lock:
|
||||
if _token is None or time.time() >= _token_exp - _REFRESH_MARGIN:
|
||||
_token = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience)
|
||||
_token_exp = jwt.decode(_token, verify=False)["exp"]
|
||||
"""Return Authorization headers, refreshing the cached token when needed.
|
||||
|
||||
With Streamable HTTP transport every tool call is a fresh HTTP
|
||||
request, so returning a valid token here is sufficient — no
|
||||
background refresh loop required.
|
||||
"""
|
||||
global _token, _token_exp
|
||||
|
||||
if _token is not None and time.time() < _token_exp - _REFRESH_MARGIN:
|
||||
return {"Authorization": f"Bearer {_token}"}
|
||||
|
||||
tok, exp = _fetch_token()
|
||||
_token, _token_exp = tok, exp
|
||||
return {"Authorization": f"Bearer {tok}"}
|
||||
|
||||
@@ -75,8 +75,7 @@ class TestAuthHeadersProvider:
|
||||
) -> None:
|
||||
mock_settings.mcp_audience = "https://my-service"
|
||||
|
||||
# First token expires within the refresh margin (5 min).
|
||||
first_exp = time.time() + 100 # < 300s margin
|
||||
first_exp = time.time() + 100 # < 900s margin
|
||||
second_exp = time.time() + 3600
|
||||
mock_fetch.side_effect = [
|
||||
_make_fake_token(first_exp),
|
||||
|
||||
Reference in New Issue
Block a user