diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..b98199c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,33 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + ci: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Install dependencies + run: uv sync --frozen + + - name: Format check + run: uv run ruff format --check + + - name: Lint + run: uv run ruff check + + - name: Type check + run: uv run ty check + + - name: Test + run: uv run pytest diff --git a/AGENTS.md b/AGENTS.md index 3434537..a4792cf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,3 +1,4 @@ Use `uv` for project management. -Use `uv run ruff check` for linting, and `uv run ty check` for type checking +Use `uv run ruff check` for linting +Use `uv run ty check` for type checking Use `uv run pytest` for testing. diff --git a/src/va_agent/agent.py b/src/va_agent/agent.py index 72085d6..11b458b 100644 --- a/src/va_agent/agent.py +++ b/src/va_agent/agent.py @@ -1,7 +1,5 @@ """ADK agent with vector search RAG tool.""" -from functools import partial - from google import genai from google.adk.agents.llm_agent import Agent from google.adk.runners import Runner @@ -12,10 +10,9 @@ from google.genai.types import Content, Part from va_agent.auth import auth_headers_provider from va_agent.config import settings -from va_agent.dynamic_instruction import provide_dynamic_instruction +from va_agent.governance import GovernancePlugin from va_agent.notifications import NotificationService from va_agent.session import FirestoreSessionService -from va_agent.governance import GovernancePlugin # MCP Toolset for RAG knowledge search toolset = McpToolset( diff --git a/src/va_agent/config.py b/src/va_agent/config.py index f3502b5..ae33d2d 100644 --- a/src/va_agent/config.py +++ b/src/va_agent/config.py @@ -39,7 +39,7 @@ class AgentSettings(BaseSettings): model_config = SettingsConfigDict( yaml_file=CONFIG_FILE_PATH, extra="ignore", # Ignore extra fields from config.yaml - env_file=".env" + env_file=".env", ) @classmethod diff --git a/src/va_agent/dynamic_instruction.py b/src/va_agent/dynamic_instruction.py index 86f190d..cc64fb9 100644 --- a/src/va_agent/dynamic_instruction.py +++ b/src/va_agent/dynamic_instruction.py @@ -34,17 +34,19 @@ async def provide_dynamic_instruction( """ # Only check notifications on the first message - if not ctx or not ctx._invocation_context: + if not ctx: logger.debug("No context available for dynamic instruction") return "" - session = ctx._invocation_context.session + session = ctx.session if not session: logger.debug("No session available for dynamic instruction") return "" - # FOR TESTING: Always check for notifications (comment out to enable first-message-only) - # Only check on first message (when events list is empty or has only 1-2 events) + # FOR TESTING: Always check for notifications + # (comment out to enable first-message-only) + # Only check on first message (when events list is empty + # or has only 1-2 events) # Events include both user and agent messages, so < 2 means first interaction # event_count = len(session.events) if session.events else 0 # @@ -74,7 +76,11 @@ async def provide_dynamic_instruction( return "" # Build dynamic instruction with notification details - notification_ids = [n.get("id_notificacion") for n in pending_notifications] + notification_ids = [ + nid + for n in pending_notifications + if (nid := n.get("id_notificacion")) is not None + ] count = len(pending_notifications) # Format notification details for the agent @@ -97,9 +103,11 @@ INSTRUCCIONES: - Menciona estas notificaciones de forma natural en tu respuesta inicial - No necesitas leerlas todas literalmente, solo hazle saber que las tiene - Sé breve y directo según tu personalidad (directo y cálido) -- Si el usuario pregunta algo específico, prioriza responder eso primero y luego menciona las notificaciones +- Si el usuario pregunta algo específico, prioriza responder eso primero\ + y luego menciona las notificaciones -Ejemplo: "¡Hola! 👋 Antes de empezar, veo que tienes {count} notificación(es) pendiente(s) en tu cuenta. ¿Te gustaría revisarlas o prefieres que te ayude con algo más?" +Ejemplo: "¡Hola! 👋 Tienes {count} notificación(es)\ + pendiente(s). ¿Te gustaría revisarlas?" """ # Mark notifications as notified in Firestore @@ -111,10 +119,11 @@ Ejemplo: "¡Hola! 👋 Antes de empezar, veo que tienes {count} notificación(es phone_number, ) - return instruction - except Exception: logger.exception( - "Error building dynamic instruction for user %s", phone_number + "Error building dynamic instruction for user %s", + phone_number, ) return "" + else: + return instruction diff --git a/src/va_agent/governance.py b/src/va_agent/governance.py index 936c668..a65d5a3 100644 --- a/src/va_agent/governance.py +++ b/src/va_agent/governance.py @@ -1,4 +1,5 @@ """GovernancePlugin: Guardrails for VAia, the virtual assistant for VA.""" + import logging import re @@ -9,10 +10,57 @@ logger = logging.getLogger(__name__) FORBIDDEN_EMOJIS = [ - "🥵","🔪","🎰","🎲","🃏","😤","🤬","😡","😠","🩸","🧨","🪓","☠️","💀", - "💣","🔫","👗","💦","🍑","🍆","👄","👅","🫦","💩","⚖️","⚔️","✝️","🕍", - "🕌","⛪","🍻","🍸","🥃","🍷","🍺","🚬","👹","👺","👿","😈","🤡","🧙", - "🧙‍♀️", "🧙‍♂️", "🧛", "🧛‍♀️", "🧛‍♂️", "🔞","🧿","💊", "💏" + "🥵", + "🔪", + "🎰", + "🎲", + "🃏", + "😤", + "🤬", + "😡", + "😠", + "🩸", + "🧨", + "🪓", + "☠️", + "💀", + "💣", + "🔫", + "👗", + "💦", + "🍑", + "🍆", + "👄", + "👅", + "🫦", + "💩", + "⚖️", + "⚔️", + "✝️", + "🕍", + "🕌", + "⛪", + "🍻", + "🍸", + "🥃", + "🍷", + "🍺", + "🚬", + "👹", + "👺", + "👿", + "😈", + "🤡", + "🧙", + "🧙‍♀️", + "🧙‍♂️", + "🧛", + "🧛‍♀️", + "🧛‍♂️", + "🔞", + "🧿", + "💊", + "💏", ] @@ -20,29 +68,31 @@ class GovernancePlugin: """Guardrail executor for VAia requests as a Agent engine callbacks.""" def __init__(self) -> None: - """Initialize guardrail model (structured output), prompt and emojis patterns.""" + """Initialize guardrail model, prompt and emojis patterns.""" self._combined_pattern = self._get_combined_pattern() - def _get_combined_pattern(self): - person_pattern = r"(?:🧑|👩|👨)" - tone_pattern = r"[\U0001F3FB-\U0001F3FF]?" - - # Unique pattern that combines all forbidden emojis, including complex ones with skin tones - combined_pattern = re.compile( - rf"{person_pattern}{tone_pattern}\u200d❤️?\u200d💋\u200d{person_pattern}{tone_pattern}" # kiss - rf"|{person_pattern}{tone_pattern}\u200d❤️?\u200d{person_pattern}{tone_pattern}" # lovers - rf"|🖕{tone_pattern}" # middle finger with all skin tone variations - rf"|{'|'.join(map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)))}" # simple emojis - rf"|\u200d|\uFE0F" # residual ZWJ and variation selectors + def _get_combined_pattern(self) -> re.Pattern[str]: + person = r"(?:🧑|👩|👨)" + tone = r"[\U0001F3FB-\U0001F3FF]?" + simple = "|".join( + map(re.escape, sorted(FORBIDDEN_EMOJIS, key=len, reverse=True)) ) - return combined_pattern - + + # Combines all forbidden emojis, including complex + # ones with skin tones + return re.compile( + rf"{person}{tone}\u200d❤️?\u200d💋\u200d{person}{tone}" + rf"|{person}{tone}\u200d❤️?\u200d{person}{tone}" + rf"|🖕{tone}" + rf"|{simple}" + rf"|\u200d|\uFE0F" + ) + def _remove_emojis(self, text: str) -> tuple[str, list[str]]: removed = self._combined_pattern.findall(text) text = self._combined_pattern.sub("", text) return text.strip(), removed - def after_model_callback( self, callback_context: CallbackContext | None = None, diff --git a/src/va_agent/notifications.py b/src/va_agent/notifications.py index e7cc57a..8536fb2 100644 --- a/src/va_agent/notifications.py +++ b/src/va_agent/notifications.py @@ -58,9 +58,7 @@ class NotificationService: """ try: # Query Firestore document by phone number - doc_ref = self._db.collection(self._collection_path).document( - phone_number - ) + doc_ref = self._db.collection(self._collection_path).document(phone_number) doc = await doc_ref.get() if not doc.exists: @@ -78,9 +76,7 @@ class NotificationService: # Filter notifications that have NOT been notified by the agent pending = [ - n - for n in all_notifications - if not n.get("notified_by_agent", False) + n for n in all_notifications if not n.get("notified_by_agent", False) ] if not pending: @@ -90,9 +86,7 @@ class NotificationService: return [] # Sort by timestamp_creacion (most recent first) - pending.sort( - key=lambda n: n.get("timestamp_creacion", 0), reverse=True - ) + pending.sort(key=lambda n: n.get("timestamp_creacion", 0), reverse=True) # Return top N most recent result = pending[: self._max_to_notify] @@ -104,13 +98,13 @@ class NotificationService: len(result), ) - return result - except Exception: logger.exception( "Failed to fetch notifications for phone: %s", phone_number ) return [] + else: + return result async def mark_as_notified( self, phone_number: str, notification_ids: list[str] @@ -133,9 +127,7 @@ class NotificationService: return True try: - doc_ref = self._db.collection(self._collection_path).document( - phone_number - ) + doc_ref = self._db.collection(self._collection_path).document(phone_number) doc = await doc_ref.get() if not doc.exists: @@ -184,18 +176,16 @@ class NotificationService: phone_number, ) - return True - except Exception: logger.exception( "Failed to mark notifications as notified for phone: %s", phone_number, ) return False + else: + return True - def format_notification_summary( - self, notifications: list[dict[str, Any]] - ) -> str: + def format_notification_summary(self, notifications: list[dict[str, Any]]) -> str: """Format notifications into a human-readable summary. Args: @@ -209,9 +199,7 @@ class NotificationService: return "" count = len(notifications) - summary_lines = [ - f"El usuario tiene {count} notificación(es) pendiente(s):" - ] + summary_lines = [f"El usuario tiene {count} notificación(es) pendiente(s):"] for i, notif in enumerate(notifications, 1): texto = notif.get("texto", "Sin texto") diff --git a/tests/conftest.py b/tests/conftest.py index 959b677..2c1c3ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,25 +2,23 @@ from __future__ import annotations -import os import uuid import pytest import pytest_asyncio -from google.cloud.firestore_v1.async_client import AsyncClient from va_agent.session import FirestoreSessionService -os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8602") +from .fake_firestore import FakeAsyncClient @pytest_asyncio.fixture async def db(): - return AsyncClient(project="test-project") + return FakeAsyncClient() @pytest_asyncio.fixture -async def service(db: AsyncClient): +async def service(db): prefix = f"test_{uuid.uuid4().hex[:8]}" return FirestoreSessionService(db=db, collection_prefix=prefix) diff --git a/tests/fake_firestore.py b/tests/fake_firestore.py new file mode 100644 index 0000000..f7fce9b --- /dev/null +++ b/tests/fake_firestore.py @@ -0,0 +1,284 @@ +"""In-memory fake of the Firestore async surface used by this project. + +Covers: AsyncClient, DocumentReference, CollectionReference, Query, +DocumentSnapshot, WriteBatch, and basic transaction support (enough for +``@async_transactional``). +""" + +from __future__ import annotations + +import copy +from typing import Any + + +# ------------------------------------------------------------------ # +# DocumentSnapshot +# ------------------------------------------------------------------ # + +class FakeDocumentSnapshot: + def __init__(self, *, exists: bool, data: dict[str, Any] | None, reference: FakeDocumentReference) -> None: + self._exists = exists + self._data = data + self._reference = reference + + @property + def exists(self) -> bool: + return self._exists + + @property + def reference(self) -> FakeDocumentReference: + return self._reference + + def to_dict(self) -> dict[str, Any] | None: + if not self._exists: + return None + return copy.deepcopy(self._data) + + +# ------------------------------------------------------------------ # +# DocumentReference +# ------------------------------------------------------------------ # + +class FakeDocumentReference: + def __init__(self, store: FakeStore, path: str) -> None: + self._store = store + self._path = path + + @property + def path(self) -> str: + return self._path + + # --- read --- + + async def get(self, *, transaction: FakeTransaction | None = None) -> FakeDocumentSnapshot: + data = self._store.get_doc(self._path) + if data is None: + return FakeDocumentSnapshot(exists=False, data=None, reference=self) + return FakeDocumentSnapshot(exists=True, data=copy.deepcopy(data), reference=self) + + # --- write --- + + async def set(self, document_data: dict[str, Any], merge: bool = False) -> None: + if merge: + existing = self._store.get_doc(self._path) or {} + existing.update(document_data) + self._store.set_doc(self._path, existing) + else: + self._store.set_doc(self._path, copy.deepcopy(document_data)) + + async def update(self, field_updates: dict[str, Any]) -> None: + data = self._store.get_doc(self._path) + if data is None: + msg = f"Document {self._path} does not exist" + raise ValueError(msg) + for key, value in field_updates.items(): + _nested_set(data, key, value) + self._store.set_doc(self._path, data) + + # --- subcollection --- + + def collection(self, subcollection_name: str) -> FakeCollectionReference: + return FakeCollectionReference(self._store, f"{self._path}/{subcollection_name}") + + +# ------------------------------------------------------------------ # +# Helpers for nested field-path updates ("state.counter" → data["state"]["counter"]) +# ------------------------------------------------------------------ # + +def _nested_set(data: dict[str, Any], dotted_key: str, value: Any) -> None: + parts = dotted_key.split(".") + for part in parts[:-1]: + # Backtick-quoted segments (Firestore FieldPath encoding) + part = part.strip("`") + data = data.setdefault(part, {}) + final = parts[-1].strip("`") + data[final] = value + + +# ------------------------------------------------------------------ # +# Query +# ------------------------------------------------------------------ # + +class FakeQuery: + """Supports chained .where() / .order_by() / .get().""" + + def __init__(self, store: FakeStore, collection_path: str) -> None: + self._store = store + self._collection_path = collection_path + self._filters: list[tuple[str, str, Any]] = [] + self._order_by_field: str | None = None + + def where(self, *, filter: Any) -> FakeQuery: # noqa: A002 + clone = FakeQuery(self._store, self._collection_path) + clone._filters = [*self._filters, (filter.field_path, filter.op_string, filter.value)] + clone._order_by_field = self._order_by_field + return clone + + def order_by(self, field_path: str) -> FakeQuery: + clone = FakeQuery(self._store, self._collection_path) + clone._filters = list(self._filters) + clone._order_by_field = field_path + return clone + + async def get(self) -> list[FakeDocumentSnapshot]: + docs = self._store.list_collection(self._collection_path) + results: list[tuple[str, dict[str, Any]]] = [] + + for doc_path, data in docs: + if all(_match(data, field, op, val) for field, op, val in self._filters): + results.append((doc_path, data)) + + if self._order_by_field: + field = self._order_by_field + results.sort(key=lambda item: item[1].get(field, 0)) + + return [ + FakeDocumentSnapshot( + exists=True, + data=copy.deepcopy(data), + reference=FakeDocumentReference(self._store, path), + ) + for path, data in results + ] + + +def _match(data: dict[str, Any], field: str, op: str, value: Any) -> bool: + doc_val = data.get(field) + if op == "==": + return doc_val == value + if op == ">=": + return doc_val is not None and doc_val >= value + return False + + +# ------------------------------------------------------------------ # +# CollectionReference (extends Query behaviour) +# ------------------------------------------------------------------ # + +class FakeCollectionReference(FakeQuery): + def document(self, document_id: str) -> FakeDocumentReference: + return FakeDocumentReference(self._store, f"{self._collection_path}/{document_id}") + + +# ------------------------------------------------------------------ # +# WriteBatch +# ------------------------------------------------------------------ # + +class FakeWriteBatch: + def __init__(self, store: FakeStore) -> None: + self._store = store + self._deletes: list[str] = [] + + def delete(self, doc_ref: FakeDocumentReference) -> None: + self._deletes.append(doc_ref.path) + + async def commit(self) -> None: + for path in self._deletes: + self._store.delete_doc(path) + + +# ------------------------------------------------------------------ # +# Transaction (minimal, supports @async_transactional) +# ------------------------------------------------------------------ # + +class FakeTransaction: + """Minimal transaction compatible with ``@async_transactional``. + + The decorator calls ``_clean_up()``, ``_begin()``, the wrapped function, + then ``_commit()``. On error it calls ``_rollback()``. + ``in_progress`` is a property that checks ``_id is not None``. + """ + + def __init__(self, store: FakeStore) -> None: + self._store = store + self._staged_updates: list[tuple[str, dict[str, Any]]] = [] + self._id: bytes | None = None + self._max_attempts = 1 + self._read_only = False + + @property + def in_progress(self) -> bool: + return self._id is not None + + def _clean_up(self) -> None: + self._id = None + + async def _begin(self, retry_id: bytes | None = None) -> None: + self._id = b"fake-txn" + + async def _commit(self) -> list: + for path, updates in self._staged_updates: + data = self._store.get_doc(path) + if data is not None: + for key, value in updates.items(): + _nested_set(data, key, value) + self._store.set_doc(path, data) + self._staged_updates.clear() + self._clean_up() + return [] + + async def _rollback(self) -> None: + self._staged_updates.clear() + self._clean_up() + + def update(self, doc_ref: FakeDocumentReference, field_updates: dict[str, Any]) -> None: + self._staged_updates.append((doc_ref.path, field_updates)) + + +# ------------------------------------------------------------------ # +# Document store (flat dict keyed by path) +# ------------------------------------------------------------------ # + +class FakeStore: + def __init__(self) -> None: + self._docs: dict[str, dict[str, Any]] = {} + + def get_doc(self, path: str) -> dict[str, Any] | None: + data = self._docs.get(path) + return data # returns reference, callers deepcopy where needed + + def set_doc(self, path: str, data: dict[str, Any]) -> None: + self._docs[path] = data + + def delete_doc(self, path: str) -> None: + self._docs.pop(path, None) + + def list_collection(self, collection_path: str) -> list[tuple[str, dict[str, Any]]]: + """Return (path, data) for every direct child doc of *collection_path*.""" + prefix = collection_path + "/" + results: list[tuple[str, dict[str, Any]]] = [] + for doc_path, data in self._docs.items(): + if not doc_path.startswith(prefix): + continue + # Must be a direct child (no further '/' after the prefix, except maybe subcollection paths) + remainder = doc_path[len(prefix):] + if "/" not in remainder: + results.append((doc_path, data)) + return results + + def recursive_delete(self, path: str) -> None: + """Delete a document and everything nested under it.""" + to_delete = [p for p in self._docs if p == path or p.startswith(path + "/")] + for p in to_delete: + del self._docs[p] + + +# ------------------------------------------------------------------ # +# FakeAsyncClient (drop-in for AsyncClient) +# ------------------------------------------------------------------ # + +class FakeAsyncClient: + def __init__(self, **_kwargs: Any) -> None: + self._store = FakeStore() + + def collection(self, collection_path: str) -> FakeCollectionReference: + return FakeCollectionReference(self._store, collection_path) + + def batch(self) -> FakeWriteBatch: + return FakeWriteBatch(self._store) + + def transaction(self, **kwargs: Any) -> FakeTransaction: + return FakeTransaction(self._store) + + async def recursive_delete(self, doc_ref: FakeDocumentReference) -> None: + self._store.recursive_delete(doc_ref.path)