diff --git a/src/va_agent/governance.py b/src/va_agent/governance.py index a5211bf..fe67617 100644 --- a/src/va_agent/governance.py +++ b/src/va_agent/governance.py @@ -1,8 +1,9 @@ +# ruff: noqa: E501 """GovernancePlugin: Guardrails for VAia, the virtual assistant for VA.""" import json import logging import re -from typing import Literal, Optional +from typing import Literal from google.adk.agents.callback_context import CallbackContext from google.adk.models import LlmRequest, LlmResponse @@ -35,11 +36,11 @@ class GuardrailOutput(BaseModel): ..., description="Decision for the user prompt", ) - reasoning: Optional[str] = Field( + reasoning: str | None = Field( default=None, description="Optional reasoning for the decision" ) - blocking_response: Optional[str] = Field( + blocking_response: str | None = Field( default=None, description="Optional custom blocking response to return to the user if unsafe" ) @@ -50,7 +51,6 @@ class GovernancePlugin: def __init__(self) -> None: """Initialize guardrail model (structured output), prompt and emojis patterns.""" - self.guardrail_llm = Client( vertexai=True, project=settings.google_cloud_project, @@ -94,24 +94,23 @@ Devuelve un JSON con la siguiente estructura: self._combined_pattern = self._get_combined_pattern() - def _get_combined_pattern(self): + def _get_combined_pattern(self) -> re.Pattern: person_pattern = r"(?:šŸ§‘|šŸ‘©|šŸ‘Ø)" tone_pattern = r"[\U0001F3FB-\U0001F3FF]?" - # Unique pattern that combines all forbidden emojis, including complex ones with skin tones - combined_pattern = re.compile( - rf"{person_pattern}{tone_pattern}\u200dā¤ļø?\u200dšŸ’‹\u200d{person_pattern}{tone_pattern}" # kiss + # Unique pattern that combines all forbidden emojis, including skin tones and compound emojis + return re.compile( + rf"{person_pattern}{tone_pattern}\u200dā¤ļø?\u200dšŸ’‹\u200d{person_pattern}{tone_pattern}" # kissers rf"|{person_pattern}{tone_pattern}\u200dā¤ļø?\u200d{person_pattern}{tone_pattern}" # lovers rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis rf"|šŸ–•{tone_pattern}" # middle finger with all skin tone variations ) - return combined_pattern - + def _remove_emojis(self, text: str) -> tuple[str, list[str]]: removed = self._combined_pattern.findall(text) text = self._combined_pattern.sub("", text) return text.strip(), removed - + def before_model_callback( self, callback_context: CallbackContext | None = None, @@ -124,6 +123,10 @@ Devuelve un JSON con la siguiente estructura: if callback_context is None: error_msg = "callback_context is required" raise ValueError(error_msg) + + if llm_request is None: + error_msg = "llm_request is required" + raise ValueError(error_msg) try: resp = self.guardrail_llm.models.generate_content( @@ -134,7 +137,10 @@ Devuelve un JSON con la siguiente estructura: data = json.loads(resp.text or "{}") decision = data.get("decision", "safe").lower() reasoning = data.get("reasoning", "") - blocking_response = data.get("blocking_response", "Lo siento, no puedo ayudarte con esa solicitud šŸ˜…") + blocking_response = data.get( + "blocking_response", + "Lo siento, no puedo ayudarte con esa solicitud šŸ˜…" + ) if decision == "unsafe": callback_context.state["guardrail_blocked"] = True