refactor(session): extract guardrail censor helpers
All checks were successful
CI / ci (pull_request) Successful in 21s
All checks were successful
CI / ci (pull_request) Successful in 21s
Pass all checks
This commit is contained in:
10
config.yaml
10
config.yaml
@@ -52,11 +52,11 @@ agent_instructions: |
|
|||||||
El teléfono de centro de contacto de VA es: +52 1 55 5140 5655
|
El teléfono de centro de contacto de VA es: +52 1 55 5140 5655
|
||||||
|
|
||||||
# Guardrail config
|
# Guardrail config
|
||||||
guardrail_censored_user_message: "[pregunta mala]"
|
guardrail_censored_user_message: "[ILEGAL QUESTION]"
|
||||||
guardrail_censored_model_response: "[respuesta de adversidad]"
|
guardrail_censored_model_response: "[ADVERSITY RESPONSE]"
|
||||||
guardrail_blocked_label: "[GUARDRAIL_BLOCKED]"
|
guardrail_blocked_label: "[GUARDRAIL BLOCKED]"
|
||||||
guardrail_passed_label: "[GUARDRAIL_PASSED]"
|
guardrail_passed_label: "[GUARDRAIL PASSED]"
|
||||||
guardrail_error_label: "[GUARDRAIL_ERROR]"
|
guardrail_error_label: "[GUARDRAIL ERROR]"
|
||||||
|
|
||||||
guardrail_instruction: |
|
guardrail_instruction: |
|
||||||
Eres una capa de seguridad y protección de marca para VAia, el asistente virtual de VA en WhatsApp.
|
Eres una capa de seguridad y protección de marca para VAia, el asistente virtual de VA en WhatsApp.
|
||||||
|
|||||||
@@ -171,14 +171,18 @@ class GovernancePlugin:
|
|||||||
|
|
||||||
if decision == "unsafe":
|
if decision == "unsafe":
|
||||||
callback_context.state["guardrail_blocked"] = True
|
callback_context.state["guardrail_blocked"] = True
|
||||||
callback_context.state["guardrail_message"] = settings.guardrail_blocked_label
|
callback_context.state["guardrail_message"] = (
|
||||||
|
settings.guardrail_blocked_label
|
||||||
|
)
|
||||||
callback_context.state["guardrail_reasoning"] = reasoning
|
callback_context.state["guardrail_reasoning"] = reasoning
|
||||||
return LlmResponse(
|
return LlmResponse(
|
||||||
content=Content(role="model", parts=[Part(text=blocking_response)]),
|
content=Content(role="model", parts=[Part(text=blocking_response)]),
|
||||||
usage_metadata=resp.usage_metadata or None,
|
usage_metadata=resp.usage_metadata or None,
|
||||||
)
|
)
|
||||||
callback_context.state["guardrail_blocked"] = False
|
callback_context.state["guardrail_blocked"] = False
|
||||||
callback_context.state["guardrail_message"] = settings.guardrail_passed_label
|
callback_context.state["guardrail_message"] = (
|
||||||
|
settings.guardrail_passed_label
|
||||||
|
)
|
||||||
callback_context.state["guardrail_reasoning"] = reasoning
|
callback_context.state["guardrail_reasoning"] = reasoning
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -380,8 +380,85 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
event = await super().append_event(session=session, event=event)
|
event = await super().append_event(session=session, event=event)
|
||||||
session.last_update_time = event.timestamp
|
session.last_update_time = event.timestamp
|
||||||
|
|
||||||
# Determine if we need to censor this event (model response when guardrail blocked)
|
event_data = await self._prepare_event_data(
|
||||||
should_censor_model = (
|
session=session,
|
||||||
|
event=event,
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await (
|
||||||
|
self._events_col(app_name, user_id, session_id)
|
||||||
|
.document(event.id)
|
||||||
|
.set(event_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||||
|
await self._persist_state_deltas(
|
||||||
|
event=event,
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
session_ref=session_ref,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._log_token_usage(event=event, session_id=session_id)
|
||||||
|
|
||||||
|
self._maybe_trigger_compaction(
|
||||||
|
session=session,
|
||||||
|
event=event,
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
logger.info(
|
||||||
|
"append_event completed for session %s event %s in %.3fs",
|
||||||
|
session_id,
|
||||||
|
event.id,
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def _prepare_event_data(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
event: Event,
|
||||||
|
app_name: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle guardrail censoring and return serialized event data."""
|
||||||
|
if not self._should_censor_model(session=session, event=event):
|
||||||
|
return event.model_dump(mode="json", exclude_none=True)
|
||||||
|
|
||||||
|
session.state["guardrail_censored"] = True
|
||||||
|
event_to_save = copy.deepcopy(event)
|
||||||
|
content = event_to_save.content
|
||||||
|
|
||||||
|
if content is None or not content.parts:
|
||||||
|
logger.warning(
|
||||||
|
"Guardrail censor requested but event content is missing; "
|
||||||
|
"falling back to original event payload."
|
||||||
|
)
|
||||||
|
event_data = event.model_dump(mode="json", exclude_none=True)
|
||||||
|
else:
|
||||||
|
content.parts[0].text = settings.guardrail_censored_model_response
|
||||||
|
event_data = event_to_save.model_dump(mode="json", exclude_none=True)
|
||||||
|
|
||||||
|
await self._censor_previous_user_event(
|
||||||
|
session=session,
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
return event_data
|
||||||
|
|
||||||
|
def _should_censor_model(self, *, session: Session, event: Event) -> Any | bool:
|
||||||
|
return (
|
||||||
session.state.get("guardrail_blocked", False)
|
session.state.get("guardrail_blocked", False)
|
||||||
and event.author != "user"
|
and event.author != "user"
|
||||||
and hasattr(event, "content")
|
and hasattr(event, "content")
|
||||||
@@ -390,18 +467,14 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
and not session.state.get("guardrail_censored", False)
|
and not session.state.get("guardrail_censored", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare event data for Firestore
|
async def _censor_previous_user_event(
|
||||||
if should_censor_model:
|
self,
|
||||||
# Mark as censored to avoid double-censoring
|
*,
|
||||||
session.state["guardrail_censored"] = True
|
session: Session,
|
||||||
|
app_name: str,
|
||||||
# Create a censored version of the model response
|
user_id: str,
|
||||||
event_to_save = copy.deepcopy(event)
|
session_id: str,
|
||||||
event_to_save.content.parts[0].text = settings.guardrail_censored_model_response
|
) -> None:
|
||||||
event_data = event_to_save.model_dump(mode="json", exclude_none=True)
|
|
||||||
|
|
||||||
# Also censor the previous user message in Firestore
|
|
||||||
# Find the last user event in the session
|
|
||||||
prev_user_event = next(
|
prev_user_event = next(
|
||||||
(
|
(
|
||||||
e
|
e
|
||||||
@@ -410,8 +483,8 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if prev_user_event:
|
if not prev_user_event:
|
||||||
# Update this event in Firestore with censored content
|
return
|
||||||
censored_user_content = Content(
|
censored_user_content = Content(
|
||||||
role="user",
|
role="user",
|
||||||
parts=[Part(text=settings.guardrail_censored_user_message)],
|
parts=[Part(text=settings.guardrail_censored_user_message)],
|
||||||
@@ -427,21 +500,16 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
event_data = event.model_dump(mode="json", exclude_none=True)
|
|
||||||
|
|
||||||
# Persist event document
|
|
||||||
await (
|
|
||||||
self._events_col(app_name, user_id, session_id)
|
|
||||||
.document(event.id)
|
|
||||||
.set(event_data)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Persist state deltas
|
|
||||||
session_ref = self._session_ref(app_name, user_id, session_id)
|
|
||||||
|
|
||||||
|
async def _persist_state_deltas(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
event: Event,
|
||||||
|
app_name: str,
|
||||||
|
user_id: str,
|
||||||
|
session_ref: Any,
|
||||||
|
) -> None:
|
||||||
last_update_dt = datetime.fromtimestamp(event.timestamp, UTC)
|
last_update_dt = datetime.fromtimestamp(event.timestamp, UTC)
|
||||||
|
|
||||||
if event.actions and event.actions.state_delta:
|
if event.actions and event.actions.state_delta:
|
||||||
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
|
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
|
||||||
|
|
||||||
@@ -470,15 +538,16 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.gather(*write_coros)
|
await asyncio.gather(*write_coros)
|
||||||
else:
|
return
|
||||||
|
|
||||||
await session_ref.update({"last_update_time": last_update_dt})
|
await session_ref.update({"last_update_time": last_update_dt})
|
||||||
|
|
||||||
# Log token usage
|
def _log_token_usage(self, *, event: Event, session_id: str) -> None:
|
||||||
if event.usage_metadata:
|
if not event.usage_metadata:
|
||||||
|
return
|
||||||
meta = event.usage_metadata
|
meta = event.usage_metadata
|
||||||
logger.info(
|
logger.info(
|
||||||
"Token usage for session %s event %s: "
|
"Token usage for session %s event %s: prompt=%s, candidates=%s, total=%s",
|
||||||
"prompt=%s, candidates=%s, total=%s",
|
|
||||||
session_id,
|
session_id,
|
||||||
event.id,
|
event.id,
|
||||||
meta.prompt_token_count,
|
meta.prompt_token_count,
|
||||||
@@ -486,16 +555,25 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
meta.total_token_count,
|
meta.total_token_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trigger compaction if total token count exceeds threshold
|
def _maybe_trigger_compaction(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session: Session,
|
||||||
|
event: Event,
|
||||||
|
app_name: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
) -> None:
|
||||||
if (
|
if (
|
||||||
self._compaction_threshold is not None
|
self._compaction_threshold is None
|
||||||
and event.usage_metadata
|
or not event.usage_metadata
|
||||||
and event.usage_metadata.total_token_count
|
or not event.usage_metadata.total_token_count
|
||||||
and event.usage_metadata.total_token_count >= self._compaction_threshold
|
or event.usage_metadata.total_token_count < self._compaction_threshold
|
||||||
):
|
):
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Compaction triggered for session %s: "
|
"Compaction triggered for session %s: total_token_count=%d >= threshold=%d",
|
||||||
"total_token_count=%d >= threshold=%d",
|
|
||||||
session_id,
|
session_id,
|
||||||
event.usage_metadata.total_token_count,
|
event.usage_metadata.total_token_count,
|
||||||
self._compaction_threshold,
|
self._compaction_threshold,
|
||||||
@@ -507,13 +585,3 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
)
|
)
|
||||||
self._active_tasks.add(task)
|
self._active_tasks.add(task)
|
||||||
task.add_done_callback(self._active_tasks.discard)
|
task.add_done_callback(self._active_tasks.discard)
|
||||||
|
|
||||||
elapsed = time.monotonic() - t0
|
|
||||||
logger.info(
|
|
||||||
"append_event completed for session %s event %s in %.3fs",
|
|
||||||
session_id,
|
|
||||||
event.id,
|
|
||||||
elapsed,
|
|
||||||
)
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|||||||
Reference in New Issue
Block a user