WIP: feature: Add before Guardrail #26

Draft
A8080816 wants to merge 16 commits from feature/before-guardrail into main
Showing only changes of commit f8638d22fe - Show all commits

View File

@@ -1,8 +1,9 @@
# ruff: noqa: E501
"""GovernancePlugin: Guardrails for VAia, the virtual assistant for VA.""" """GovernancePlugin: Guardrails for VAia, the virtual assistant for VA."""
import json import json
import logging import logging
import re import re
from typing import Literal, Optional from typing import Literal
from google.adk.agents.callback_context import CallbackContext from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmRequest, LlmResponse from google.adk.models import LlmRequest, LlmResponse
@@ -35,11 +36,11 @@ class GuardrailOutput(BaseModel):
..., ...,
description="Decision for the user prompt", description="Decision for the user prompt",
) )
reasoning: Optional[str] = Field( reasoning: str | None = Field(
default=None, default=None,
description="Optional reasoning for the decision" description="Optional reasoning for the decision"
) )
blocking_response: Optional[str] = Field( blocking_response: str | None = Field(
default=None, default=None,
description="Optional custom blocking response to return to the user if unsafe" description="Optional custom blocking response to return to the user if unsafe"
) )
@@ -50,7 +51,6 @@ class GovernancePlugin:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize guardrail model (structured output), prompt and emojis patterns.""" """Initialize guardrail model (structured output), prompt and emojis patterns."""
self.guardrail_llm = Client( self.guardrail_llm = Client(
vertexai=True, vertexai=True,
project=settings.google_cloud_project, project=settings.google_cloud_project,
@@ -94,24 +94,23 @@ Devuelve un JSON con la siguiente estructura:
self._combined_pattern = self._get_combined_pattern() self._combined_pattern = self._get_combined_pattern()
def _get_combined_pattern(self): def _get_combined_pattern(self) -> re.Pattern:
person_pattern = r"(?:🧑|👩|👨)" person_pattern = r"(?:🧑|👩|👨)"
tone_pattern = r"[\U0001F3FB-\U0001F3FF]?" tone_pattern = r"[\U0001F3FB-\U0001F3FF]?"
# Unique pattern that combines all forbidden emojis, including complex ones with skin tones # Unique pattern that combines all forbidden emojis, including skin tones and compound emojis
combined_pattern = re.compile( return re.compile(
rf"{person_pattern}{tone_pattern}\u200d❤?\u200d💋\u200d{person_pattern}{tone_pattern}" # kiss rf"{person_pattern}{tone_pattern}\u200d❤?\u200d💋\u200d{person_pattern}{tone_pattern}" # kissers
rf"|{person_pattern}{tone_pattern}\u200d❤?\u200d{person_pattern}{tone_pattern}" # lovers rf"|{person_pattern}{tone_pattern}\u200d❤?\u200d{person_pattern}{tone_pattern}" # lovers
rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis
rf"|🖕{tone_pattern}" # middle finger with all skin tone variations rf"|🖕{tone_pattern}" # middle finger with all skin tone variations
) )
return combined_pattern
def _remove_emojis(self, text: str) -> tuple[str, list[str]]: def _remove_emojis(self, text: str) -> tuple[str, list[str]]:
removed = self._combined_pattern.findall(text) removed = self._combined_pattern.findall(text)
text = self._combined_pattern.sub("", text) text = self._combined_pattern.sub("", text)
return text.strip(), removed return text.strip(), removed
def before_model_callback( def before_model_callback(
self, self,
callback_context: CallbackContext | None = None, callback_context: CallbackContext | None = None,
@@ -124,6 +123,10 @@ Devuelve un JSON con la siguiente estructura:
if callback_context is None: if callback_context is None:
error_msg = "callback_context is required" error_msg = "callback_context is required"
raise ValueError(error_msg) raise ValueError(error_msg)
if llm_request is None:
error_msg = "llm_request is required"
raise ValueError(error_msg)
try: try:
resp = self.guardrail_llm.models.generate_content( resp = self.guardrail_llm.models.generate_content(
@@ -134,7 +137,10 @@ Devuelve un JSON con la siguiente estructura:
data = json.loads(resp.text or "{}") data = json.loads(resp.text or "{}")
decision = data.get("decision", "safe").lower() decision = data.get("decision", "safe").lower()
reasoning = data.get("reasoning", "") reasoning = data.get("reasoning", "")
blocking_response = data.get("blocking_response", "Lo siento, no puedo ayudarte con esa solicitud 😅") blocking_response = data.get(
"blocking_response",
"Lo siento, no puedo ayudarte con esa solicitud 😅"
)
if decision == "unsafe": if decision == "unsafe":
callback_context.state["guardrail_blocked"] = True callback_context.state["guardrail_blocked"] = True