From 9a2643a0299cd28e6a61bd1579d0f59ae2b6a9b6 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Wed, 25 Feb 2026 17:28:20 +0000 Subject: [PATCH] Improve auth implementation --- .config.yaml.swp | Bin 0 -> 12288 bytes pyproject.toml | 4 +- src/va_agent/agent.py | 58 ++------------------------ src/va_agent/auth.py | 27 +++++++++++++ tests/conftest.py | 2 +- tests/test_auth.py | 92 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 125 insertions(+), 58 deletions(-) create mode 100644 .config.yaml.swp create mode 100644 src/va_agent/auth.py create mode 100644 tests/test_auth.py diff --git a/.config.yaml.swp b/.config.yaml.swp new file mode 100644 index 0000000000000000000000000000000000000000..99271d2df5b0fef1b475976aecb20f60576c0fd9 GIT binary patch literal 12288 zcmeHNO^6)F6)u_l+SpDU2+8evMDS|S^zN**Sv-dj*-@~JB^&Q*oj_oy>8|OiRdrRX zs(X?ZA@M~f1RO~0gA)wP-n@qe2Va6uiRWbSDLL4_7^6dS-BZr_UUkp@*pM8eLjpbM zF*7w)uipFK_r5oaX1uz6?K(Y~3qPhj7Kefnh~UWHwUef%XMK7#!P_Gj4dVQ;`*f!&3D5B5*QdjR_p zjAQ>gtb23_bO>|^{C5Zt9(}rMs1CLC+L2eGbT(6je}%NlOOiPj^-Pk(Vo|G5VNrXp zE1Bx~J2nxGmrX@hrPx|YGHNDmB|X(v((PL9NSZibNW)IB98_NGLgR-~j%3^T?OH3Z2q#tG2(*piHGWFvgR_Vg7#JL@5vFYXJU@YF#h#En=f}a54DjM`r3a3gKzh9aVMp00*enkL_7>BcemGEi`1uZEO*T~@AeR%J3I zi=$frBQPveJw}G1S?|;G4Y!PwMyFf?t4W$drV16vWaZ#g*`NVsNN7`&Iw^YUicps}>5k>+F? z(!`i^a0;KdObIdcmJab0$7jHyKubu^QC>5`2$>6qY^66cTsY&>HHB&_!@DQi;%-zO z2iyAUJ2E?)2=f7vF92I?WFQ(A;6;!Im~)JUo3$K1n~qn~!1oC`V-ib|s}s-;VMsVC zW$q%vGsF|&xq!l)mQ2qf4(>+2s9D=cX!+^b0!^0< zmuYi;*tSzvr|3Hr6S&LRMm19!{nBhkk~+6~qSIIs-r3+;AaNBGb1(WxCRSp1} zLa*b}7qC#H6$8z{$-qf!ZrQS3N3bM|vCeB&ucF>1{1y`wFtnt066#&qtuxuXNWCp( zAsnZ~rgs4a(#(aEDsZqtNh;?13`42X`5`cBCFU~X`5{Y%qWQaEr9?D0gQO>JS@%<; zvRsXD*r(Gc8s?@T-*KP`JteWXszX#@7FjNI^ixEi5WCQP$`S?!f-9OUGlfb*pxafv zY_zv&-l%Gy%>jp6HgnaH$yxnt;Jt)&u4UR4oX6Er%K0Br0wXJJoh+0`g0{_kjWcI; z2dQ3~$|_tbOB`CZD(4{)mYam<2X-QZ2IWU-4aU!87&V_j*nI)N_-I@kqbjHyH@DFY zLle&cmu!KflxavXZrY`EzHgv^xN1ArDaSBypspp*Q?WNRFfXVAgowF9$xzfgNvw=2 z+>Q>c@(oN@R(0gY(cbU&@4b2VXz#UO;pN*$dvE;uXz#bL-h1<{qrE@8j4r*7;f8ch z*D7VULZCp7pXtKUrT%3q=lkBviGw!LjT*|=$Q4M*xKALDc!}##Go!dFkb3I9xBoOo zp^JhW=7*0nKi}Jb0Rtw~B{RddOXzP%VSG>bUVfdALx>?f4MeCClaa4NOPro?$GY)i z?59RCu7X6geML)*SA&rT?Bn^F_7xhLc}4`JpgNY0+OB3DG~9P+#{Z7 z5Hp@w@b&X{73!tT5c&cl>;77@-k%_xY5S=~F@_Oc{B@a@BUxuCRf@IVw5rN5Tv?Ij z>R|BX<-YT``>D(OUgmhV_)1kd-yf{5Zme%?tgc-etPYYo^fA{lP|P6lwN3D4qbJVO z!fOl~W$|@|YQ_p6tV-i6lc8%rpe=Vx*cwc{TG0O3e|CTh;ewF_Sr3%M?1jj J-`&D~{{>-18B72G literal 0 HcmV?d00001 diff --git a/pyproject.toml b/pyproject.toml index bdf1185..65f8d2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,10 +29,10 @@ dev = [ ] [tool.ruff] -exclude = ["scripts"] +exclude = ["utils", "tests"] [tool.ty.src] -exclude = ["scripts"] +exclude = ["utils", "tests"] [tool.ruff.lint] select = ['ALL'] diff --git a/src/va_agent/agent.py b/src/va_agent/agent.py index 15e09d3..48deae1 100644 --- a/src/va_agent/agent.py +++ b/src/va_agent/agent.py @@ -1,71 +1,19 @@ """ADK agent with vector search RAG tool.""" -import base64 -import concurrent.futures -import json -import time -from typing import Any - from google import genai from google.adk.agents.llm_agent import Agent from google.adk.runners import Runner from google.adk.tools.mcp_tool import McpToolset from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams -from google.auth.transport.requests import Request as GAuthRequest from google.cloud.firestore_v1.async_client import AsyncClient -# --- Autenticación Cloud Run → Cloud Run (ID Token) --- -from google.oauth2 import id_token - +from va_agent.auth import auth_headers_provider from va_agent.config import settings from va_agent.session import FirestoreSessionService - -def _fetch_id_token(audience: str) -> str: - """Return an ID token for a protected Cloud Run service.""" - return id_token.fetch_id_token(GAuthRequest(), audience) - - -def _jwt_exp(token: str) -> float: - """Return the ``exp`` claim (epoch seconds) from a JWT without verification.""" - payload = token.split(".")[1] - # Fix base64url padding - padded = payload + "=" * (-len(payload) % 4) - return float(json.loads(base64.urlsafe_b64decode(padded))["exp"]) - - -# Audience = URL del MCP remoto -_MCP_URL = settings.mcp_remote_url -_MCP_AUDIENCE = settings.mcp_audience - -# Reusable pool for the blocking metadata-server call inside the sync callback. -_token_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) - -# Cached token and its expiry (epoch seconds). -_cached_token: str | None = None -_cached_token_exp: float = 0.0 - -_TOKEN_REFRESH_MARGIN = 300 # refresh 5 min before expiry - - -def _auth_headers_provider(_ctx: Any = None) -> dict[str, str]: - global _cached_token, _cached_token_exp # noqa: PLW0603 - # Reuse a valid token; refresh only when near expiry. - expired = time.time() >= _cached_token_exp - _TOKEN_REFRESH_MARGIN - if _cached_token is None or expired: - # header_provider is called synchronously by ADK inside an async path. - # Run the blocking HTTP call in a thread so we don't stall the event loop. - token = _token_pool.submit(_fetch_id_token, _MCP_AUDIENCE).result() - _cached_token = token - _cached_token_exp = _jwt_exp(token) - return {"Authorization": f"Bearer {_cached_token}"} - - -connection_params = SseConnectionParams(url=_MCP_URL) - toolset = McpToolset( - connection_params=connection_params, - header_provider=_auth_headers_provider, + connection_params=SseConnectionParams(url=settings.mcp_remote_url), + header_provider=auth_headers_provider, ) agent = Agent( diff --git a/src/va_agent/auth.py b/src/va_agent/auth.py new file mode 100644 index 0000000..e51fcef --- /dev/null +++ b/src/va_agent/auth.py @@ -0,0 +1,27 @@ +"""ID-token auth for Cloud Run → Cloud Run calls.""" + +import threading +import time + +from google.adk.agents.readonly_context import ReadonlyContext +from google.auth import jwt +from google.auth.transport.requests import Request as GAuthRequest +from google.oauth2 import id_token + +from va_agent.config import settings + +_REFRESH_MARGIN = 300 # refresh 5 min before expiry + +_lock = threading.Lock() +_token: str | None = None +_token_exp: float = 0.0 + + +def auth_headers_provider(_ctx: ReadonlyContext | None = None) -> dict[str, str]: + """Return Authorization headers with a cached ID token.""" + global _token, _token_exp # noqa: PLW0603 + with _lock: + if _token is None or time.time() >= _token_exp - _REFRESH_MARGIN: + _token = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience) + _token_exp = jwt.decode(_token, verify=False)["exp"] + return {"Authorization": f"Bearer {_token}"} diff --git a/tests/conftest.py b/tests/conftest.py index ed72390..959b677 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ from google.cloud.firestore_v1.async_client import AsyncClient from va_agent.session import FirestoreSessionService -os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8153") +os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8602") @pytest_asyncio.fixture diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..9a467db --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,92 @@ +"""Tests for ID-token auth caching and refresh logic.""" + +from __future__ import annotations + +import time +from unittest.mock import MagicMock, patch + +import va_agent.auth as auth_mod + + +def _reset_module_state() -> None: + """Reset the module-level token cache between tests.""" + auth_mod._token = None # noqa: SLF001 + auth_mod._token_exp = 0.0 # noqa: SLF001 + + +def _make_fake_token(exp: float) -> str: + """Return a dummy token string (content doesn't matter, jwt.decode is mocked).""" + return f"fake-token-exp-{exp}" + + +class TestAuthHeadersProvider: + """Tests for auth_headers_provider.""" + + def setup_method(self) -> None: + _reset_module_state() + + @patch("va_agent.auth.jwt.decode") + @patch("va_agent.auth.id_token.fetch_id_token") + @patch("va_agent.auth.settings", new_callable=MagicMock) + def test_fetches_token_on_first_call( + self, + mock_settings: MagicMock, + mock_fetch: MagicMock, + mock_decode: MagicMock, + ) -> None: + mock_settings.mcp_audience = "https://my-service" + exp = time.time() + 3600 + mock_fetch.return_value = _make_fake_token(exp) + mock_decode.return_value = {"exp": exp} + + headers = auth_mod.auth_headers_provider() + + assert headers == {"Authorization": f"Bearer {_make_fake_token(exp)}"} + mock_fetch.assert_called_once() + + @patch("va_agent.auth.jwt.decode") + @patch("va_agent.auth.id_token.fetch_id_token") + @patch("va_agent.auth.settings", new_callable=MagicMock) + def test_caches_token_on_subsequent_calls( + self, + mock_settings: MagicMock, + mock_fetch: MagicMock, + mock_decode: MagicMock, + ) -> None: + mock_settings.mcp_audience = "https://my-service" + exp = time.time() + 3600 + mock_fetch.return_value = _make_fake_token(exp) + mock_decode.return_value = {"exp": exp} + + auth_mod.auth_headers_provider() + auth_mod.auth_headers_provider() + auth_mod.auth_headers_provider() + + mock_fetch.assert_called_once() + + @patch("va_agent.auth.jwt.decode") + @patch("va_agent.auth.id_token.fetch_id_token") + @patch("va_agent.auth.settings", new_callable=MagicMock) + def test_refreshes_token_when_near_expiry( + self, + mock_settings: MagicMock, + mock_fetch: MagicMock, + mock_decode: MagicMock, + ) -> None: + mock_settings.mcp_audience = "https://my-service" + + # First token expires within the refresh margin (5 min). + first_exp = time.time() + 100 # < 300s margin + second_exp = time.time() + 3600 + mock_fetch.side_effect = [ + _make_fake_token(first_exp), + _make_fake_token(second_exp), + ] + mock_decode.side_effect = [{"exp": first_exp}, {"exp": second_exp}] + + first = auth_mod.auth_headers_provider() + second = auth_mod.auth_headers_provider() + + assert first == {"Authorization": f"Bearer {_make_fake_token(first_exp)}"} + assert second == {"Authorization": f"Bearer {_make_fake_token(second_exp)}"} + assert mock_fetch.call_count == 2