refactor(session): extract guardrail censor helpers
All checks were successful
CI / ci (pull_request) Successful in 21s

Pass all checks
This commit is contained in:
2026-03-13 22:41:57 +00:00
parent c244b35e00
commit 2b610d30a9
3 changed files with 168 additions and 96 deletions

View File

@@ -52,11 +52,11 @@ agent_instructions: |
El teléfono de centro de contacto de VA es: +52 1 55 5140 5655
# Guardrail config
guardrail_censored_user_message: "[pregunta mala]"
guardrail_censored_model_response: "[respuesta de adversidad]"
guardrail_blocked_label: "[GUARDRAIL_BLOCKED]"
guardrail_passed_label: "[GUARDRAIL_PASSED]"
guardrail_error_label: "[GUARDRAIL_ERROR]"
guardrail_censored_user_message: "[ILEGAL QUESTION]"
guardrail_censored_model_response: "[ADVERSITY RESPONSE]"
guardrail_blocked_label: "[GUARDRAIL BLOCKED]"
guardrail_passed_label: "[GUARDRAIL PASSED]"
guardrail_error_label: "[GUARDRAIL ERROR]"
guardrail_instruction: |
Eres una capa de seguridad y protección de marca para VAia, el asistente virtual de VA en WhatsApp.

View File

@@ -171,14 +171,18 @@ class GovernancePlugin:
if decision == "unsafe":
callback_context.state["guardrail_blocked"] = True
callback_context.state["guardrail_message"] = settings.guardrail_blocked_label
callback_context.state["guardrail_message"] = (
settings.guardrail_blocked_label
)
callback_context.state["guardrail_reasoning"] = reasoning
return LlmResponse(
content=Content(role="model", parts=[Part(text=blocking_response)]),
usage_metadata=resp.usage_metadata or None,
)
callback_context.state["guardrail_blocked"] = False
callback_context.state["guardrail_message"] = settings.guardrail_passed_label
callback_context.state["guardrail_message"] = (
settings.guardrail_passed_label
)
callback_context.state["guardrail_reasoning"] = reasoning
except Exception:

View File

