From 2b610d30a995e63a94e30d77f874995dbae39374 Mon Sep 17 00:00:00 2001 From: Jorge Juarez Date: Fri, 13 Mar 2026 22:41:57 +0000 Subject: [PATCH] refactor(session): extract guardrail censor helpers Pass all checks --- config.yaml | 10 +- src/va_agent/governance.py | 8 +- src/va_agent/session.py | 246 +++++++++++++++++++++++-------------- 3 files changed, 168 insertions(+), 96 deletions(-) diff --git a/config.yaml b/config.yaml index 635c75d..610c7a7 100644 --- a/config.yaml +++ b/config.yaml @@ -52,11 +52,11 @@ agent_instructions: | El teléfono de centro de contacto de VA es: +52 1 55 5140 5655 # Guardrail config -guardrail_censored_user_message: "[pregunta mala]" -guardrail_censored_model_response: "[respuesta de adversidad]" -guardrail_blocked_label: "[GUARDRAIL_BLOCKED]" -guardrail_passed_label: "[GUARDRAIL_PASSED]" -guardrail_error_label: "[GUARDRAIL_ERROR]" +guardrail_censored_user_message: "[ILEGAL QUESTION]" +guardrail_censored_model_response: "[ADVERSITY RESPONSE]" +guardrail_blocked_label: "[GUARDRAIL BLOCKED]" +guardrail_passed_label: "[GUARDRAIL PASSED]" +guardrail_error_label: "[GUARDRAIL ERROR]" guardrail_instruction: | Eres una capa de seguridad y protección de marca para VAia, el asistente virtual de VA en WhatsApp. diff --git a/src/va_agent/governance.py b/src/va_agent/governance.py index fb8ab26..7118df6 100644 --- a/src/va_agent/governance.py +++ b/src/va_agent/governance.py @@ -171,14 +171,18 @@ class GovernancePlugin: if decision == "unsafe": callback_context.state["guardrail_blocked"] = True - callback_context.state["guardrail_message"] = settings.guardrail_blocked_label + callback_context.state["guardrail_message"] = ( + settings.guardrail_blocked_label + ) callback_context.state["guardrail_reasoning"] = reasoning return LlmResponse( content=Content(role="model", parts=[Part(text=blocking_response)]), usage_metadata=resp.usage_metadata or None, ) callback_context.state["guardrail_blocked"] = False - callback_context.state["guardrail_message"] = settings.guardrail_passed_label + callback_context.state["guardrail_message"] = ( + settings.guardrail_passed_label + ) callback_context.state["guardrail_reasoning"] = reasoning except Exception: diff --git a/src/va_agent/session.py b/src/va_agent/session.py index bca3788..f283c14 100644 --- a/src/va_agent/session.py +++ b/src/va_agent/session.py @@ -380,8 +380,85 @@ class FirestoreSessionService(BaseSessionService): event = await super().append_event(session=session, event=event) session.last_update_time = event.timestamp - # Determine if we need to censor this event (model response when guardrail blocked) - should_censor_model = ( + event_data = await self._prepare_event_data( + session=session, + event=event, + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + + await ( + self._events_col(app_name, user_id, session_id) + .document(event.id) + .set(event_data) + ) + + session_ref = self._session_ref(app_name, user_id, session_id) + await self._persist_state_deltas( + event=event, + app_name=app_name, + user_id=user_id, + session_ref=session_ref, + ) + + self._log_token_usage(event=event, session_id=session_id) + + self._maybe_trigger_compaction( + session=session, + event=event, + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + + elapsed = time.monotonic() - t0 + logger.info( + "append_event completed for session %s event %s in %.3fs", + session_id, + event.id, + elapsed, + ) + + return event + + async def _prepare_event_data( + self, + *, + session: Session, + event: Event, + app_name: str, + user_id: str, + session_id: str, + ) -> dict[str, Any]: + """Handle guardrail censoring and return serialized event data.""" + if not self._should_censor_model(session=session, event=event): + return event.model_dump(mode="json", exclude_none=True) + + session.state["guardrail_censored"] = True + event_to_save = copy.deepcopy(event) + content = event_to_save.content + + if content is None or not content.parts: + logger.warning( + "Guardrail censor requested but event content is missing; " + "falling back to original event payload." + ) + event_data = event.model_dump(mode="json", exclude_none=True) + else: + content.parts[0].text = settings.guardrail_censored_model_response + event_data = event_to_save.model_dump(mode="json", exclude_none=True) + + await self._censor_previous_user_event( + session=session, + app_name=app_name, + user_id=user_id, + session_id=session_id, + ) + return event_data + + def _should_censor_model(self, *, session: Session, event: Event) -> Any | bool: + return ( session.state.get("guardrail_blocked", False) and event.author != "user" and hasattr(event, "content") @@ -390,58 +467,49 @@ class FirestoreSessionService(BaseSessionService): and not session.state.get("guardrail_censored", False) ) - # Prepare event data for Firestore - if should_censor_model: - # Mark as censored to avoid double-censoring - session.state["guardrail_censored"] = True - - # Create a censored version of the model response - event_to_save = copy.deepcopy(event) - event_to_save.content.parts[0].text = settings.guardrail_censored_model_response - event_data = event_to_save.model_dump(mode="json", exclude_none=True) - - # Also censor the previous user message in Firestore - # Find the last user event in the session - prev_user_event = next( - ( - e - for e in reversed(session.events[:-1]) - if e.author == "user" and e.content and e.content.parts - ), - None, - ) - if prev_user_event: - # Update this event in Firestore with censored content - censored_user_content = Content( - role="user", - parts=[Part(text=settings.guardrail_censored_user_message)], - ) - await ( - self._events_col(app_name, user_id, session_id) - .document(prev_user_event.id) - .update( - { - "content": censored_user_content.model_dump( - mode="json", exclude_none=True - ) - } - ) - ) - else: - event_data = event.model_dump(mode="json", exclude_none=True) - - # Persist event document + async def _censor_previous_user_event( + self, + *, + session: Session, + app_name: str, + user_id: str, + session_id: str, + ) -> None: + prev_user_event = next( + ( + e + for e in reversed(session.events[:-1]) + if e.author == "user" and e.content and e.content.parts + ), + None, + ) + if not prev_user_event: + return + censored_user_content = Content( + role="user", + parts=[Part(text=settings.guardrail_censored_user_message)], + ) await ( self._events_col(app_name, user_id, session_id) - .document(event.id) - .set(event_data) + .document(prev_user_event.id) + .update( + { + "content": censored_user_content.model_dump( + mode="json", exclude_none=True + ) + } + ) ) - # Persist state deltas - session_ref = self._session_ref(app_name, user_id, session_id) - + async def _persist_state_deltas( + self, + *, + event: Event, + app_name: str, + user_id: str, + session_ref: Any, + ) -> None: last_update_dt = datetime.fromtimestamp(event.timestamp, UTC) - if event.actions and event.actions.state_delta: state_deltas = _session_util.extract_state_delta(event.actions.state_delta) @@ -470,50 +538,50 @@ class FirestoreSessionService(BaseSessionService): ) await asyncio.gather(*write_coros) - else: - await session_ref.update({"last_update_time": last_update_dt}) + return - # Log token usage - if event.usage_metadata: - meta = event.usage_metadata - logger.info( - "Token usage for session %s event %s: " - "prompt=%s, candidates=%s, total=%s", - session_id, - event.id, - meta.prompt_token_count, - meta.candidates_token_count, - meta.total_token_count, - ) + await session_ref.update({"last_update_time": last_update_dt}) - # Trigger compaction if total token count exceeds threshold - if ( - self._compaction_threshold is not None - and event.usage_metadata - and event.usage_metadata.total_token_count - and event.usage_metadata.total_token_count >= self._compaction_threshold - ): - logger.info( - "Compaction triggered for session %s: " - "total_token_count=%d >= threshold=%d", - session_id, - event.usage_metadata.total_token_count, - self._compaction_threshold, - ) - events_ref = self._events_col(app_name, user_id, session_id) - session_ref = self._session_ref(app_name, user_id, session_id) - task = asyncio.create_task( - self._compactor.guarded_compact(session, events_ref, session_ref) - ) - self._active_tasks.add(task) - task.add_done_callback(self._active_tasks.discard) - - elapsed = time.monotonic() - t0 + def _log_token_usage(self, *, event: Event, session_id: str) -> None: + if not event.usage_metadata: + return + meta = event.usage_metadata logger.info( - "append_event completed for session %s event %s in %.3fs", + "Token usage for session %s event %s: prompt=%s, candidates=%s, total=%s", session_id, event.id, - elapsed, + meta.prompt_token_count, + meta.candidates_token_count, + meta.total_token_count, ) - return event + def _maybe_trigger_compaction( + self, + *, + session: Session, + event: Event, + app_name: str, + user_id: str, + session_id: str, + ) -> None: + if ( + self._compaction_threshold is None + or not event.usage_metadata + or not event.usage_metadata.total_token_count + or event.usage_metadata.total_token_count < self._compaction_threshold + ): + return + + logger.info( + "Compaction triggered for session %s: total_token_count=%d >= threshold=%d", + session_id, + event.usage_metadata.total_token_count, + self._compaction_threshold, + ) + events_ref = self._events_col(app_name, user_id, session_id) + session_ref = self._session_ref(app_name, user_id, session_id) + task = asyncio.create_task( + self._compactor.guarded_compact(session, events_ref, session_ref) + ) + self._active_tasks.add(task) + task.add_done_callback(self._active_tasks.discard)