Improve auth implementation
This commit is contained in:
92
tests/test_auth.py
Normal file
92
tests/test_auth.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for ID-token auth caching and refresh logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import va_agent.auth as auth_mod
|
||||
|
||||
|
||||
def _reset_module_state() -> None:
|
||||
"""Reset the module-level token cache between tests."""
|
||||
auth_mod._token = None # noqa: SLF001
|
||||
auth_mod._token_exp = 0.0 # noqa: SLF001
|
||||
|
||||
|
||||
def _make_fake_token(exp: float) -> str:
|
||||
"""Return a dummy token string (content doesn't matter, jwt.decode is mocked)."""
|
||||
return f"fake-token-exp-{exp}"
|
||||
|
||||
|
||||
class TestAuthHeadersProvider:
|
||||
"""Tests for auth_headers_provider."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
_reset_module_state()
|
||||
|
||||
@patch("va_agent.auth.jwt.decode")
|
||||
@patch("va_agent.auth.id_token.fetch_id_token")
|
||||
@patch("va_agent.auth.settings", new_callable=MagicMock)
|
||||
def test_fetches_token_on_first_call(
|
||||
self,
|
||||
mock_settings: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
) -> None:
|
||||
mock_settings.mcp_audience = "https://my-service"
|
||||
exp = time.time() + 3600
|
||||
mock_fetch.return_value = _make_fake_token(exp)
|
||||
mock_decode.return_value = {"exp": exp}
|
||||
|
||||
headers = auth_mod.auth_headers_provider()
|
||||
|
||||
assert headers == {"Authorization": f"Bearer {_make_fake_token(exp)}"}
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
@patch("va_agent.auth.jwt.decode")
|
||||
@patch("va_agent.auth.id_token.fetch_id_token")
|
||||
@patch("va_agent.auth.settings", new_callable=MagicMock)
|
||||
def test_caches_token_on_subsequent_calls(
|
||||
self,
|
||||
mock_settings: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
) -> None:
|
||||
mock_settings.mcp_audience = "https://my-service"
|
||||
exp = time.time() + 3600
|
||||
mock_fetch.return_value = _make_fake_token(exp)
|
||||
mock_decode.return_value = {"exp": exp}
|
||||
|
||||
auth_mod.auth_headers_provider()
|
||||
auth_mod.auth_headers_provider()
|
||||
auth_mod.auth_headers_provider()
|
||||
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
@patch("va_agent.auth.jwt.decode")
|
||||
@patch("va_agent.auth.id_token.fetch_id_token")
|
||||
@patch("va_agent.auth.settings", new_callable=MagicMock)
|
||||
def test_refreshes_token_when_near_expiry(
|
||||
self,
|
||||
mock_settings: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
) -> None:
|
||||
mock_settings.mcp_audience = "https://my-service"
|
||||
|
||||
# First token expires within the refresh margin (5 min).
|
||||
first_exp = time.time() + 100 # < 300s margin
|
||||
second_exp = time.time() + 3600
|
||||
mock_fetch.side_effect = [
|
||||
_make_fake_token(first_exp),
|
||||
_make_fake_token(second_exp),
|
||||
]
|
||||
mock_decode.side_effect = [{"exp": first_exp}, {"exp": second_exp}]
|
||||
|
||||
first = auth_mod.auth_headers_provider()
|
||||
second = auth_mod.auth_headers_provider()
|
||||
|
||||
assert first == {"Authorization": f"Bearer {_make_fake_token(first_exp)}"}
|
||||
assert second == {"Authorization": f"Bearer {_make_fake_token(second_exp)}"}
|
||||
assert mock_fetch.call_count == 2
|
||||
Reference in New Issue
Block a user