Optimization

This commit is contained in:
2026-02-20 15:59:19 +00:00
parent ade4689ab7
commit 383efed319
12 changed files with 168 additions and 78 deletions

View File

@@ -74,7 +74,7 @@ filterwarnings = [
] ]
env = [ env = [
"FIRESTORE_EMULATOR_HOST=[::1]:8469", "FIRESTORE_EMULATOR_HOST=[::1]:8462",
"GCP_PROJECT_ID=test-project", "GCP_PROJECT_ID=test-project",
"GCP_LOCATION=us-central1", "GCP_LOCATION=us-central1",
"GCP_FIRESTORE_DATABASE_ID=(default)", "GCP_FIRESTORE_DATABASE_ID=(default)",

View File

@@ -1,5 +1,7 @@
"""Dependency injection and service lifecycle management.""" """Dependency injection and service lifecycle management."""
import asyncio
import logging
from functools import lru_cache from functools import lru_cache
from capa_de_integracion.services.rag import ( from capa_de_integracion.services.rag import (
@@ -16,8 +18,12 @@ from .services import (
QuickReplyContentService, QuickReplyContentService,
QuickReplySessionService, QuickReplySessionService,
) )
from .services.conversation import get_background_tasks as conv_bg_tasks
from .services.notifications import get_background_tasks as notif_bg_tasks
from .services.storage import FirestoreService, RedisService from .services.storage import FirestoreService, RedisService
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def get_redis_service() -> RedisService: def get_redis_service() -> RedisService:
@@ -106,6 +112,12 @@ async def startup_services() -> None:
async def shutdown_services() -> None: async def shutdown_services() -> None:
"""Close all service connections on shutdown.""" """Close all service connections on shutdown."""
# Drain in-flight background tasks before closing connections
all_tasks = conv_bg_tasks() | notif_bg_tasks()
if all_tasks:
logger.info("Draining %d background tasks before shutdown…", len(all_tasks))
await asyncio.gather(*all_tasks, return_exceptions=True)
# Close Redis # Close Redis
redis = get_redis_service() redis = get_redis_service()
await redis.close() await redis.close()

View File

@@ -1,5 +1,6 @@
"""Conversation manager service for orchestrating user conversations.""" """Conversation manager service for orchestrating user conversations."""
import asyncio
import logging import logging
import re import re
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@@ -22,6 +23,14 @@ from capa_de_integracion.services.storage.redis import RedisService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Keep references to background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[None]] = set()
def get_background_tasks() -> set[asyncio.Task[None]]:
"""Return the set of pending background tasks (for graceful shutdown)."""
return _background_tasks
MSG_EMPTY_MESSAGE = "Message cannot be empty" MSG_EMPTY_MESSAGE = "Message cannot be empty"
@@ -88,16 +97,16 @@ class ConversationManagerService:
# Step 1: Validate message is not empty # Step 1: Validate message is not empty
self._validate_message(request.mensaje) self._validate_message(request.mensaje)
# Step 2: Apply DLP security # Step 2+3: Apply DLP security and obtain session in parallel
obfuscated_message = await self.dlp_service.get_obfuscated_string( telefono = request.usuario.telefono
request.mensaje, obfuscated_message, session = await asyncio.gather(
self.settings.dlp_template_complete_flow, self.dlp_service.get_obfuscated_string(
request.mensaje,
self.settings.dlp_template_complete_flow,
),
self._obtain_or_create_session(telefono),
) )
request.mensaje = obfuscated_message request.mensaje = obfuscated_message
telefono = request.usuario.telefono
# Step 3: Obtain or create session
session = await self._obtain_or_create_session(telefono)
# Step 4: Try quick reply path first # Step 4: Try quick reply path first
response = await self._handle_quick_reply_path(request, session) response = await self._handle_quick_reply_path(request, session)
@@ -131,6 +140,8 @@ class ConversationManagerService:
# Try Firestore if Redis miss # Try Firestore if Redis miss
session = await self.firestore_service.get_session_by_phone(telefono) session = await self.firestore_service.get_session_by_phone(telefono)
if session: if session:
# Cache to Redis for subsequent requests
await self.redis_service.save_session(session)
return session return session
# Create new session if both miss # Create new session if both miss
@@ -165,27 +176,31 @@ class ConversationManagerService:
canal: Communication channel canal: Communication channel
""" """
# Save user entry # Save user and assistant entries in parallel.
# Use a single timestamp for both, but offset the assistant entry by 1µs
# to avoid Firestore document ID collision (save_entry uses isoformat()
# as the document ID).
now = datetime.now(UTC)
user_entry = ConversationEntry( user_entry = ConversationEntry(
entity="user", entity="user",
type=entry_type, type=entry_type,
timestamp=datetime.now(UTC), timestamp=now,
text=user_text, text=user_text,
parameters=None, parameters=None,
canal=canal, canal=canal,
) )
await self.firestore_service.save_entry(session_id, user_entry)
# Save assistant entry
assistant_entry = ConversationEntry( assistant_entry = ConversationEntry(
entity="assistant", entity="assistant",
type=entry_type, type=entry_type,
timestamp=datetime.now(UTC), timestamp=now + timedelta(microseconds=1),
text=assistant_text, text=assistant_text,
parameters=None, parameters=None,
canal=canal, canal=canal,
) )
await self.firestore_service.save_entry(session_id, assistant_entry) await asyncio.gather(
self.firestore_service.save_entry(session_id, user_entry),
self.firestore_service.save_entry(session_id, assistant_entry),
)
async def _update_session_after_turn( async def _update_session_after_turn(
self, self,
@@ -204,8 +219,10 @@ class ConversationManagerService:
""" """
session.last_message = last_message session.last_message = last_message
session.last_modified = datetime.now(UTC) session.last_modified = datetime.now(UTC)
await self.firestore_service.save_session(session) await asyncio.gather(
await self.redis_service.save_session(session) self.firestore_service.save_session(session),
self.redis_service.save_session(session),
)
async def _handle_quick_reply_path( async def _handle_quick_reply_path(
self, self,
@@ -253,17 +270,25 @@ class ConversationManagerService:
response.query_result.response_text if response.query_result else "" response.query_result.response_text if response.query_result else ""
) or "" ) or ""
# Save conversation turn # Fire-and-forget: persist conversation turn and update session
await self._save_conversation_turn( async def _post_response() -> None:
session_id=session.session_id, try:
user_text=request.mensaje, await asyncio.gather(
assistant_text=response_text, self._save_conversation_turn(
entry_type="CONVERSACION", session_id=session.session_id,
canal=getattr(request, "canal", None), user_text=request.mensaje,
) assistant_text=response_text,
entry_type="CONVERSACION",
canal=getattr(request, "canal", None),
),
self._update_session_after_turn(session, response_text),
)
except Exception:
logger.exception("Error in quick-reply post-response work")
# Update session task = asyncio.create_task(_post_response())
await self._update_session_after_turn(session, response_text) _background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
return response return response
@@ -292,13 +317,19 @@ class ConversationManagerService:
telefono, telefono,
) )
# Load conversation history only if session is older than threshold # Load conversation history and notifications in parallel
# (optimization: new/recent sessions don't need history context)
session_age = datetime.now(UTC) - session.created_at session_age = datetime.now(UTC) - session.created_at
if session_age > timedelta(minutes=self.SESSION_RESET_THRESHOLD_MINUTES): load_history = session_age > timedelta(
entries = await self.firestore_service.get_entries( minutes=self.SESSION_RESET_THRESHOLD_MINUTES,
session.session_id, )
limit=self.settings.conversation_context_message_limit,
if load_history:
entries, notifications = await asyncio.gather(
self.firestore_service.get_entries(
session.session_id,
limit=self.settings.conversation_context_message_limit,
),
self._get_active_notifications(telefono),
) )
logger.info( logger.info(
"Session is %s minutes old. Loaded %s conversation entries.", "Session is %s minutes old. Loaded %s conversation entries.",
@@ -307,13 +338,12 @@ class ConversationManagerService:
) )
else: else:
entries = [] entries = []
notifications = await self._get_active_notifications(telefono)
logger.info( logger.info(
"Session is only %s minutes old. Skipping history load.", "Session is only %s minutes old. Skipping history load.",
session_age.total_seconds() / 60, session_age.total_seconds() / 60,
) )
# Retrieve active notifications for this user
notifications = await self._get_active_notifications(telefono)
logger.info("Retrieved %s active notifications", len(notifications)) logger.info("Retrieved %s active notifications", len(notifications))
# Prepare current user message # Prepare current user message
@@ -344,27 +374,8 @@ class ConversationManagerService:
assistant_response[:100], assistant_response[:100],
) )
# Save conversation turn # Build response object first, then fire-and-forget persistence
await self._save_conversation_turn( response = DetectIntentResponse(
session_id=session.session_id,
user_text=request.mensaje,
assistant_text=assistant_response,
entry_type="LLM",
canal=getattr(request, "canal", None),
)
logger.info("Saved user message and assistant response to Firestore")
# Update session
await self._update_session_after_turn(session, assistant_response)
logger.info("Updated session in Firestore and Redis")
# Mark notifications as processed if any were included
if notifications:
await self._mark_notifications_as_processed(telefono)
logger.info("Marked %s notifications as processed", len(notifications))
# Return response object
return DetectIntentResponse(
responseId=str(uuid4()), responseId=str(uuid4()),
queryResult=QueryResult( queryResult=QueryResult(
responseText=assistant_response, responseText=assistant_response,
@@ -373,6 +384,31 @@ class ConversationManagerService:
quick_replies=None, quick_replies=None,
) )
# Fire-and-forget: persist conversation and update session
async def _post_response() -> None:
try:
coros = [
self._save_conversation_turn(
session_id=session.session_id,
user_text=request.mensaje,
assistant_text=assistant_response,
entry_type="LLM",
canal=getattr(request, "canal", None),
),
self._update_session_after_turn(session, assistant_response),
]
if notifications:
coros.append(self._mark_notifications_as_processed(telefono))
await asyncio.gather(*coros)
except Exception:
logger.exception("Error in post-response background work")
task = asyncio.create_task(_post_response())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
return response
def _is_pantalla_context_valid(self, last_modified: datetime) -> bool: def _is_pantalla_context_valid(self, last_modified: datetime) -> bool:
"""Check if pantallaContexto is still valid (not stale).""" """Check if pantallaContexto is still valid (not stale)."""
time_diff = datetime.now(UTC) - last_modified time_diff = datetime.now(UTC) - last_modified

View File

@@ -21,6 +21,11 @@ PREFIX_PO_PARAM = "notification_po_"
_background_tasks: set[asyncio.Task] = set() _background_tasks: set[asyncio.Task] = set()
def get_background_tasks() -> set[asyncio.Task]:
"""Return the set of pending background tasks (for graceful shutdown)."""
return _background_tasks
class NotificationManagerService: class NotificationManagerService:
"""Manages notification processing and integration with conversations. """Manages notification processing and integration with conversations.

View File

@@ -149,13 +149,8 @@ class QuickReplyContentService:
if quick_reply is None: if quick_reply is None:
logger.warning("Quick reply not found in cache for screen: %s", screen_id) logger.warning("Quick reply not found in cache for screen: %s", screen_id)
return QuickReplyScreen( msg = f"Quick reply not found for screen_id: {screen_id}"
header=None, raise ValueError(msg)
body=None,
button=None,
header_section=None,
preguntas=[],
)
logger.info( logger.info(
"Retrieved %s quick replies for screen: %s from cache", "Retrieved %s quick replies for screen: %s from cache",

View File

@@ -1,6 +1,7 @@
"""Quick reply session service for managing FAQ sessions.""" """Quick reply session service for managing FAQ sessions."""
import logging import logging
from datetime import UTC, datetime
from uuid import uuid4 from uuid import uuid4
from capa_de_integracion.models.quick_replies import QuickReplyScreen from capa_de_integracion.models.quick_replies import QuickReplyScreen
@@ -99,6 +100,7 @@ class QuickReplySessionService:
pantalla_contexto, pantalla_contexto,
) )
session.pantalla_contexto = pantalla_contexto session.pantalla_contexto = pantalla_contexto
session.last_modified = datetime.now(UTC)
else: else:
session_id = str(uuid4()) session_id = str(uuid4())
user_id = f"user_by_phone_{telefono.replace(' ', '').replace('-', '')}" user_id = f"user_by_phone_{telefono.replace(' ', '').replace('-', '')}"

View File

@@ -95,14 +95,13 @@ class FirestoreService:
return session return session
logger.debug("No session found in Firestore for phone: %s", telefono) logger.debug("No session found in Firestore for phone: %s", telefono)
return None
except Exception: except Exception:
logger.exception( logger.exception(
"Error querying session by phone %s from Firestore:", "Error querying session by phone %s from Firestore:",
telefono, telefono,
) )
return None return None
else:
return None
async def save_session(self, session: ConversationSession) -> bool: async def save_session(self, session: ConversationSession) -> bool:
"""Save conversation session to Firestore.""" """Save conversation session to Firestore."""

View File

@@ -104,12 +104,12 @@ class RedisService:
phone_key = self._phone_to_session_key(session.telefono) phone_key = self._phone_to_session_key(session.telefono)
try: try:
# Save session data # Save session data and phone mapping in a single pipeline
data = session.model_dump_json(by_alias=False) data = session.model_dump_json(by_alias=False)
await self.redis.setex(key, self.session_ttl, data) async with self.redis.pipeline(transaction=False) as pipe:
pipe.setex(key, self.session_ttl, data)
# Save phone-to-session mapping pipe.setex(phone_key, self.session_ttl, session.session_id)
await self.redis.setex(phone_key, self.session_ttl, session.session_id) await pipe.execute()
logger.debug( logger.debug(
"Saved session to Redis: %s for phone: %s", "Saved session to Redis: %s for phone: %s",
@@ -384,8 +384,10 @@ class RedisService:
try: try:
logger.info("Deleting notification session for phone %s", phone_number) logger.info("Deleting notification session for phone %s", phone_number)
await self.redis.delete(notification_key) async with self.redis.pipeline(transaction=False) as pipe:
await self.redis.delete(phone_key) pipe.delete(notification_key)
pipe.delete(phone_key)
await pipe.execute()
except Exception: except Exception:
logger.exception( logger.exception(
"Error deleting notification session for phone %s:", "Error deleting notification session for phone %s:",

View File

@@ -1,5 +1,6 @@
"""Unit tests for ConversationManagerService.""" """Unit tests for ConversationManagerService."""
import asyncio
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Literal from typing import Literal
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
@@ -471,6 +472,7 @@ class TestQuickReplyPath:
session=sample_session, session=sample_session,
) )
await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete
assert mock_firestore.save_entry.await_count == 2 assert mock_firestore.save_entry.await_count == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -499,6 +501,7 @@ class TestQuickReplyPath:
session=sample_session, session=sample_session,
) )
await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete
mock_firestore.save_session.assert_awaited_once() mock_firestore.save_session.assert_awaited_once()
mock_redis.save_session.assert_awaited_once() mock_redis.save_session.assert_awaited_once()
@@ -571,6 +574,7 @@ class TestStandardConversation:
session=sample_session, session=sample_session,
) )
await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete
assert mock_firestore.save_entry.await_count == 2 assert mock_firestore.save_entry.await_count == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -588,6 +592,7 @@ class TestStandardConversation:
session=sample_session, session=sample_session,
) )
await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete
# save_session is called in _update_session_after_turn # save_session is called in _update_session_after_turn
assert mock_firestore.save_session.await_count >= 1 assert mock_firestore.save_session.await_count >= 1
assert mock_redis.save_session.await_count >= 1 assert mock_redis.save_session.await_count >= 1
@@ -611,6 +616,7 @@ class TestStandardConversation:
session=sample_session, session=sample_session,
) )
await asyncio.sleep(0.01) # Let fire-and-forget background tasks complete
mock_firestore.update_notification_status.assert_awaited_once() mock_firestore.update_notification_status.assert_awaited_once()
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -120,10 +120,8 @@ class TestSessionManagement:
mock_collection = MagicMock() mock_collection = MagicMock()
mock_where = MagicMock() mock_where = MagicMock()
mock_order = MagicMock()
mock_collection.where.return_value = mock_where mock_collection.where.return_value = mock_where
mock_where.order_by.return_value = mock_order mock_where.limit.return_value = mock_query
mock_order.limit.return_value = mock_query
original_collection = clean_firestore.db.collection original_collection = clean_firestore.db.collection
clean_firestore.db.collection = MagicMock(return_value=mock_collection) clean_firestore.db.collection = MagicMock(return_value=mock_collection)

