Merge pull request 'Improve auth implementation' (#14) from robust-auth into main

Reviewed-on: #14
This commit was merged in pull request #14.
This commit is contained in:
2026-02-25 18:28:27 +00:00
6 changed files with 125 additions and 58 deletions

BIN
.config.yaml.swp Normal file

Binary file not shown.

View File

@@ -29,10 +29,10 @@ dev = [
] ]
[tool.ruff] [tool.ruff]
exclude = ["scripts"] exclude = ["utils", "tests"]
[tool.ty.src] [tool.ty.src]
exclude = ["scripts"] exclude = ["utils", "tests"]
[tool.ruff.lint] [tool.ruff.lint]
select = ['ALL'] select = ['ALL']

View File

@@ -1,71 +1,19 @@
"""ADK agent with vector search RAG tool.""" """ADK agent with vector search RAG tool."""
import base64
import concurrent.futures
import json
import time
from typing import Any
from google import genai from google import genai
from google.adk.agents.llm_agent import Agent from google.adk.agents.llm_agent import Agent
from google.adk.runners import Runner from google.adk.runners import Runner
from google.adk.tools.mcp_tool import McpToolset from google.adk.tools.mcp_tool import McpToolset
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.auth.transport.requests import Request as GAuthRequest
from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.async_client import AsyncClient
# --- Autenticación Cloud Run → Cloud Run (ID Token) --- from va_agent.auth import auth_headers_provider
from google.oauth2 import id_token
from va_agent.config import settings from va_agent.config import settings
from va_agent.session import FirestoreSessionService from va_agent.session import FirestoreSessionService
def _fetch_id_token(audience: str) -> str:
"""Return an ID token for a protected Cloud Run service."""
return id_token.fetch_id_token(GAuthRequest(), audience)
def _jwt_exp(token: str) -> float:
"""Return the ``exp`` claim (epoch seconds) from a JWT without verification."""
payload = token.split(".")[1]
# Fix base64url padding
padded = payload + "=" * (-len(payload) % 4)
return float(json.loads(base64.urlsafe_b64decode(padded))["exp"])
# Audience = URL del MCP remoto
_MCP_URL = settings.mcp_remote_url
_MCP_AUDIENCE = settings.mcp_audience
# Reusable pool for the blocking metadata-server call inside the sync callback.
_token_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Cached token and its expiry (epoch seconds).
_cached_token: str | None = None
_cached_token_exp: float = 0.0
_TOKEN_REFRESH_MARGIN = 300 # refresh 5 min before expiry
def _auth_headers_provider(_ctx: Any = None) -> dict[str, str]:
global _cached_token, _cached_token_exp # noqa: PLW0603
# Reuse a valid token; refresh only when near expiry.
expired = time.time() >= _cached_token_exp - _TOKEN_REFRESH_MARGIN
if _cached_token is None or expired:
# header_provider is called synchronously by ADK inside an async path.
# Run the blocking HTTP call in a thread so we don't stall the event loop.
token = _token_pool.submit(_fetch_id_token, _MCP_AUDIENCE).result()
_cached_token = token
_cached_token_exp = _jwt_exp(token)
return {"Authorization": f"Bearer {_cached_token}"}
connection_params = SseConnectionParams(url=_MCP_URL)
toolset = McpToolset( toolset = McpToolset(
connection_params=connection_params, connection_params=SseConnectionParams(url=settings.mcp_remote_url),
header_provider=_auth_headers_provider, header_provider=auth_headers_provider,
) )
agent = Agent( agent = Agent(

27
src/va_agent/auth.py Normal file
View File

@@ -0,0 +1,27 @@
"""ID-token auth for Cloud Run → Cloud Run calls."""
import threading
import time
from google.adk.agents.readonly_context import ReadonlyContext
from google.auth import jwt
from google.auth.transport.requests import Request as GAuthRequest
from google.oauth2 import id_token
from va_agent.config import settings
_REFRESH_MARGIN = 300 # refresh 5 min before expiry
_lock = threading.Lock()
_token: str | None = None
_token_exp: float = 0.0
def auth_headers_provider(_ctx: ReadonlyContext | None = None) -> dict[str, str]:
"""Return Authorization headers with a cached ID token."""
global _token, _token_exp # noqa: PLW0603
with _lock:
if _token is None or time.time() >= _token_exp - _REFRESH_MARGIN:
_token = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience)
_token_exp = jwt.decode(_token, verify=False)["exp"]
return {"Authorization": f"Bearer {_token}"}

View File

@@ -11,7 +11,7 @@ from google.cloud.firestore_v1.async_client import AsyncClient
from va_agent.session import FirestoreSessionService from va_agent.session import FirestoreSessionService
os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8153") os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8602")
@pytest_asyncio.fixture @pytest_asyncio.fixture

