|
|
|
|
@@ -1,44 +1,73 @@
|
|
|
|
|
"""ADK agent with vector search RAG tool."""
|
|
|
|
|
|
|
|
|
|
import base64
|
|
|
|
|
import concurrent.futures
|
|
|
|
|
import json
|
|
|
|
|
import time
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
from google import genai
|
|
|
|
|
from google.adk.agents.llm_agent import Agent
|
|
|
|
|
from google.adk.runners import Runner
|
|
|
|
|
from google.adk.tools.mcp_tool import McpToolset
|
|
|
|
|
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
|
|
|
|
from google.auth.transport.requests import Request as GAuthRequest
|
|
|
|
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
|
|
|
|
|
|
|
|
|
# --- Autenticación Cloud Run → Cloud Run (ID Token) ---
|
|
|
|
|
from google.oauth2 import id_token
|
|
|
|
|
|
|
|
|
|
from va_agent.config import settings
|
|
|
|
|
from va_agent.session import FirestoreSessionService
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# --- Autenticación Cloud Run → Cloud Run (ID Token) ---
|
|
|
|
|
from google.oauth2 import id_token
|
|
|
|
|
from google.auth.transport.requests import Request as GAuthRequest
|
|
|
|
|
|
|
|
|
|
def _fetch_id_token(audience: str) -> str:
|
|
|
|
|
"""Emite un ID Token para invocar un servicio Cloud Run protegido."""
|
|
|
|
|
"""Return an ID token for a protected Cloud Run service."""
|
|
|
|
|
return id_token.fetch_id_token(GAuthRequest(), audience)
|
|
|
|
|
|
|
|
|
|
# Audience = URL del MCP remoto
|
|
|
|
|
|
|
|
|
|
def _jwt_exp(token: str) -> float:
|
|
|
|
|
"""Return the ``exp`` claim (epoch seconds) from a JWT without verification."""
|
|
|
|
|
payload = token.split(".")[1]
|
|
|
|
|
# Fix base64url padding
|
|
|
|
|
padded = payload + "=" * (-len(payload) % 4)
|
|
|
|
|
return float(json.loads(base64.urlsafe_b64decode(padded))["exp"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Audience = URL del MCP remoto
|
|
|
|
|
_MCP_URL = settings.mcp_remote_url
|
|
|
|
|
_MCP_AUDIENCE = getattr(settings, "mcp_audience", None) or _MCP_URL
|
|
|
|
|
_MCP_AUDIENCE = settings.mcp_audience
|
|
|
|
|
|
|
|
|
|
# Reusable pool for the blocking metadata-server call inside the sync callback.
|
|
|
|
|
_token_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
|
|
|
|
|
|
|
|
# Cached token and its expiry (epoch seconds).
|
|
|
|
|
_cached_token: str | None = None
|
|
|
|
|
_cached_token_exp: float = 0.0
|
|
|
|
|
|
|
|
|
|
_TOKEN_REFRESH_MARGIN = 300 # refresh 5 min before expiry
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _auth_headers_provider() -> dict[str, str]:
|
|
|
|
|
token = _fetch_id_token(_MCP_AUDIENCE)
|
|
|
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
def _auth_headers_provider(_ctx: Any = None) -> dict[str, str]:
|
|
|
|
|
global _cached_token, _cached_token_exp # noqa: PLW0603
|
|
|
|
|
# Reuse a valid token; refresh only when near expiry.
|
|
|
|
|
expired = time.time() >= _cached_token_exp - _TOKEN_REFRESH_MARGIN
|
|
|
|
|
if _cached_token is None or expired:
|
|
|
|
|
# header_provider is called synchronously by ADK inside an async path.
|
|
|
|
|
# Run the blocking HTTP call in a thread so we don't stall the event loop.
|
|
|
|
|
token = _token_pool.submit(_fetch_id_token, _MCP_AUDIENCE).result()
|
|
|
|
|
_cached_token = token
|
|
|
|
|
_cached_token_exp = _jwt_exp(token)
|
|
|
|
|
return {"Authorization": f"Bearer {_cached_token}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
connection_params = SseConnectionParams(
|
|
|
|
|
url=_MCP_URL,
|
|
|
|
|
headers=_auth_headers_provider()
|
|
|
|
|
connection_params = SseConnectionParams(url=_MCP_URL)
|
|
|
|
|
|
|
|
|
|
toolset = McpToolset(
|
|
|
|
|
connection_params=connection_params,
|
|
|
|
|
header_provider=_auth_headers_provider,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# connection_params = SseConnectionParams(url=settings.mcp_remote_url)
|
|
|
|
|
toolset = McpToolset(connection_params=connection_params)
|
|
|
|
|
|
|
|
|
|
agent = Agent(
|
|
|
|
|
model=settings.agent_model,
|
|
|
|
|
name=settings.agent_name,
|
|
|
|
|
|