WIP: feature: Add before Guardrail #26

Draft
A8080816 wants to merge 16 commits from feature/before-guardrail into main
2 changed files with 53 additions and 1 deletions
Showing only changes of commit d92a75a393 - Show all commits

View File

@@ -233,5 +233,9 @@ class GovernancePlugin:
deleted,
)
# Reset censorship flag for next interaction
if callback_context:
callback_context.state["guardrail_censored"] = False
except Exception:
logger.exception("Error in after_model_callback")

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import copy
import logging
import time
import uuid
@@ -378,8 +379,55 @@ class FirestoreSessionService(BaseSessionService):
event = await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Determine if we need to censor this event (model response when guardrail blocked)
should_censor_model = (
session.state.get("guardrail_blocked", False)
and event.author == app_name
and hasattr(event, "content")
and event.content
and event.content.parts
and not session.state.get("guardrail_censored", False)
)
# Prepare event data for Firestore
if should_censor_model:
# Mark as censored to avoid double-censoring
session.state["guardrail_censored"] = True
# Create a censored version of the model response
event_to_save = copy.deepcopy(event)
event_to_save.content.parts[0].text = "[respuesta de adversidad]"
event_data = event_to_save.model_dump(mode="json", exclude_none=True)
# Also censor the previous user message in Firestore
# Find the last user event in the session
for i in range(len(session.events) - 1, -1, -1):
prev_event = session.events[i]
if (
prev_event.author == "user"
and prev_event.content
and prev_event.content.parts
):
# Update this event in Firestore with censored content
censored_user_content = Content(
role="user", parts=[Part(text="[pregunta mala]")]
)
await (
self._events_col(app_name, user_id, session_id)
.document(prev_event.id)
.update(
{
"content": censored_user_content.model_dump(
mode="json", exclude_none=True
)
}
)
)
break
else:
event_data = event.model_dump(mode="json", exclude_none=True)
# Persist event document
event_data = event.model_dump(mode="json", exclude_none=True)
await (
self._events_col(app_name, user_id, session_id)
.document(event.id)