diff --git a/pyproject.toml b/pyproject.toml index 9a41602..4bbd3eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ filterwarnings = [ ] env = [ - "FIRESTORE_EMULATOR_HOST=[::1]:8469", + "FIRESTORE_EMULATOR_HOST=[::1]:8462", "GCP_PROJECT_ID=test-project", "GCP_LOCATION=us-central1", "GCP_FIRESTORE_DATABASE_ID=(default)", diff --git a/src/capa_de_integracion/dependencies.py b/src/capa_de_integracion/dependencies.py index dce72c0..2de2a49 100644 --- a/src/capa_de_integracion/dependencies.py +++ b/src/capa_de_integracion/dependencies.py @@ -1,5 +1,7 @@ """Dependency injection and service lifecycle management.""" +import asyncio +import logging from functools import lru_cache from capa_de_integracion.services.rag import ( @@ -16,8 +18,12 @@ from .services import ( QuickReplyContentService, QuickReplySessionService, ) +from .services.conversation import get_background_tasks as conv_bg_tasks +from .services.notifications import get_background_tasks as notif_bg_tasks from .services.storage import FirestoreService, RedisService +logger = logging.getLogger(__name__) + @lru_cache(maxsize=1) def get_redis_service() -> RedisService: @@ -106,6 +112,12 @@ async def startup_services() -> None: async def shutdown_services() -> None: """Close all service connections on shutdown.""" + # Drain in-flight background tasks before closing connections + all_tasks = conv_bg_tasks() | notif_bg_tasks() + if all_tasks: + logger.info("Draining %d background tasks before shutdown…", len(all_tasks)) + await asyncio.gather(*all_tasks, return_exceptions=True) + # Close Redis redis = get_redis_service() await redis.close() diff --git a/src/capa_de_integracion/services/conversation.py b/src/capa_de_integracion/services/conversation.py index 009cdf8..9ce40fd 100644 --- a/src/capa_de_integracion/services/conversation.py +++ b/src/capa_de_integracion/services/conversation.py @@ -1,5 +1,6 @@ """Conversation manager service for orchestrating user conversations.""" +import asyncio import logging import re from datetime import UTC, datetime, timedelta @@ -22,6 +23,14 @@ from capa_de_integracion.services.storage.redis import RedisService logger = logging.getLogger(__name__) +# Keep references to background tasks to prevent garbage collection +_background_tasks: set[asyncio.Task[None]] = set() + + +def get_background_tasks() -> set[asyncio.Task[None]]: + """Return the set of pending background tasks (for graceful shutdown).""" + return _background_tasks + MSG_EMPTY_MESSAGE = "Message cannot be empty" @@ -88,16 +97,16 @@ class ConversationManagerService: # Step 1: Validate message is not empty self._validate_message(request.mensaje) - # Step 2: Apply DLP security - obfuscated_message = await self.dlp_service.get_obfuscated_string( - request.mensaje, - self.settings.dlp_template_complete_flow, + # Step 2+3: Apply DLP security and obtain session in parallel + telefono = request.usuario.telefono + obfuscated_message, session = await asyncio.gather( + self.dlp_service.get_obfuscated_string( + request.mensaje, + self.settings.dlp_template_complete_flow, + ), + self._obtain_or_create_session(telefono), ) request.mensaje = obfuscated_message - telefono = request.usuario.telefono - - # Step 3: Obtain or create session - session = await self._obtain_or_create_session(telefono) # Step 4: Try quick reply path first response = await self._handle_quick_reply_path(request, session) @@ -131,6 +140,8 @@ class ConversationManagerService: # Try Firestore if Redis miss session = await self.firestore_service.get_session_by_phone(telefono) if session: + # Cache to Redis for subsequent requests + await self.redis_service.save_session(session) return session # Create new session if both miss @@ -165,27 +176,31 @@ class ConversationManagerService: canal: Communication channel """ - # Save user entry + # Save user and assistant entries in parallel. + # Use a single timestamp for both, but offset the assistant entry by 1µs + # to avoid Firestore document ID collision (save_entry uses isoformat() + # as the document ID). + now = datetime.now(UTC) user_entry = ConversationEntry( entity="user", type=entry_type, - timestamp=datetime.now(UTC), + timestamp=now, text=user_text, parameters=None, canal=canal, ) - await self.firestore_service.save_entry(session_id, user_entry) - - # Save assistant entry assistant_entry = ConversationEntry( entity="assistant", type=entry_type, - timestamp=datetime.now(UTC), + timestamp=now + timedelta(microseconds=1), text=assistant_text, parameters=None, canal=canal, ) - await self.firestore_service.save_entry(session_id, assistant_entry) + await asyncio.gather( + self.firestore_service.save_entry(session_id, user_entry), + self.firestore_service.save_entry(session_id, assistant_entry), + ) async def _update_session_after_turn( self, @@ -204,8 +219,10 @@ class ConversationManagerService: """ session.last_message = last_message session.last_modified = datetime.now(UTC) - await self.firestore_service.save_session(session) - await self.redis_service.save_session(session) + await asyncio.gather( + self.firestore_service.save_session(session), + self.redis_service.save_session(session), + ) async def _handle_quick_reply_path( self, @@ -253,17 +270,25 @@ class ConversationManagerService: response.query_result.response_text if response.query_result else "" ) or "" - # Save conversation turn - await self._save_conversation_turn( - session_id=session.session_id, - user_text=request.mensaje, - assistant_text=response_text, - entry_type="CONVERSACION", - canal=getattr(request, "canal", None), - ) + # Fire-and-forget: persist conversation turn and update session + async def _post_response() -> None: + try: + await asyncio.gather( + self._save_conversation_turn( + session_id=session.session_id, + user_text=request.mensaje, + assistant_text=response_text, + entry_type="CONVERSACION", + canal=getattr(request, "canal", None), + ), + self._update_session_after_turn(session, response_text), + ) + except Exception: + logger.exception("Error in quick-reply post-response work") - # Update session - await self._update_session_after_turn(session, response_text) + task = asyncio.create_task(_post_response()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) return response @@ -292,13 +317,19 @@ class ConversationManagerService: telefono, ) - # Load conversation history only if session is older than threshold - # (optimization: new/recent sessions don't need history context) + # Load conversation history and notifications in parallel session_age = datetime.now(UTC) - session.created_at - if session_age > timedelta(minutes=self.SESSION_RESET_THRESHOLD_MINUTES): - entries = await self.firestore_service.get_entries( - session.session_id, - limit=self.settings.conversation_context_message_limit, + load_history = session_age > timedelta( + minutes=self.SESSION_RESET_THRESHOLD_MINUTES, + ) + + if load_history: + entries, notifications = await asyncio.gather( + self.firestore_service.get_entries( + session.session_id, + limit=self.settings.conversation_context_message_limit, + ), + self._get_active_notifications(telefono), ) logger.info( "Session is %s minutes old. Loaded %s conversation entries.", @@ -307,13 +338,12 @@ class ConversationManagerService: ) else: entries = [] + notifications = await self._get_active_notifications(telefono) logger.info( "Session is only %s minutes old. Skipping history load.", session_age.total_seconds() / 60, ) - # Retrieve active notifications for this user - notifications = await self._get_active_notifications(telefono) logger.info("Retrieved %s active notifications", len(notifications)) # Prepare current user message @@ -344,27 +374,8 @@ class ConversationManagerService: assistant_response[:100], ) - # Save conversation turn - await self._save_conversation_turn( - session_id=session.session_id, - user_text=request.mensaje, - assistant_text=assistant_response, - entry_type="LLM", - canal=getattr(request, "canal", None), - ) - logger.info("Saved user message and assistant response to Firestore") - - # Update session - await self._update_session_after_turn(session, assistant_response) - logger.info("Updated session in Firestore and Redis") - - # Mark notifications as processed if any were included - if notifications: - await self._mark_notifications_as_processed(telefono) - logger.info("Marked %s notifications as processed", len(notifications)) - - # Return response object - return DetectIntentResponse( + # Build response object first, then fire-and-forget persistence + response = DetectIntentResponse( responseId=str(uuid4()), queryResult=QueryResult( responseText=assistant_response, @@ -373,6 +384,31 @@ class ConversationManagerService: quick_replies=None, ) + # Fire-and-forget: persist conversation and update session + async def _post_response() -> None: + try: + coros = [ + self._save_conversation_turn( + session_id=session.session_id, + user_text=request.mensaje, + assistant_text=assistant_response, + entry_type="LLM", + canal=getattr(request, "canal", None), + ), + self._update_session_after_turn(session, assistant_response), + ] + if notifications: + coros.append(self._mark_notifications_as_processed(telefono)) + await asyncio.gather(*coros) + except Exception: + logger.exception("Error in post-response background work") + + task = asyncio.create_task(_post_response()) + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) + + return response + def _is_pantalla_context_valid(self, last_modified: datetime) -> bool: """Check if pantallaContexto is still valid (not stale).""" time_diff = datetime.now(UTC) - last_modified diff --git a/src/capa_de_integracion/services/notifications.py b/src/capa_de_integracion/services/notifications.py index d3652fe..64d99bc 100644 --- a/src/capa_de_integracion/services/notifications.py +++ b/src/capa_de_integracion/services/notifications.py @@ -21,6 +21,11 @@ PREFIX_PO_PARAM = "notification_po_" _background_tasks: set[asyncio.Task] = set() +def get_background_tasks() -> set[asyncio.Task]: + """Return the set of pending background tasks (for graceful shutdown).""" + return _background_tasks + + class NotificationManagerService: """Manages notification processing and integration with conversations. diff --git a/src/capa_de_integracion/services/quick_reply/content.py b/src/capa_de_integracion/services/quick_reply/content.py index 415d789..1a4173d 100644 --- a/src/capa_de_integracion/services/quick_reply/content.py +++ b/src/capa_de_integracion/services/quick_reply/content.py @@ -149,13 +149,8 @@ class QuickReplyContentService: if quick_reply is None: logger.warning("Quick reply not found in cache for screen: %s", screen_id) - return QuickReplyScreen( - header=None, - body=None, - button=None, - header_section=None, - preguntas=[], - ) + msg = f"Quick reply not found for screen_id: {screen_id}" + raise ValueError(msg) logger.info( "Retrieved %s quick replies for screen: %s from cache", diff --git a/src/capa_de_integracion/services/quick_reply/session.py b/src/capa_de_integracion/services/quick_reply/session.py index 7a8c01c..325c954 100644 --- a/src/capa_de_integracion/services/quick_reply/session.py +++ b/src/capa_de_integracion/services/quick_reply/session.py @@ -1,6 +1,7 @@ """Quick reply session service for managing FAQ sessions.""" import logging +from datetime import UTC, datetime from uuid import uuid4 from capa_de_integracion.models.quick_replies import QuickReplyScreen @@ -99,6 +100,7 @@ class QuickReplySessionService: pantalla_contexto, ) session.pantalla_contexto = pantalla_contexto + session.last_modified = datetime.now(UTC) else: session_id = str(uuid4()) user_id = f"user_by_phone_{telefono.replace(' ', '').replace('-', '')}" diff --git a/src/capa_de_integracion/services/storage/firestore.py b/src/capa_de_integracion/services/storage/firestore.py index 3f49481..3ad2241 100644 --- a/src/capa_de_integracion/services/storage/firestore.py +++ b/src/capa_de_integracion/services/storage/firestore.py @@ -95,14 +95,13 @@ class FirestoreService: return session logger.debug("No session found in Firestore for phone: %s", telefono) + return None except Exception: logger.exception( "Error querying session by phone %s from Firestore:", telefono, ) return None - else: - return None async def save_session(self, session: ConversationSession) -> bool: """Save conversation session to Firestore.""" diff --git a/src/capa_de_integracion/services/storage/redis.py b/src/capa_de_integracion/services/storage/redis.py index 3868cf9..ac90b0f 100644 --- a/src/capa_de_integracion/services/storage/redis.py +++ b/src/capa_de_integracion/services/storage/redis.py @@ -104,12 +104,12 @@ class RedisService: phone_key = self._phone_to_session_key(session.telefono) try: - # Save session data + # Save session data and phone mapping in a single pipeline data = session.model_dump_json(by_alias=False) - await self.redis.setex(key, self.session_ttl, data) - - # Save phone-to-session mapping - await self.redis.setex(phone_key, self.session_ttl, session.session_id) + async with self.redis.pipeline(transaction=False) as pipe: + pipe.setex(key, self.session_ttl, data) + pipe.setex(phone_key, self.session_ttl, session.session_id) + await pipe.execute() logger.debug( "Saved session to Redis: %s for phone: %s", @@ -384,8 +384,10 @@ class RedisService: try: logger.info("Deleting notification session for phone %s", phone_number) - await self.redis.delete(notification_key) - await self.redis.delete(phone_key) + async with self.redis.pipeline(transaction=False) as pipe: + pipe.delete(notification_key) + pipe.delete(phone_key) + await pipe.execute() except Exception: logger.exception( "Error deleting notification session for phone %s:", diff --git a/tests/services/test_conversation_service.py b/tests/services/test_conversation_service.py index b20a5c4..6f573a8 100644 --- a/tests/services/test_conversation_service.py +++ b/tests/services/test_conversation_service.py @@ -1,5 +1,6 @@ """Unit tests for ConversationManagerService.""" +import asyncio from datetime import UTC, datetime, timedelta from typing import Literal from unittest.mock import AsyncMock, Mock, patch @@ -471,6 +472,7 @@ class TestQuickReplyPath: session=sample_session, ) + await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete assert mock_firestore.save_entry.await_count == 2 @pytest.mark.asyncio @@ -499,6 +501,7 @@ class TestQuickReplyPath: session=sample_session, ) + await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete mock_firestore.save_session.assert_awaited_once() mock_redis.save_session.assert_awaited_once() @@ -571,6 +574,7 @@ class TestStandardConversation: session=sample_session, ) + await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete assert mock_firestore.save_entry.await_count == 2 @pytest.mark.asyncio @@ -588,6 +592,7 @@ class TestStandardConversation: session=sample_session, ) + await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete # save_session is called in _update_session_after_turn assert mock_firestore.save_session.await_count >= 1 assert mock_redis.save_session.await_count >= 1 @@ -611,6 +616,7 @@ class TestStandardConversation: session=sample_session, ) + await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete mock_firestore.update_notification_status.assert_awaited_once() @pytest.mark.asyncio diff --git a/tests/services/test_firestore_service.py b/tests/services/test_firestore_service.py index 931f603..40b64a9 100644 --- a/tests/services/test_firestore_service.py +++ b/tests/services/test_firestore_service.py @@ -120,10 +120,8 @@ class TestSessionManagement: mock_collection = MagicMock() mock_where = MagicMock() - mock_order = MagicMock() mock_collection.where.return_value = mock_where - mock_where.order_by.return_value = mock_order - mock_order.limit.return_value = mock_query + mock_where.limit.return_value = mock_query original_collection = clean_firestore.db.collection clean_firestore.db.collection = MagicMock(return_value=mock_collection) diff --git a/tests/services/test_quick_reply_session.py b/tests/services/test_quick_reply_session.py index 10a272e..4facaa5 100644 --- a/tests/services/test_quick_reply_session.py +++ b/tests/services/test_quick_reply_session.py @@ -1,5 +1,6 @@ """Tests for QuickReplySessionService.""" +from datetime import UTC, datetime, timedelta from unittest.mock import AsyncMock, Mock from uuid import uuid4 @@ -160,6 +161,39 @@ async def test_start_session_existing_user(service, mock_firestore, mock_redis, mock_content.get_quick_replies.assert_called_once_with("pagos") +@pytest.mark.asyncio +async def test_start_session_updates_last_modified_on_existing( + service, mock_firestore, mock_redis, mock_content +): + """Test that last_modified is refreshed when updating pantalla_contexto. + + Ensures quick reply context won't be incorrectly marked as stale + when the session was idle before the user opened a quick reply screen. + """ + stale_time = datetime.now(UTC) - timedelta(minutes=20) + test_session = ConversationSession.create( + session_id="session-123", + user_id="user_by_phone_5551234", + telefono="555-1234", + pantalla_contexto=None, + ) + test_session.last_modified = stale_time + + mock_redis.get_session.return_value = test_session + mock_content.get_quick_replies.return_value = QuickReplyScreen( + header="H", body=None, button=None, header_section=None, preguntas=[] + ) + + await service.start_quick_reply_session( + telefono="555-1234", + _nombre="John", + pantalla_contexto="pagos", + ) + + saved_session = mock_redis.save_session.call_args[0][0] + assert saved_session.last_modified > stale_time + + @pytest.mark.asyncio async def test_start_session_invalid_phone(service): """Test starting session with invalid phone number.""" diff --git a/tests/test_main.py b/tests/test_main.py index 47b48ff..22b4837 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -47,14 +47,15 @@ def test_app_has_routers(): def test_main_entry_point(): """Test main entry point calls uvicorn.run.""" - with patch("capa_de_integracion.main.uvicorn.run") as mock_run: + with patch("capa_de_integracion.main.uvicorn.run") as mock_run, \ + patch("sys.argv", ["capa-de-integracion"]): main() mock_run.assert_called_once() call_kwargs = mock_run.call_args.kwargs assert call_kwargs["host"] == "0.0.0.0" assert call_kwargs["port"] == 8080 - assert call_kwargs["reload"] is True + assert call_kwargs["workers"] == 1 @pytest.mark.asyncio