WIP: feature: Add before Guardrail #26

Draft
A8080816 wants to merge 16 commits from feature/before-guardrail into main
Showing only changes of commit db9400fcf3 - Show all commits

View File

@@ -1,5 +1,6 @@
# ruff: noqa: E501 # ruff: noqa: E501
"""GovernancePlugin: Guardrails for VAia, the virtual assistant for VA.""" """GovernancePlugin: Guardrails for VAia, the virtual assistant for VA."""
import json import json
import logging import logging
import re import re
@@ -22,10 +23,56 @@ logger = logging.getLogger(__name__)
FORBIDDEN_EMOJIS = [ FORBIDDEN_EMOJIS = [
"🥵","🔪","🎰","🎲","🃏","😤","🤬","😡","😠","🩸","🧨","🪓","☠️","💀", "🥵",
"💣","🔫","👗","💦","🍑","🍆","👄","👅","🫦","💩","⚖️","⚔️","✝️","🕍", "🔪",
"🕌","","🍻","🍸","🥃","🍷","🍺","🚬","👹","👺","👿","😈","🤡","🧙", "🎰",
"🧙‍♀️", "🧙‍♂️", "🧛", "🧛‍♀️", "🧛‍♂️", "🔞","🧿","💊" "🎲",
"🃏",
"😤",
"🤬",
"😡",
"😠",
"🩸",
"🧨",
"🪓",
"☠️",
"💀",
"💣",
"🔫",
"👗",
"💦",
"🍑",
"🍆",
"👄",
"👅",
"🫦",
"💩",
"⚖️",
"⚔️",
"✝️",
"🕍",
"🕌",
"",
"🍻",
"🍸",
"🥃",
"🍷",
"🍺",
"🚬",
"👹",
"👺",
"👿",
"😈",
"🤡",
"🧙",
"🧙‍♀️",
"🧙‍♂️",
"🧛",
"🧛‍♀️",
"🧛‍♂️",
"🔞",
"🧿",
"💊",
] ]
@@ -37,12 +84,11 @@ class GuardrailOutput(BaseModel):
description="Decision for the user prompt", description="Decision for the user prompt",
) )
reasoning: str | None = Field( reasoning: str | None = Field(
default=None, default=None, description="Optional reasoning for the decision"
description="Optional reasoning for the decision"
) )
blocking_response: str | None = Field( blocking_response: str | None = Field(
default=None, default=None,
description="Optional custom blocking response to return to the user if unsafe" description="Optional custom blocking response to return to the user if unsafe",
) )
@@ -54,7 +100,7 @@ class GovernancePlugin:
self.guardrail_llm = Client( self.guardrail_llm = Client(
vertexai=True, vertexai=True,
project=settings.google_cloud_project, project=settings.google_cloud_project,
location=settings.google_cloud_location location=settings.google_cloud_location,
) )
_guardrail_instruction = """ _guardrail_instruction = """
Eres una capa de seguridad y protección de marca para VAia, el asistente virtual de VA en WhatsApp. Eres una capa de seguridad y protección de marca para VAia, el asistente virtual de VA en WhatsApp.
@@ -85,9 +131,9 @@ Devuelve un JSON con la siguiente estructura:
_schema = GuardrailOutput.model_json_schema() _schema = GuardrailOutput.model_json_schema()
# Force strict JSON output from the guardrail LLM # Force strict JSON output from the guardrail LLM
self._guardrail_gen_config = GenerateContentConfig( self._guardrail_gen_config = GenerateContentConfig(
system_instruction = _guardrail_instruction, system_instruction=_guardrail_instruction,
response_mime_type = "application/json", response_mime_type="application/json",
response_schema = _schema, response_schema=_schema,
max_output_tokens=1000, max_output_tokens=1000,
temperature=0.1, temperature=0.1,
) )
@@ -100,13 +146,12 @@ Devuelve un JSON con la siguiente estructura:
# Unique pattern that combines all forbidden emojis, including skin tones and compound emojis # Unique pattern that combines all forbidden emojis, including skin tones and compound emojis
return re.compile( return re.compile(
rf"{person_pattern}{tone_pattern}\u200d❤?\u200d💋\u200d{person_pattern}{tone_pattern}" # kissers rf"{person_pattern}{tone_pattern}\u200d❤?\u200d💋\u200d{person_pattern}{tone_pattern}" # kissers
rf"|{person_pattern}{tone_pattern}\u200d❤?\u200d{person_pattern}{tone_pattern}" # lovers rf"|{person_pattern}{tone_pattern}\u200d❤?\u200d{person_pattern}{tone_pattern}" # lovers
rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis
rf"|🖕{tone_pattern}" # middle finger with all skin tone variations rf"|🖕{tone_pattern}" # middle finger with all skin tone variations
) )
def _remove_emojis(self, text: str) -> tuple[str, list[str]]: def _remove_emojis(self, text: str) -> tuple[str, list[str]]:
removed = self._combined_pattern.findall(text) removed = self._combined_pattern.findall(text)
text = self._combined_pattern.sub("", text) text = self._combined_pattern.sub("", text)
@@ -139,8 +184,7 @@ Devuelve un JSON con la siguiente estructura:
decision = data.get("decision", "safe").lower() decision = data.get("decision", "safe").lower()
reasoning = data.get("reasoning", "") reasoning = data.get("reasoning", "")
blocking_response = data.get( blocking_response = data.get(
"blocking_response", "blocking_response", "Lo siento, no puedo ayudarte con esa solicitud 😅"
"Lo siento, no puedo ayudarte con esa solicitud 😅"
) )
if decision == "unsafe": if decision == "unsafe":
@@ -148,13 +192,8 @@ Devuelve un JSON con la siguiente estructura:
callback_context.state["guardrail_message"] = "[GUARDRAIL_BLOCKED]" callback_context.state["guardrail_message"] = "[GUARDRAIL_BLOCKED]"
callback_context.state["guardrail_reasoning"] = reasoning callback_context.state["guardrail_reasoning"] = reasoning
return LlmResponse( return LlmResponse(
content=Content( content=Content(role="model", parts=[Part(text=blocking_response)]),
role="model", usage_metadata=resp.usage_metadata or None,
parts=[
Part(text=blocking_response)
]
),
usage_metadata=resp.usage_metadata or None
) )
callback_context.state["guardrail_blocked"] = False callback_context.state["guardrail_blocked"] = False
callback_context.state["guardrail_message"] = "[GUARDRAIL_PASSED]" callback_context.state["guardrail_message"] = "[GUARDRAIL_PASSED]"
@@ -168,9 +207,7 @@ Devuelve un JSON con la siguiente estructura:
content=Content( content=Content(
role="model", role="model",
parts=[ parts=[
Part( Part(text="Lo siento, no puedo ayudarte con esa solicitud 😅")
text="Lo siento, no puedo ayudarte con esa solicitud 😅"
)
], ],
), ),
interrupted=True, interrupted=True,