Compare commits
2 Commits
e77a2ba2ed
...
b47b84cfd1
| Author | SHA1 | Date | |
|---|---|---|---|
| b47b84cfd1 | |||
| 9a2643a029 |
BIN
.config.yaml.swp
Normal file
BIN
.config.yaml.swp
Normal file
Binary file not shown.
@@ -29,10 +29,10 @@ dev = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
exclude = ["scripts"]
|
exclude = ["utils", "tests"]
|
||||||
|
|
||||||
[tool.ty.src]
|
[tool.ty.src]
|
||||||
exclude = ["scripts"]
|
exclude = ["utils", "tests"]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ['ALL']
|
select = ['ALL']
|
||||||
|
|||||||
@@ -1,71 +1,19 @@
|
|||||||
"""ADK agent with vector search RAG tool."""
|
"""ADK agent with vector search RAG tool."""
|
||||||
|
|
||||||
import base64
|
|
||||||
import concurrent.futures
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.adk.agents.llm_agent import Agent
|
from google.adk.agents.llm_agent import Agent
|
||||||
from google.adk.runners import Runner
|
from google.adk.runners import Runner
|
||||||
from google.adk.tools.mcp_tool import McpToolset
|
from google.adk.tools.mcp_tool import McpToolset
|
||||||
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
||||||
from google.auth.transport.requests import Request as GAuthRequest
|
|
||||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
|
||||||
# --- Autenticación Cloud Run → Cloud Run (ID Token) ---
|
from va_agent.auth import auth_headers_provider
|
||||||
from google.oauth2 import id_token
|
|
||||||
|
|
||||||
from va_agent.config import settings
|
from va_agent.config import settings
|
||||||
from va_agent.session import FirestoreSessionService
|
from va_agent.session import FirestoreSessionService
|
||||||
|
|
||||||
|
|
||||||
def _fetch_id_token(audience: str) -> str:
|
|
||||||
"""Return an ID token for a protected Cloud Run service."""
|
|
||||||
return id_token.fetch_id_token(GAuthRequest(), audience)
|
|
||||||
|
|
||||||
|
|
||||||
def _jwt_exp(token: str) -> float:
|
|
||||||
"""Return the ``exp`` claim (epoch seconds) from a JWT without verification."""
|
|
||||||
payload = token.split(".")[1]
|
|
||||||
# Fix base64url padding
|
|
||||||
padded = payload + "=" * (-len(payload) % 4)
|
|
||||||
return float(json.loads(base64.urlsafe_b64decode(padded))["exp"])
|
|
||||||
|
|
||||||
|
|
||||||
# Audience = URL del MCP remoto
|
|
||||||
_MCP_URL = settings.mcp_remote_url
|
|
||||||
_MCP_AUDIENCE = settings.mcp_audience
|
|
||||||
|
|
||||||
# Reusable pool for the blocking metadata-server call inside the sync callback.
|
|
||||||
_token_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
|
||||||
|
|
||||||
# Cached token and its expiry (epoch seconds).
|
|
||||||
_cached_token: str | None = None
|
|
||||||
_cached_token_exp: float = 0.0
|
|
||||||
|
|
||||||
_TOKEN_REFRESH_MARGIN = 300 # refresh 5 min before expiry
|
|
||||||
|
|
||||||
|
|
||||||
def _auth_headers_provider(_ctx: Any = None) -> dict[str, str]:
|
|
||||||
global _cached_token, _cached_token_exp # noqa: PLW0603
|
|
||||||
# Reuse a valid token; refresh only when near expiry.
|
|
||||||
expired = time.time() >= _cached_token_exp - _TOKEN_REFRESH_MARGIN
|
|
||||||
if _cached_token is None or expired:
|
|
||||||
# header_provider is called synchronously by ADK inside an async path.
|
|
||||||
# Run the blocking HTTP call in a thread so we don't stall the event loop.
|
|
||||||
token = _token_pool.submit(_fetch_id_token, _MCP_AUDIENCE).result()
|
|
||||||
_cached_token = token
|
|
||||||
_cached_token_exp = _jwt_exp(token)
|
|
||||||
return {"Authorization": f"Bearer {_cached_token}"}
|
|
||||||
|
|
||||||
|
|
||||||
connection_params = SseConnectionParams(url=_MCP_URL)
|
|
||||||
|
|
||||||
toolset = McpToolset(
|
toolset = McpToolset(
|
||||||
connection_params=connection_params,
|
connection_params=SseConnectionParams(url=settings.mcp_remote_url),
|
||||||
header_provider=_auth_headers_provider,
|
header_provider=auth_headers_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
|
|||||||
27
src/va_agent/auth.py
Normal file
27
src/va_agent/auth.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""ID-token auth for Cloud Run → Cloud Run calls."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from google.adk.agents.readonly_context import ReadonlyContext
|
||||||
|
from google.auth import jwt
|
||||||
|
from google.auth.transport.requests import Request as GAuthRequest
|
||||||
|
from google.oauth2 import id_token
|
||||||
|
|
||||||
|
from va_agent.config import settings
|
||||||
|
|
||||||
|
_REFRESH_MARGIN = 300 # refresh 5 min before expiry
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_token: str | None = None
|
||||||
|
_token_exp: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def auth_headers_provider(_ctx: ReadonlyContext | None = None) -> dict[str, str]:
|
||||||
|
"""Return Authorization headers with a cached ID token."""
|
||||||
|
global _token, _token_exp # noqa: PLW0603
|
||||||
|
with _lock:
|
||||||
|
if _token is None or time.time() >= _token_exp - _REFRESH_MARGIN:
|
||||||
|
_token = id_token.fetch_id_token(GAuthRequest(), settings.mcp_audience)
|
||||||
|
_token_exp = jwt.decode(_token, verify=False)["exp"]
|
||||||
|
return {"Authorization": f"Bearer {_token}"}
|
||||||
@@ -11,7 +11,7 @@ from google.cloud.firestore_v1.async_client import AsyncClient
|
|||||||
|
|
||||||
from va_agent.session import FirestoreSessionService
|
from va_agent.session import FirestoreSessionService
|
||||||
|
|
||||||
os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8153")
|
os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8602")
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
|
|||||||
92
tests/test_auth.py
Normal file
92
tests/test_auth.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Tests for ID-token auth caching and refresh logic."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import va_agent.auth as auth_mod
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_module_state() -> None:
|
||||||
|
"""Reset the module-level token cache between tests."""
|
||||||
|
auth_mod._token = None # noqa: SLF001
|
||||||
|
auth_mod._token_exp = 0.0 # noqa: SLF001
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fake_token(exp: float) -> str:
|
||||||
|
"""Return a dummy token string (content doesn't matter, jwt.decode is mocked)."""
|
||||||
|
return f"fake-token-exp-{exp}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthHeadersProvider:
|
||||||
|
"""Tests for auth_headers_provider."""
|
||||||
|
|
||||||
|
def setup_method(self) -> None:
|
||||||
|
_reset_module_state()
|
||||||
|
|
||||||
|
@patch("va_agent.auth.jwt.decode")
|
||||||
|
@patch("va_agent.auth.id_token.fetch_id_token")
|
||||||
|
@patch("va_agent.auth.settings", new_callable=MagicMock)
|
||||||
|
def test_fetches_token_on_first_call(
|
||||||
|
self,
|
||||||
|
mock_settings: MagicMock,
|
||||||
|
mock_fetch: MagicMock,
|
||||||
|
mock_decode: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
mock_settings.mcp_audience = "https://my-service"
|
||||||
|
exp = time.time() + 3600
|
||||||
|
mock_fetch.return_value = _make_fake_token(exp)
|
||||||
|
mock_decode.return_value = {"exp": exp}
|
||||||
|
|
||||||
|
headers = auth_mod.auth_headers_provider()
|
||||||
|
|
||||||
|
assert headers == {"Authorization": f"Bearer {_make_fake_token(exp)}"}
|
||||||
|
mock_fetch.assert_called_once()
|
||||||
|
|
||||||
|
@patch("va_agent.auth.jwt.decode")
|
||||||
|
@patch("va_agent.auth.id_token.fetch_id_token")
|
||||||
|
@patch("va_agent.auth.settings", new_callable=MagicMock)
|
||||||
|
def test_caches_token_on_subsequent_calls(
|
||||||
|
self,
|
||||||
|
mock_settings: MagicMock,
|
||||||
|
mock_fetch: MagicMock,
|
||||||
|
mock_decode: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
mock_settings.mcp_audience = "https://my-service"
|
||||||
|
exp = time.time() + 3600
|
||||||
|
mock_fetch.return_value = _make_fake_token(exp)
|
||||||
|
mock_decode.return_value = {"exp": exp}
|
||||||
|
|
||||||
|
auth_mod.auth_headers_provider()
|
||||||
|
auth_mod.auth_headers_provider()
|
||||||
|
auth_mod.auth_headers_provider()
|
||||||
|
|
||||||
|
mock_fetch.assert_called_once()
|
||||||
|
|
||||||
|
@patch("va_agent.auth.jwt.decode")
|
||||||
|
@patch("va_agent.auth.id_token.fetch_id_token")
|
||||||
|
@patch("va_agent.auth.settings", new_callable=MagicMock)
|
||||||
|
def test_refreshes_token_when_near_expiry(
|
||||||
|
self,
|
||||||
|
mock_settings: MagicMock,
|
||||||
|
mock_fetch: MagicMock,
|
||||||
|
mock_decode: MagicMock,
|
||||||
|
) -> None:
|
||||||
|
mock_settings.mcp_audience = "https://my-service"
|
||||||
|
|
||||||
|
# First token expires within the refresh margin (5 min).
|
||||||
|
first_exp = time.time() + 100 # < 300s margin
|
||||||
|
second_exp = time.time() + 3600
|
||||||
|
mock_fetch.side_effect = [
|
||||||
|
_make_fake_token(first_exp),
|
||||||
|
_make_fake_token(second_exp),
|
||||||
|
]
|
||||||
|
mock_decode.side_effect = [{"exp": first_exp}, {"exp": second_exp}]
|
||||||
|
|
||||||
|
first = auth_mod.auth_headers_provider()
|
||||||
|
second = auth_mod.auth_headers_provider()
|
||||||
|
|
||||||
|
assert first == {"Authorization": f"Bearer {_make_fake_token(first_exp)}"}
|
||||||
|
assert second == {"Authorization": f"Bearer {_make_fake_token(second_exp)}"}
|
||||||
|
assert mock_fetch.call_count == 2
|
||||||
Reference in New Issue
Block a user