@@ -380,8 +380,85 @@ class FirestoreSessionService(BaseSessionService):
event = await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Determine if we need to censor this event (model response when guardrail blocked)
should_censor_model = (
event_data = await self._prepare_event_data(
session=session,
event=event,
app_name=app_name,
user_id=user_id,
session_id=session_id,
)
await (
self._events_col(app_name, user_id, session_id)
.document(event.id)
.set(event_data)
)
session_ref = self._session_ref(app_name, user_id, session_id)
await self._persist_state_deltas(
event=event,
app_name=app_name,
user_id=user_id,
session_ref=session_ref,
)
self._log_token_usage(event=event, session_id=session_id)
self._maybe_trigger_compaction(
session=session,
event=event,
app_name=app_name,
user_id=user_id,
session_id=session_id,
)
elapsed = time.monotonic() - t0
logger.info(
"append_event completed for session %s event %s in %.3fs",
session_id,
event.id,
elapsed,
)
return event
async def _prepare_event_data(
self,
*,
session: Session,
event: Event,
app_name: str,
user_id: str,
session_id: str,
) -> dict[str, Any]:
"""Handle guardrail censoring and return serialized event data."""
if not self._should_censor_model(session=session, event=event):
return event.model_dump(mode="json", exclude_none=True)
session.state["guardrail_censored"] = True
event_to_save = copy.deepcopy(event)
content = event_to_save.content
if content is None or not content.parts:
logger.warning(
"Guardrail censor requested but event content is missing; "
"falling back to original event payload."
)
event_data = event.model_dump(mode="json", exclude_none=True)
else:
content.parts[0].text = settings.guardrail_censored_model_response
event_data = event_to_save.model_dump(mode="json", exclude_none=True)
await self._censor_previous_user_event(
session=session,
app_name=app_name,
user_id=user_id,
session_id=session_id,
)
return event_data
def _should_censor_model(self, *, session: Session, event: Event) -> Any | bool:
return (
session.state.get("guardrail_blocked", False)
and event.author != "user"
and hasattr(event, "content")
@@ -390,58 +467,49 @@ class FirestoreSessionService(BaseSessionService):
and not session.state.get("guardrail_censored", False)
)
# Prepare event data for Firestore
if should_censor_model:
# Mark as censored to avoid double-censoring
session.state["guardrail_censored"] = True
# Create a censored version of the model response
event_to_save = copy.deepcopy(event)
event_to_save.content.parts[0].text = settings.guardrail_censored_model_response
event_data = event_to_save.model_dump(mode="json", exclude_none=True)
# Also censor the previous user message in Firestore
# Find the last user event in the session
prev_user_event = next(
(
e
for e in reversed(session.events[:-1])
if e.author == "user" and e.content and e.content.parts
),
None,
)
if prev_user_event:
# Update this event in Firestore with censored content
censored_user_content = Content(
role="user",
parts=[Part(text=settings.guardrail_censored_user_message)],
)
await (
self._events_col(app_name, user_id, session_id)
.document(prev_user_event.id)
.update(
{
"content": censored_user_content.model_dump(
mode="json", exclude_none=True
)
}
)
)
else:
event_data = event.model_dump(mode="json", exclude_none=True)
# Persist event document
async def _censor_previous_user_event(
self,
*,
session: Session,
app_name: str,
user_id: str,
session_id: str,
) -> None:
prev_user_event = next(
(
e
for e in reversed(session.events[:-1])
if e.author == "user" and e.content and e.content.parts
),
None,
)
if not prev_user_event:
return
censored_user_content = Content(
role="user",
parts=[Part(text=settings.guardrail_censored_user_message)],
)
await (
self._events_col(app_name, user_id, session_id)
.document(event.id)
.set(event_data)
.document(prev_user_event.id)
.update(
{
"content": censored_user_content.model_dump(
mode="json", exclude_none=True
)
}
)
)
# Persist state deltas
session_ref = self._session_ref(app_name, user_id, session_id)
async def _persist_state_deltas(
self,
*,
event: Event,
app_name: str,
user_id: str,
session_ref: Any,
) -> None:
last_update_dt = datetime.fromtimestamp(event.timestamp, UTC)
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
@@ -470,50 +538,50 @@ class FirestoreSessionService(BaseSessionService):
)
await asyncio.gather(*write_coros)
else:
await session_ref.update({"last_update_time": last_update_dt})
return
# Log token usage
if event.usage_metadata:
meta = event.usage_metadata
logger.info(
"Token usage for session %s event %s: "
"prompt=%s, candidates=%s, total=%s",
session_id,
event.id,
meta.prompt_token_count,
meta.candidates_token_count,
meta.total_token_count,
)
await session_ref.update({"last_update_time": last_update_dt})
# Trigger compaction if total token count exceeds threshold
if (
self._compaction_threshold is not None
and event.usage_metadata
and event.usage_metadata.total_token_count
and event.usage_metadata.total_token_count >= self._compaction_threshold
):
logger.info(
"Compaction triggered for session %s: "
"total_token_count=%d >= threshold=%d",
session_id,
event.usage_metadata.total_token_count,
self._compaction_threshold,
)
events_ref = self._events_col(app_name, user_id, session_id)
session_ref = self._session_ref(app_name, user_id, session_id)
task = asyncio.create_task(
self._compactor.guarded_compact(session, events_ref, session_ref)
)
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
elapsed = time.monotonic() - t0
def _log_token_usage(self, *, event: Event, session_id: str) -> None:
if not event.usage_metadata:
return
meta = event.usage_metadata
logger.info(
"append_event completed for session %s event %s in %.3fs",
"Token usage for session %s event %s: prompt=%s, candidates=%s, total=%s",
session_id,
event.id,
elapsed,
meta.prompt_token_count,
meta.candidates_token_count,
meta.total_token_count,
)
return event
def _maybe_trigger_compaction(
self,
*,
session: Session,
event: Event,
app_name: str,
user_id: str,
session_id: str,
) -> None:
if (
self._compaction_threshold is None
or not event.usage_metadata
or not event.usage_metadata.total_token_count
or event.usage_metadata.total_token_count < self._compaction_threshold
):
return
logger.info(
"Compaction triggered for session %s: total_token_count=%d >= threshold=%d",
session_id,
event.usage_metadata.total_token_count,
self._compaction_threshold,
)
events_ref = self._events_col(app_name, user_id, session_id)
session_ref = self._session_ref(app_name, user_id, session_id)
task = asyncio.create_task(
self._compactor.guarded_compact(session, events_ref, session_ref)
)
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)