130 lines
3.3 KiB
Python
130 lines
3.3 KiB
Python
"""GovernancePlugin: Guardrails for VAia, the virtual assistant for VA."""
|
||
|
||
import logging
|
||
import re
|
||
|
||
from google.adk.agents.callback_context import CallbackContext
|
||
from google.adk.models import LlmResponse
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
FORBIDDEN_EMOJIS = [
|
||
"🥵",
|
||
"🔪",
|
||
"🎰",
|
||
"🎲",
|
||
"🃏",
|
||
"😤",
|
||
"🤬",
|
||
"😡",
|
||
"😠",
|
||
"🩸",
|
||
"🧨",
|
||
"🪓",
|
||
"☠️",
|
||
"💀",
|
||
"💣",
|
||
"🔫",
|
||
"👗",
|
||
"💦",
|
||
"🍑",
|
||
"🍆",
|
||
"👄",
|
||
"👅",
|
||
"🫦",
|
||
"💩",
|
||
"⚖️",
|
||
"⚔️",
|
||
"✝️",
|
||
"🕍",
|
||
"🕌",
|
||
"⛪",
|
||
"🍻",
|
||
"🍸",
|
||
"🥃",
|
||
"🍷",
|
||
"🍺",
|
||
"🚬",
|
||
"👹",
|
||
"👺",
|
||
"👿",
|
||
"😈",
|
||
"🤡",
|
||
"🧙",
|
||
"🧙♀️",
|
||
"🧙♂️",
|
||
"🧛",
|
||
"🧛♀️",
|
||
"🧛♂️",
|
||
"🔞",
|
||
"🧿",
|
||
"💊",
|
||
"💏",
|
||
]
|
||
|
||
|
||
class GovernancePlugin:
|
||
"""Guardrail executor for VAia requests as a Agent engine callbacks."""
|
||
|
||
def __init__(self) -> None:
|
||
"""Initialize guardrail model, prompt and emojis patterns."""
|
||
self._combined_pattern = self._get_combined_pattern()
|
||
|
||
def _get_combined_pattern(self) -> re.Pattern[str]:
|
||
person = r"(?:🧑|👩|👨)"
|
||
tone = r"[\U0001F3FB-\U0001F3FF]?"
|
||
simple = "|".join(
|
||
map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True))
|
||
)
|
||
|
||
# Combines all forbidden emojis, including complex
|
||
# ones with skin tones
|
||
return re.compile(
|
||
rf"{person}{tone}\u200d❤️?\u200d💋\u200d{person}{tone}"
|
||
rf"|{person}{tone}\u200d❤️?\u200d{person}{tone}"
|
||
rf"|🖕{tone}"
|
||
rf"|{simple}"
|
||
rf"|\u200d|\uFE0F"
|
||
)
|
||
|
||
def _remove_emojis(self, text: str) -> tuple[str, list[str]]:
|
||
removed = self._combined_pattern.findall(text)
|
||
text = self._combined_pattern.sub("", text)
|
||
return text.strip(), removed
|
||
|
||
def after_model_callback(
|
||
self,
|
||
callback_context: CallbackContext | None = None,
|
||
llm_response: LlmResponse | None = None,
|
||
) -> None:
|
||
"""Guardrail post-processing.
|
||
|
||
Remove forbidden emojis from the model response.
|
||
"""
|
||
try:
|
||
text_out = ""
|
||
if llm_response and llm_response.content:
|
||
content = llm_response.content
|
||
parts = getattr(content, "parts", None)
|
||
if parts:
|
||
part = parts[0]
|
||
text_value = getattr(part, "text", "")
|
||
if isinstance(text_value, str):
|
||
text_out = text_value
|
||
|
||
if text_out:
|
||
new_text, deleted = self._remove_emojis(text_out)
|
||
if llm_response and llm_response.content and llm_response.content.parts:
|
||
llm_response.content.parts[0].text = new_text
|
||
if deleted:
|
||
if callback_context:
|
||
callback_context.state["removed_emojis"] = deleted
|
||
logger.warning(
|
||
"Removed forbidden emojis from response: %s",
|
||
deleted,
|
||
)
|
||
|
||
except Exception:
|
||
logger.exception("Error in after_model_callback")
|