feat: Add emojis filter for LLM response #21
80
src/va_agent/governance.py
Normal file
80
src/va_agent/governance.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""GovernancePlugin: Guardrails for VAia, the virtual assistant for VA."""
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from google.adk.agents.callback_context import CallbackContext
|
||||||
|
from google.adk.models import LlmResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
FORBIDDEN_EMOJIS = [
|
||||||
|
"🥵","🔪","🎰","🎲","🃏","😤","🤬","😡","😠","🩸","🧨","🪓","☠️","💀",
|
||||||
|
"💣","🔫","👗","💦","🍑","🍆","👄","👅","🫦","💩","⚖️","⚔️","✝️","🕍",
|
||||||
|
"🕌","⛪","🍻","🍸","🥃","🍷","🍺","🚬","👹","👺","👿","😈","🤡","🧙",
|
||||||
|
"🧛","🔞","🧿","💊",
|
||||||
|
"💏"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GovernancePlugin:
|
||||||
|
"""Guardrail executor for VAia requests as a Agent engine callbacks."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize guardrail model (structured output), prompt and emojis patterns."""
|
||||||
|
self._combined_pattern = self._get_combined_pattern()
|
||||||
|
|
||||||
|
def _get_combined_pattern(self):
|
||||||
|
person_pattern = r"(?:🧑|👩|👨)"
|
||||||
|
tone_pattern = r"[\U0001F3FB-\U0001F3FF]?"
|
||||||
|
|
||||||
|
# Unique pattern that combines all forbidden emojis, including complex ones with skin tones
|
||||||
|
combined_pattern = re.compile(
|
||||||
|
rf"{person_pattern}{tone_pattern}\u200d❤️?\u200d💋\u200d{person_pattern}{tone_pattern}" # kiss
|
||||||
|
rf"|{person_pattern}{tone_pattern}\u200d❤️?\u200d{person_pattern}{tone_pattern}" # lovers
|
||||||
|
rf"|🖕[\U0001F3FB-\U0001F3FF]?" # middle finger with all skin tone variations
|
||||||
|
rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis
|
||||||
|
rf"|\u200d|\uFE0F" # residual ZWJ and variation selectors
|
||||||
|
)
|
||||||
|
return combined_pattern
|
||||||
|
|
||||||
|
def _remove_emojis(self, text: str) -> tuple[str, list[str]]:
|
||||||
|
removed = self._combined_pattern.findall(text)
|
||||||
|
text = self._combined_pattern.sub("", text)
|
||||||
|
return text.strip(), removed
|
||||||
|
|
||||||
|
|
||||||
|
def after_model_callback(
|
||||||
|
self,
|
||||||
|
callback_context: CallbackContext | None = None,
|
||||||
|
llm_response: LlmResponse | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Guardrail post-processing.
|
||||||
|
|
||||||
|
Remove forbidden emojis from the model response.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
text_out = ""
|
||||||
|
if llm_response and llm_response.content:
|
||||||
|
content = llm_response.content
|
||||||
|
parts = getattr(content, "parts", None)
|
||||||
|
if parts:
|
||||||
|
part = parts[0]
|
||||||
|
text_value = getattr(part, "text", "")
|
||||||
|
if isinstance(text_value, str):
|
||||||
|
text_out = text_value
|
||||||
|
|
||||||
|
if text_out:
|
||||||
|
new_text, deleted = self._remove_emojis(text_out)
|
||||||
|
if llm_response and llm_response.content and llm_response.content.parts:
|
||||||
|
llm_response.content.parts[0].text = new_text
|
||||||
|
if deleted:
|
||||||
|
if callback_context:
|
||||||
|
callback_context.state["removed_emojis"] = deleted
|
||||||
|
logger.warning(
|
||||||
|
"Removed forbidden emojis from response: %s",
|
||||||
|
deleted,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error in after_model_callback")
|
||||||
Reference in New Issue
Block a user