"""Tests for ID-token auth caching and refresh logic.""" from __future__ import annotations import time from unittest.mock import MagicMock, patch import va_agent.auth as auth_mod def _reset_module_state() -> None: """Reset the module-level token cache between tests.""" auth_mod._token = None # noqa: SLF001 auth_mod._token_exp = 0.0 # noqa: SLF001 def _make_fake_token(exp: float) -> str: """Return a dummy token string (content doesn't matter, jwt.decode is mocked).""" return f"fake-token-exp-{exp}" class TestAuthHeadersProvider: """Tests for auth_headers_provider.""" def setup_method(self) -> None: _reset_module_state() @patch("va_agent.auth.jwt.decode") @patch("va_agent.auth.id_token.fetch_id_token") @patch("va_agent.auth.settings", new_callable=MagicMock) def test_fetches_token_on_first_call( self, mock_settings: MagicMock, mock_fetch: MagicMock, mock_decode: MagicMock, ) -> None: mock_settings.mcp_audience = "https://my-service" exp = time.time() + 3600 mock_fetch.return_value = _make_fake_token(exp) mock_decode.return_value = {"exp": exp} headers = auth_mod.auth_headers_provider() assert headers == {"Authorization": f"Bearer {_make_fake_token(exp)}"} mock_fetch.assert_called_once() @patch("va_agent.auth.jwt.decode") @patch("va_agent.auth.id_token.fetch_id_token") @patch("va_agent.auth.settings", new_callable=MagicMock) def test_caches_token_on_subsequent_calls( self, mock_settings: MagicMock, mock_fetch: MagicMock, mock_decode: MagicMock, ) -> None: mock_settings.mcp_audience = "https://my-service" exp = time.time() + 3600 mock_fetch.return_value = _make_fake_token(exp) mock_decode.return_value = {"exp": exp} auth_mod.auth_headers_provider() auth_mod.auth_headers_provider() auth_mod.auth_headers_provider() mock_fetch.assert_called_once() @patch("va_agent.auth.jwt.decode") @patch("va_agent.auth.id_token.fetch_id_token") @patch("va_agent.auth.settings", new_callable=MagicMock) def test_refreshes_token_when_near_expiry( self, mock_settings: MagicMock, mock_fetch: MagicMock, mock_decode: MagicMock, ) -> None: mock_settings.mcp_audience = "https://my-service" # First token expires within the refresh margin (5 min). first_exp = time.time() + 100 # < 300s margin second_exp = time.time() + 3600 mock_fetch.side_effect = [ _make_fake_token(first_exp), _make_fake_token(second_exp), ] mock_decode.side_effect = [{"exp": first_exp}, {"exp": second_exp}] first = auth_mod.auth_headers_provider() second = auth_mod.auth_headers_provider() assert first == {"Authorization": f"Bearer {_make_fake_token(first_exp)}"} assert second == {"Authorization": f"Bearer {_make_fake_token(second_exp)}"} assert mock_fetch.call_count == 2