diff --git a/pyproject.toml b/pyproject.toml index 65f8d2d..f66b064 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "google-cloud-firestore>=2.23.0", "pydantic-settings[yaml]>=2.13.1", "google-auth>=2.34.0", + "google-genai>=1.64.0", ] [build-system] diff --git a/src/va_agent/agent.py b/src/va_agent/agent.py index 3ffcc13..cbccff9 100644 --- a/src/va_agent/agent.py +++ b/src/va_agent/agent.py @@ -10,17 +10,20 @@ from google.cloud.firestore_v1.async_client import AsyncClient from va_agent.auth import auth_headers_provider from va_agent.config import settings from va_agent.session import FirestoreSessionService +from va_agent.governance import GovernancePlugin toolset = McpToolset( connection_params=StreamableHTTPConnectionParams(url=settings.mcp_remote_url), header_provider=auth_headers_provider, ) +governance = GovernancePlugin() agent = Agent( model=settings.agent_model, name=settings.agent_name, instruction=settings.agent_instructions, tools=[toolset], + after_model_callback=governance.after_model_callback, ) session_service = FirestoreSessionService( diff --git a/src/va_agent/governance.py b/src/va_agent/governance.py new file mode 100644 index 0000000..936c668 --- /dev/null +++ b/src/va_agent/governance.py @@ -0,0 +1,79 @@ +"""GovernancePlugin: Guardrails for VAia, the virtual assistant for VA.""" +import logging +import re + +from google.adk.agents.callback_context import CallbackContext +from google.adk.models import LlmResponse + +logger = logging.getLogger(__name__) + + +FORBIDDEN_EMOJIS = [ + "๐Ÿฅต","๐Ÿ”ช","๐ŸŽฐ","๐ŸŽฒ","๐Ÿƒ","๐Ÿ˜ค","๐Ÿคฌ","๐Ÿ˜ก","๐Ÿ˜ ","๐Ÿฉธ","๐Ÿงจ","๐Ÿช“","โ˜ ๏ธ","๐Ÿ’€", + "๐Ÿ’ฃ","๐Ÿ”ซ","๐Ÿ‘—","๐Ÿ’ฆ","๐Ÿ‘","๐Ÿ†","๐Ÿ‘„","๐Ÿ‘…","๐Ÿซฆ","๐Ÿ’ฉ","โš–๏ธ","โš”๏ธ","โœ๏ธ","๐Ÿ•", + "๐Ÿ•Œ","โ›ช","๐Ÿป","๐Ÿธ","๐Ÿฅƒ","๐Ÿท","๐Ÿบ","๐Ÿšฌ","๐Ÿ‘น","๐Ÿ‘บ","๐Ÿ‘ฟ","๐Ÿ˜ˆ","๐Ÿคก","๐Ÿง™", + "๐Ÿง™โ€โ™€๏ธ", "๐Ÿง™โ€โ™‚๏ธ", "๐Ÿง›", "๐Ÿง›โ€โ™€๏ธ", "๐Ÿง›โ€โ™‚๏ธ", "๐Ÿ”ž","๐Ÿงฟ","๐Ÿ’Š", "๐Ÿ’" +] + + +class GovernancePlugin: + """Guardrail executor for VAia requests as a Agent engine callbacks.""" + + def __init__(self) -> None: + """Initialize guardrail model (structured output), prompt and emojis patterns.""" + self._combined_pattern = self._get_combined_pattern() + + def _get_combined_pattern(self): + person_pattern = r"(?:๐Ÿง‘|๐Ÿ‘ฉ|๐Ÿ‘จ)" + tone_pattern = r"[\U0001F3FB-\U0001F3FF]?" + + # Unique pattern that combines all forbidden emojis, including complex ones with skin tones + combined_pattern = re.compile( + rf"{person_pattern}{tone_pattern}\u200dโค๏ธ?\u200d๐Ÿ’‹\u200d{person_pattern}{tone_pattern}" # kiss + rf"|{person_pattern}{tone_pattern}\u200dโค๏ธ?\u200d{person_pattern}{tone_pattern}" # lovers + rf"|๐Ÿ–•{tone_pattern}" # middle finger with all skin tone variations + rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis + rf"|\u200d|\uFE0F" # residual ZWJ and variation selectors + ) + return combined_pattern + + def _remove_emojis(self, text: str) -> tuple[str, list[str]]: + removed = self._combined_pattern.findall(text) + text = self._combined_pattern.sub("", text) + return text.strip(), removed + + + def after_model_callback( + self, + callback_context: CallbackContext | None = None, + llm_response: LlmResponse | None = None, + ) -> None: + """Guardrail post-processing. + + Remove forbidden emojis from the model response. + """ + try: + text_out = "" + if llm_response and llm_response.content: + content = llm_response.content + parts = getattr(content, "parts", None) + if parts: + part = parts[0] + text_value = getattr(part, "text", "") + if isinstance(text_value, str): + text_out = text_value + + if text_out: + new_text, deleted = self._remove_emojis(text_out) + if llm_response and llm_response.content and llm_response.content.parts: + llm_response.content.parts[0].text = new_text + if deleted: + if callback_context: + callback_context.state["removed_emojis"] = deleted + logger.warning( + "Removed forbidden emojis from response: %s", + deleted, + ) + + except Exception: + logger.exception("Error in after_model_callback") diff --git a/uv.lock b/uv.lock index 64c7fbb..5598062 100644 --- a/uv.lock +++ b/uv.lock @@ -1924,6 +1924,7 @@ dependencies = [ { name = "google-adk" }, { name = "google-auth" }, { name = "google-cloud-firestore" }, + { name = "google-genai" }, { name = "pydantic-settings", extra = ["yaml"] }, ] @@ -1941,6 +1942,7 @@ requires-dist = [ { name = "google-adk", specifier = ">=1.14.1" }, { name = "google-auth", specifier = ">=2.34.0" }, { name = "google-cloud-firestore", specifier = ">=2.23.0" }, + { name = "google-genai", specifier = ">=1.64.0" }, { name = "pydantic-settings", extras = ["yaml"], specifier = ">=2.13.1" }, ]