refactor(session): extract guardrail censor helpers
All checks were successful
CI / ci (pull_request) Successful in 21s
All checks were successful
CI / ci (pull_request) Successful in 21s
Pass all checks
This commit is contained in:
10
config.yaml
10
config.yaml
@@ -52,11 +52,11 @@ agent_instructions: |
|
||||
El teléfono de centro de contacto de VA es: +52 1 55 5140 5655
|
||||
|
||||
# Guardrail config
|
||||
guardrail_censored_user_message: "[pregunta mala]"
|
||||
guardrail_censored_model_response: "[respuesta de adversidad]"
|
||||
guardrail_blocked_label: "[GUARDRAIL_BLOCKED]"
|
||||
guardrail_passed_label: "[GUARDRAIL_PASSED]"
|
||||
guardrail_error_label: "[GUARDRAIL_ERROR]"
|
||||
guardrail_censored_user_message: "[ILEGAL QUESTION]"
|
||||
guardrail_censored_model_response: "[ADVERSITY RESPONSE]"
|
||||
guardrail_blocked_label: "[GUARDRAIL BLOCKED]"
|
||||
guardrail_passed_label: "[GUARDRAIL PASSED]"
|
||||
guardrail_error_label: "[GUARDRAIL ERROR]"
|
||||
|
||||
guardrail_instruction: |
|
||||
Eres una capa de seguridad y protección de marca para VAia, el asistente virtual de VA en WhatsApp.
|
||||
|
||||
@@ -171,14 +171,18 @@ class GovernancePlugin:
|
||||
|
||||
if decision == "unsafe":
|
||||
callback_context.state["guardrail_blocked"] = True
|
||||
callback_context.state["guardrail_message"] = settings.guardrail_blocked_label
|
||||
callback_context.state["guardrail_message"] = (
|
||||
settings.guardrail_blocked_label
|
||||
)
|
||||
callback_context.state["guardrail_reasoning"] = reasoning
|
||||
return LlmResponse(
|
||||
content=Content(role="model", parts=[Part(text=blocking_response)]),
|
||||
usage_metadata=resp.usage_metadata or None,
|
||||
)
|
||||
callback_context.state["guardrail_blocked"] = False
|
||||
callback_context.state["guardrail_message"] = settings.guardrail_passed_label
|
||||
callback_context.state["guardrail_message"] = (
|
||||
settings.guardrail_passed_label
|
||||
)
|
||||
callback_context.state["guardrail_reasoning"] = reasoning
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -380,8 +380,85 @@ class FirestoreSessionService(BaseSessionService):
|
||||
event = await super().append_event(session=session, event=event)
|
||||
session.last_update_time = event.timestamp
|
||||
|
||||
# Determine if we need to censor this event (model response when guardrail blocked)
|
||||
should_censor_model = (
|
||||
event_data = await self._prepare_event_data(
|
||||
session=session,
|
||||
event=event,
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
await (
|
||||
self._events_col(app_name, user_id, session_id)
|
||||
.document(event.id)
|
||||
.set(event_data)
|
||||
)
|
||||
|
||||
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||
await self._persist_state_deltas(
|
||||
event=event,
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_ref=session_ref,
|
||||
)
|
||||
|
||||
self._log_token_usage(event=event, session_id=session_id)
|
||||
|
||||
self._maybe_trigger_compaction(
|
||||
session=session,
|
||||
event=event,
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
logger.info(
|
||||
"append_event completed for session %s event %s in %.3fs",
|
||||
session_id,
|
||||
event.id,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
async def _prepare_event_data(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
event: Event,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Handle guardrail censoring and return serialized event data."""
|
||||
if not self._should_censor_model(session=session, event=event):
|
||||
return event.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
session.state["guardrail_censored"] = True
|
||||
event_to_save = copy.deepcopy(event)
|
||||
content = event_to_save.content
|
||||
|
||||
if content is None or not content.parts:
|
||||
logger.warning(
|
||||
"Guardrail censor requested but event content is missing; "
|
||||
"falling back to original event payload."
|
||||
)
|
||||
event_data = event.model_dump(mode="json", exclude_none=True)
|
||||
else:
|
||||
content.parts[0].text = settings.guardrail_censored_model_response
|
||||
event_data = event_to_save.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
await self._censor_previous_user_event(
|
||||
session=session,
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
return event_data
|
||||
|
||||
def _should_censor_model(self, *, session: Session, event: Event) -> Any | bool:
|
||||
return (
|
||||
session.state.get("guardrail_blocked", False)
|
||||
and event.author != "user"
|
||||
and hasattr(event, "content")
|
||||
@@ -390,58 +467,49 @@ class FirestoreSessionService(BaseSessionService):
|
||||
and not session.state.get("guardrail_censored", False)
|
||||
)
|
||||
|
||||
# Prepare event data for Firestore
|
||||
if should_censor_model:
|
||||
# Mark as censored to avoid double-censoring
|
||||
session.state["guardrail_censored"] = True
|
||||
|
||||
# Create a censored version of the model response
|
||||
event_to_save = copy.deepcopy(event)
|
||||
event_to_save.content.parts[0].text = settings.guardrail_censored_model_response
|
||||
event_data = event_to_save.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
# Also censor the previous user message in Firestore
|
||||
# Find the last user event in the session
|
||||
prev_user_event = next(
|
||||
(
|
||||
e
|
||||
for e in reversed(session.events[:-1])
|
||||
if e.author == "user" and e.content and e.content.parts
|
||||
),
|
||||
None,
|
||||
)
|
||||
if prev_user_event:
|
||||
# Update this event in Firestore with censored content
|
||||
censored_user_content = Content(
|
||||
role="user",
|
||||
parts=[Part(text=settings.guardrail_censored_user_message)],
|
||||
)
|
||||
await (
|
||||
self._events_col(app_name, user_id, session_id)
|
||||
.document(prev_user_event.id)
|
||||
.update(
|
||||
{
|
||||
"content": censored_user_content.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
else:
|
||||
event_data = event.model_dump(mode="json", exclude_none=True)
|
||||
|
||||
# Persist event document
|
||||
async def _censor_previous_user_event(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> None:
|
||||
prev_user_event = next(
|
||||
(
|
||||
e
|
||||
for e in reversed(session.events[:-1])
|
||||
if e.author == "user" and e.content and e.content.parts
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not prev_user_event:
|
||||
return
|
||||
censored_user_content = Content(
|
||||
role="user",
|
||||
parts=[Part(text=settings.guardrail_censored_user_message)],
|
||||
)
|
||||
await (
|
||||
self._events_col(app_name, user_id, session_id)
|
||||
.document(event.id)
|
||||
.set(event_data)
|
||||
.document(prev_user_event.id)
|
||||
.update(
|
||||
{
|
||||
"content": censored_user_content.model_dump(
|
||||
mode="json", exclude_none=True
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Persist state deltas
|
||||
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||
|
||||
async def _persist_state_deltas(
|
||||
self,
|
||||
*,
|
||||
event: Event,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_ref: Any,
|
||||
) -> None:
|
||||
last_update_dt = datetime.fromtimestamp(event.timestamp, UTC)
|
||||
|
||||
if event.actions and event.actions.state_delta:
|
||||
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
|
||||
|
||||
@@ -470,50 +538,50 @@ class FirestoreSessionService(BaseSessionService):
|
||||
)
|
||||
|
||||
await asyncio.gather(*write_coros)
|
||||
else:
|
||||
await session_ref.update({"last_update_time": last_update_dt})
|
||||
return
|
||||
|
||||
# Log token usage
|
||||
if event.usage_metadata:
|
||||
meta = event.usage_metadata
|
||||
logger.info(
|
||||
"Token usage for session %s event %s: "
|
||||
"prompt=%s, candidates=%s, total=%s",
|
||||
session_id,
|
||||
event.id,
|
||||
meta.prompt_token_count,
|
||||
meta.candidates_token_count,
|
||||
meta.total_token_count,
|
||||
)
|
||||
await session_ref.update({"last_update_time": last_update_dt})
|
||||
|
||||
# Trigger compaction if total token count exceeds threshold
|
||||
if (
|
||||
self._compaction_threshold is not None
|
||||
and event.usage_metadata
|
||||
and event.usage_metadata.total_token_count
|
||||
and event.usage_metadata.total_token_count >= self._compaction_threshold
|
||||
):
|
||||
logger.info(
|
||||
"Compaction triggered for session %s: "
|
||||
"total_token_count=%d >= threshold=%d",
|
||||
session_id,
|
||||
event.usage_metadata.total_token_count,
|
||||
self._compaction_threshold,
|
||||
)
|
||||
events_ref = self._events_col(app_name, user_id, session_id)
|
||||
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||
task = asyncio.create_task(
|
||||
self._compactor.guarded_compact(session, events_ref, session_ref)
|
||||
)
|
||||
self._active_tasks.add(task)
|
||||
task.add_done_callback(self._active_tasks.discard)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
def _log_token_usage(self, *, event: Event, session_id: str) -> None:
|
||||
if not event.usage_metadata:
|
||||
return
|
||||
meta = event.usage_metadata
|
||||
logger.info(
|
||||
"append_event completed for session %s event %s in %.3fs",
|
||||
"Token usage for session %s event %s: prompt=%s, candidates=%s, total=%s",
|
||||
session_id,
|
||||
event.id,
|
||||
elapsed,
|
||||
meta.prompt_token_count,
|
||||
meta.candidates_token_count,
|
||||
meta.total_token_count,
|
||||
)
|
||||
|
||||
return event
|
||||
def _maybe_trigger_compaction(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
event: Event,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
) -> None:
|
||||
if (
|
||||
self._compaction_threshold is None
|
||||
or not event.usage_metadata
|
||||
or not event.usage_metadata.total_token_count
|
||||
or event.usage_metadata.total_token_count < self._compaction_threshold
|
||||
):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Compaction triggered for session %s: total_token_count=%d >= threshold=%d",
|
||||
session_id,
|
||||
event.usage_metadata.total_token_count,
|
||||
self._compaction_threshold,
|
||||
)
|
||||
events_ref = self._events_col(app_name, user_id, session_id)
|
||||
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||
task = asyncio.create_task(
|
||||
self._compactor.guarded_compact(session, events_ref, session_ref)
|
||||
)
|
||||
self._active_tasks.add(task)
|
||||
task.add_done_callback(self._active_tasks.discard)
|
||||
|
||||
Reference in New Issue
Block a user