92
tests/test_auth.py Normal file
View File

@@ -0,0 +1,92 @@
"""Tests for ID-token auth caching and refresh logic."""
from __future__ import annotations
import time
from unittest.mock import MagicMock, patch
import va_agent.auth as auth_mod
def _reset_module_state() -> None:
"""Reset the module-level token cache between tests."""
auth_mod._token = None # noqa: SLF001
auth_mod._token_exp = 0.0 # noqa: SLF001
def _make_fake_token(exp: float) -> str:
"""Return a dummy token string (content doesn't matter, jwt.decode is mocked)."""
return f"fake-token-exp-{exp}"
class TestAuthHeadersProvider:
"""Tests for auth_headers_provider."""
def setup_method(self) -> None:
_reset_module_state()
@patch("va_agent.auth.jwt.decode")
@patch("va_agent.auth.id_token.fetch_id_token")
@patch("va_agent.auth.settings", new_callable=MagicMock)
def test_fetches_token_on_first_call(
self,
mock_settings: MagicMock,
mock_fetch: MagicMock,
mock_decode: MagicMock,
) -> None:
mock_settings.mcp_audience = "https://my-service"
exp = time.time() + 3600
mock_fetch.return_value = _make_fake_token(exp)
mock_decode.return_value = {"exp": exp}
headers = auth_mod.auth_headers_provider()
assert headers == {"Authorization": f"Bearer {_make_fake_token(exp)}"}
mock_fetch.assert_called_once()
@patch("va_agent.auth.jwt.decode")
@patch("va_agent.auth.id_token.fetch_id_token")
@patch("va_agent.auth.settings", new_callable=MagicMock)
def test_caches_token_on_subsequent_calls(
self,
mock_settings: MagicMock,
mock_fetch: MagicMock,
mock_decode: MagicMock,
) -> None:
mock_settings.mcp_audience = "https://my-service"
exp = time.time() + 3600
mock_fetch.return_value = _make_fake_token(exp)
mock_decode.return_value = {"exp": exp}
auth_mod.auth_headers_provider()
auth_mod.auth_headers_provider()
auth_mod.auth_headers_provider()
mock_fetch.assert_called_once()
@patch("va_agent.auth.jwt.decode")
@patch("va_agent.auth.id_token.fetch_id_token")
@patch("va_agent.auth.settings", new_callable=MagicMock)
def test_refreshes_token_when_near_expiry(
self,
mock_settings: MagicMock,
mock_fetch: MagicMock,
mock_decode: MagicMock,
) -> None:
mock_settings.mcp_audience = "https://my-service"
# First token expires within the refresh margin (5 min).
first_exp = time.time() + 100 # < 300s margin
second_exp = time.time() + 3600
mock_fetch.side_effect = [
_make_fake_token(first_exp),
_make_fake_token(second_exp),
]
mock_decode.side_effect = [{"exp": first_exp}, {"exp": second_exp}]
first = auth_mod.auth_headers_provider()
second = auth_mod.auth_headers_provider()
assert first == {"Authorization": f"Bearer {_make_fake_token(first_exp)}"}
assert second == {"Authorization": f"Bearer {_make_fake_token(second_exp)}"}
assert mock_fetch.call_count == 2