chore(governance): ruff and ty checks passed
This commit is contained in:
@@ -1,8 +1,9 @@
|
|||||||
|
# ruff: noqa: E501
|
||||||
"""GovernancePlugin: Guardrails for VAia, the virtual assistant for VA."""
|
"""GovernancePlugin: Guardrails for VAia, the virtual assistant for VA."""
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
from google.adk.agents.callback_context import CallbackContext
|
from google.adk.agents.callback_context import CallbackContext
|
||||||
from google.adk.models import LlmRequest, LlmResponse
|
from google.adk.models import LlmRequest, LlmResponse
|
||||||
@@ -35,11 +36,11 @@ class GuardrailOutput(BaseModel):
|
|||||||
...,
|
...,
|
||||||
description="Decision for the user prompt",
|
description="Decision for the user prompt",
|
||||||
)
|
)
|
||||||
reasoning: Optional[str] = Field(
|
reasoning: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional reasoning for the decision"
|
description="Optional reasoning for the decision"
|
||||||
)
|
)
|
||||||
blocking_response: Optional[str] = Field(
|
blocking_response: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional custom blocking response to return to the user if unsafe"
|
description="Optional custom blocking response to return to the user if unsafe"
|
||||||
)
|
)
|
||||||
@@ -50,7 +51,6 @@ class GovernancePlugin:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize guardrail model (structured output), prompt and emojis patterns."""
|
"""Initialize guardrail model (structured output), prompt and emojis patterns."""
|
||||||
|
|
||||||
self.guardrail_llm = Client(
|
self.guardrail_llm = Client(
|
||||||
vertexai=True,
|
vertexai=True,
|
||||||
project=settings.google_cloud_project,
|
project=settings.google_cloud_project,
|
||||||
@@ -94,24 +94,23 @@ Devuelve un JSON con la siguiente estructura:
|
|||||||
|
|
||||||
self._combined_pattern = self._get_combined_pattern()
|
self._combined_pattern = self._get_combined_pattern()
|
||||||
|
|
||||||
def _get_combined_pattern(self):
|
def _get_combined_pattern(self) -> re.Pattern:
|
||||||
person_pattern = r"(?:🧑|👩|👨)"
|
person_pattern = r"(?:🧑|👩|👨)"
|
||||||
tone_pattern = r"[\U0001F3FB-\U0001F3FF]?"
|
tone_pattern = r"[\U0001F3FB-\U0001F3FF]?"
|
||||||
|
|
||||||
# Unique pattern that combines all forbidden emojis, including complex ones with skin tones
|
# Unique pattern that combines all forbidden emojis, including skin tones and compound emojis
|
||||||
combined_pattern = re.compile(
|
return re.compile(
|
||||||
rf"{person_pattern}{tone_pattern}\u200d❤️?\u200d💋\u200d{person_pattern}{tone_pattern}" # kiss
|
rf"{person_pattern}{tone_pattern}\u200d❤️?\u200d💋\u200d{person_pattern}{tone_pattern}" # kissers
|
||||||
rf"|{person_pattern}{tone_pattern}\u200d❤️?\u200d{person_pattern}{tone_pattern}" # lovers
|
rf"|{person_pattern}{tone_pattern}\u200d❤️?\u200d{person_pattern}{tone_pattern}" # lovers
|
||||||
rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis
|
rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis
|
||||||
rf"|🖕{tone_pattern}" # middle finger with all skin tone variations
|
rf"|🖕{tone_pattern}" # middle finger with all skin tone variations
|
||||||
)
|
)
|
||||||
return combined_pattern
|
|
||||||
|
|
||||||
def _remove_emojis(self, text: str) -> tuple[str, list[str]]:
|
def _remove_emojis(self, text: str) -> tuple[str, list[str]]:
|
||||||
removed = self._combined_pattern.findall(text)
|
removed = self._combined_pattern.findall(text)
|
||||||
text = self._combined_pattern.sub("", text)
|
text = self._combined_pattern.sub("", text)
|
||||||
return text.strip(), removed
|
return text.strip(), removed
|
||||||
|
|
||||||
def before_model_callback(
|
def before_model_callback(
|
||||||
self,
|
self,
|
||||||
callback_context: CallbackContext | None = None,
|
callback_context: CallbackContext | None = None,
|
||||||
@@ -124,6 +123,10 @@ Devuelve un JSON con la siguiente estructura:
|
|||||||
if callback_context is None:
|
if callback_context is None:
|
||||||
error_msg = "callback_context is required"
|
error_msg = "callback_context is required"
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
if llm_request is None:
|
||||||
|
error_msg = "llm_request is required"
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = self.guardrail_llm.models.generate_content(
|
resp = self.guardrail_llm.models.generate_content(
|
||||||
@@ -134,7 +137,10 @@ Devuelve un JSON con la siguiente estructura:
|
|||||||
data = json.loads(resp.text or "{}")
|
data = json.loads(resp.text or "{}")
|
||||||
decision = data.get("decision", "safe").lower()
|
decision = data.get("decision", "safe").lower()
|
||||||
reasoning = data.get("reasoning", "")
|
reasoning = data.get("reasoning", "")
|
||||||
blocking_response = data.get("blocking_response", "Lo siento, no puedo ayudarte con esa solicitud 😅")
|
blocking_response = data.get(
|
||||||
|
"blocking_response",
|
||||||
|
"Lo siento, no puedo ayudarte con esa solicitud 😅"
|
||||||
|
)
|
||||||
|
|
||||||
if decision == "unsafe":
|
if decision == "unsafe":
|
||||||
callback_context.state["guardrail_blocked"] = True
|
callback_context.state["guardrail_blocked"] = True
|
||||||
|
|||||||
Reference in New Issue
Block a user