diff --git a/.config.yaml.swp b/.config.yaml.swp new file mode 100644 index 0000000..99271d2 Binary files /dev/null and b/.config.yaml.swp differ diff --git a/pyproject.toml b/pyproject.toml index bdf1185..65f8d2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,10 +29,10 @@ dev = [ ] [tool.ruff] -exclude = ["scripts"] +exclude = ["utils", "tests"] [tool.ty.src] -exclude = ["scripts"] +exclude = ["utils", "tests"] [tool.ruff.lint] select = ['ALL'] diff --git a/src/va_agent/agent.py b/src/va_agent/agent.py index 15e09d3..48deae1 100644 --- a/src/va_agent/agent.py +++ b/src/va_agent/agent.py @@ -1,71 +1,19 @@ """ADK agent with vector search RAG tool.""" -import base64 -import concurrent.futures -import json -import time -from typing import Any - from google import genai from google.adk.agents.llm_agent import Agent from google.adk.runners import Runner from google.adk.tools.mcp_tool import McpToolset from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams -from google.auth.transport.requests import Request as GAuthRequest from google.cloud.firestore_v1.async_client import AsyncClient -# --- Autenticación Cloud Run → Cloud Run (ID Token) --- -from google.oauth2 import id_token - +from va_agent.auth import auth_headers_provider from va_agent.config import settings from va_agent.session import FirestoreSessionService - -def _fetch_id_token(audience: str) -> str: - """Return an ID token for a protected Cloud Run service.""" - return id_token.fetch_id_token(GAuthRequest(), audience) - - -def _jwt_exp(token: str) -> float: - """Return the ``exp`` claim (epoch seconds) from a JWT without verification.""" - payload = token.split(".")[1] - # Fix base64url padding - padded = payload + "=" * (-len(payload) % 4) - return float(json.loads(base64.urlsafe_b64decode(padded))["exp"]) - - -# Audience = URL del MCP remoto -_MCP_URL = settings.mcp_remote_url -_MCP_AUDIENCE = settings.mcp_audience - -# Reusable pool for the blocking metadata-server call inside the sync callback. -_token_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) - -# Cached token and its expiry (epoch seconds). -_cached_token: str | None = None -_cached_token_exp: float = 0.0 - -_TOKEN_REFRESH_MARGIN = 300 # refresh 5 min before expiry - - -def _auth_headers_provider(_ctx: Any = None) -> dict[str, str]: - global _cached_token, _cached_token_exp # noqa: PLW0603 - # Reuse a valid token; refresh only when near expiry. - expired = time.time() >= _cached_token_exp - _TOKEN_REFRESH_MARGIN - if _cached_token is None or expired: - # header_provider is called synchronously by ADK inside an async path. - # Run the blocking HTTP call in a thread so we don't stall the event loop. - token = _token_pool.submit(_fetch_id_token, _MCP_AUDIENCE).result() - _cached_token = token - _cached_token_exp = _jwt_exp(token) - return {"Authorization": f"Bearer {_cached_token}"} - - -connection_params = SseConnectionParams(url=_MCP_URL) - toolset = McpToolset( - connection_params=connection_params, - header_provider=_auth_headers_provider, + connection_params=SseConnectionParams(url=settings.mcp_remote_url), + header_provider=auth_headers_provider, ) agent = Agent( diff --git a/src/va_agent/auth.py b/src/va_agent/auth.py new file mode 100644 index 0000000..e51fcef --- /dev/null +++ b/src/va_agent/auth.py @@ -0,0 +1,27 @@ +"""ID-token auth for Cloud Run → Cloud Run calls.""" + +import threading +import time + +from google.adk.agents.readonly_context import ReadonlyContext +from google.auth import jwt +from google.auth.transport.requests import Request as GAuthRequest +from google.oauth2 import id_token + +from va_agent.config import settings + +_REFRESH_MARGIN = 300 # refresh 5 min before expiry + +_lock = threading.Lock() +_token: str | None = None +_token_exp: float = 0.0 + + +def auth_headers_provider(_ctx: ReadonlyContext | None = None) -> dict[str, str]: + """Return Authorization headers with a cached ID token.""" + global _token, _token_exp # noqa: PLW0603 + with _lock: + if _token is None or time.time() >= _token_exp - _REFRESH_MARGIN: + _token = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience) + _token_exp = jwt.decode(_token, verify=False)["exp"] + return {"Authorization": f"Bearer {_token}"} diff --git a/tests/conftest.py b/tests/conftest.py index ed72390..959b677 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ from google.cloud.firestore_v1.async_client import AsyncClient from va_agent.session import FirestoreSessionService -os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8153") +os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8602") @pytest_asyncio.fixture diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..9a467db --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,92 @@ +"""Tests for ID-token auth caching and refresh logic.""" + +from __future__ import annotations + +import time +from unittest.mock import MagicMock, patch + +import va_agent.auth as auth_mod + + +def _reset_module_state() -> None: + """Reset the module-level token cache between tests.""" + auth_mod._token = None # noqa: SLF001 + auth_mod._token_exp = 0.0 # noqa: SLF001 + + +def _make_fake_token(exp: float) -> str: + """Return a dummy token string (content doesn't matter, jwt.decode is mocked).""" + return f"fake-token-exp-{exp}" + + +class TestAuthHeadersProvider: + """Tests for auth_headers_provider.""" + + def setup_method(self) -> None: + _reset_module_state() + + @patch("va_agent.auth.jwt.decode") + @patch("va_agent.auth.id_token.fetch_id_token") + @patch("va_agent.auth.settings", new_callable=MagicMock) + def test_fetches_token_on_first_call( + self, + mock_settings: MagicMock, + mock_fetch: MagicMock, + mock_decode: MagicMock, + ) -> None: + mock_settings.mcp_audience = "https://my-service" + exp = time.time() + 3600 + mock_fetch.return_value = _make_fake_token(exp) + mock_decode.return_value = {"exp": exp} + + headers = auth_mod.auth_headers_provider() + + assert headers == {"Authorization": f"Bearer {_make_fake_token(exp)}"} + mock_fetch.assert_called_once() + + @patch("va_agent.auth.jwt.decode") + @patch("va_agent.auth.id_token.fetch_id_token") + @patch("va_agent.auth.settings", new_callable=MagicMock) + def test_caches_token_on_subsequent_calls( + self, + mock_settings: MagicMock, + mock_fetch: MagicMock, + mock_decode: MagicMock, + ) -> None: + mock_settings.mcp_audience = "https://my-service" + exp = time.time() + 3600 + mock_fetch.return_value = _make_fake_token(exp) + mock_decode.return_value = {"exp": exp} + + auth_mod.auth_headers_provider() + auth_mod.auth_headers_provider() + auth_mod.auth_headers_provider() + + mock_fetch.assert_called_once() + + @patch("va_agent.auth.jwt.decode") + @patch("va_agent.auth.id_token.fetch_id_token") + @patch("va_agent.auth.settings", new_callable=MagicMock) + def test_refreshes_token_when_near_expiry( + self, + mock_settings: MagicMock, + mock_fetch: MagicMock, + mock_decode: MagicMock, + ) -> None: + mock_settings.mcp_audience = "https://my-service" + + # First token expires within the refresh margin (5 min). + first_exp = time.time() + 100 # < 300s margin + second_exp = time.time() + 3600 + mock_fetch.side_effect = [ + _make_fake_token(first_exp), + _make_fake_token(second_exp), + ] + mock_decode.side_effect = [{"exp": first_exp}, {"exp": second_exp}] + + first = auth_mod.auth_headers_provider() + second = auth_mod.auth_headers_provider() + + assert first == {"Authorization": f"Bearer {_make_fake_token(first_exp)}"} + assert second == {"Authorization": f"Bearer {_make_fake_token(second_exp)}"} + assert mock_fetch.call_count == 2