"""RAG service for calling RAG endpoints with high concurrency.""" import logging from types import TracebackType from typing import Self import httpx from pydantic import BaseModel, Field from capa_de_integracion.config import Settings logger = logging.getLogger(__name__) class Message(BaseModel): """OpenAI-style message format.""" role: str = Field(..., description="Role: system, user, or assistant") content: str = Field(..., description="Message content") class RAGRequest(BaseModel): """Request model for RAG endpoint.""" messages: list[Message] = Field(..., description="Conversation history") class RAGResponse(BaseModel): """Response model from RAG endpoint.""" response: str = Field(..., description="Generated response from RAG") class RAGService: """Highly concurrent HTTP client for calling RAG endpoints. Uses httpx AsyncClient with connection pooling for optimal performance when handling multiple concurrent requests. """ def __init__( self, settings: Settings, max_connections: int = 100, max_keepalive_connections: int = 20, timeout: float = 30.0, ) -> None: """Initialize RAG service with connection pooling. Args: settings: Application settings max_connections: Maximum number of concurrent connections max_keepalive_connections: Maximum number of idle connections to keep alive timeout: Request timeout in seconds """ self.settings = settings self.rag_endpoint_url = settings.rag_endpoint_url self.timeout = timeout # Configure connection limits for high concurrency limits = httpx.Limits( max_connections=max_connections, max_keepalive_connections=max_keepalive_connections, ) # Create async client with connection pooling self._client = httpx.AsyncClient( limits=limits, timeout=httpx.Timeout(timeout), http2=True, # Enable HTTP/2 for better performance ) logger.info( "RAGService initialized with endpoint: %s, " "max_connections: %s, timeout: %ss", self.rag_endpoint_url, max_connections, timeout, ) async def query(self, messages: list[dict[str, str]]) -> str: """Send conversation history to RAG endpoint and get response. Args: messages: OpenAI-style conversation history e.g., [{"role": "user", "content": "Hello"}, ...] Returns: Response string from RAG endpoint Raises: httpx.HTTPError: If HTTP request fails ValueError: If response format is invalid """ try: # Validate and construct request message_objects = [Message(**msg) for msg in messages] request = RAGRequest(messages=message_objects) # Make async HTTP POST request logger.debug("Sending RAG request with %s messages", len(messages)) response = await self._client.post( self.rag_endpoint_url, json=request.model_dump(), headers={"Content-Type": "application/json"}, ) # Raise exception for HTTP errors response.raise_for_status() # Parse response response_data = response.json() rag_response = RAGResponse(**response_data) logger.debug("RAG response received: %s chars", len(rag_response.response)) except httpx.HTTPStatusError as e: logger.exception( "HTTP error calling RAG endpoint: %s - %s", e.response.status_code, e.response.text, ) raise except httpx.RequestError: logger.exception("Request error calling RAG endpoint:") raise except Exception: logger.exception("Unexpected error calling RAG endpoint") raise else: return rag_response.response async def close(self) -> None: """Close the HTTP client and release connections.""" await self._client.aclose() logger.info("RAGService client closed") async def __aenter__(self) -> Self: """Async context manager entry.""" return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """Async context manager exit.""" await self.close()