View File

@@ -1,5 +1,6 @@
"""Tests for QuickReplySessionService.""" """Tests for QuickReplySessionService."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock, Mock
from uuid import uuid4 from uuid import uuid4
@@ -160,6 +161,39 @@ async def test_start_session_existing_user(service, mock_firestore, mock_redis,
mock_content.get_quick_replies.assert_called_once_with("pagos") mock_content.get_quick_replies.assert_called_once_with("pagos")
@pytest.mark.asyncio
async def test_start_session_updates_last_modified_on_existing(
service, mock_firestore, mock_redis, mock_content
):
"""Test that last_modified is refreshed when updating pantalla_contexto.
Ensures quick reply context won't be incorrectly marked as stale
when the session was idle before the user opened a quick reply screen.
"""
stale_time = datetime.now(UTC) - timedelta(minutes=20)
test_session = ConversationSession.create(
session_id="session-123",
user_id="user_by_phone_5551234",
telefono="555-1234",
pantalla_contexto=None,
)
test_session.last_modified = stale_time
mock_redis.get_session.return_value = test_session
mock_content.get_quick_replies.return_value = QuickReplyScreen(
header="H", body=None, button=None, header_section=None, preguntas=[]
)
await service.start_quick_reply_session(
telefono="555-1234",
_nombre="John",
pantalla_contexto="pagos",
)
saved_session = mock_redis.save_session.call_args[0][0]
assert saved_session.last_modified > stale_time
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_session_invalid_phone(service): async def test_start_session_invalid_phone(service):
"""Test starting session with invalid phone number.""" """Test starting session with invalid phone number."""

View File

@@ -47,14 +47,15 @@ def test_app_has_routers():
def test_main_entry_point(): def test_main_entry_point():
"""Test main entry point calls uvicorn.run.""" """Test main entry point calls uvicorn.run."""
with patch("capa_de_integracion.main.uvicorn.run") as mock_run: with patch("capa_de_integracion.main.uvicorn.run") as mock_run, \
patch("sys.argv", ["capa-de-integracion"]):
main() main()
mock_run.assert_called_once() mock_run.assert_called_once()
call_kwargs = mock_run.call_args.kwargs call_kwargs = mock_run.call_args.kwargs
assert call_kwargs["host"] == "0.0.0.0" assert call_kwargs["host"] == "0.0.0.0"
assert call_kwargs["port"] == 8080 assert call_kwargs["port"] == 8080
assert call_kwargs["reload"] is True assert call_kwargs["workers"] == 1
@pytest.mark.asyncio @pytest.mark.asyncio