152 lines
4.6 KiB
Python
152 lines
4.6 KiB
Python
"""RAG service for calling RAG endpoints with high concurrency."""
|
|
|
|
import logging
|
|
from types import TracebackType
|
|
from typing import Self
|
|
|
|
import httpx
|
|
from pydantic import BaseModel, Field
|
|
|
|
from capa_de_integracion.config import Settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Message(BaseModel):
|
|
"""OpenAI-style message format."""
|
|
|
|
role: str = Field(..., description="Role: system, user, or assistant")
|
|
content: str = Field(..., description="Message content")
|
|
|
|
|
|
class RAGRequest(BaseModel):
|
|
"""Request model for RAG endpoint."""
|
|
|
|
messages: list[Message] = Field(..., description="Conversation history")
|
|
|
|
|
|
class RAGResponse(BaseModel):
|
|
"""Response model from RAG endpoint."""
|
|
|
|
response: str = Field(..., description="Generated response from RAG")
|
|
|
|
|
|
class RAGService:
|
|
"""Highly concurrent HTTP client for calling RAG endpoints.
|
|
|
|
Uses httpx AsyncClient with connection pooling for optimal performance
|
|
when handling multiple concurrent requests.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
settings: Settings,
|
|
max_connections: int = 100,
|
|
max_keepalive_connections: int = 20,
|
|
timeout: float = 30.0,
|
|
) -> None:
|
|
"""Initialize RAG service with connection pooling.
|
|
|
|
Args:
|
|
settings: Application settings
|
|
max_connections: Maximum number of concurrent connections
|
|
max_keepalive_connections: Maximum number of idle connections to keep alive
|
|
timeout: Request timeout in seconds
|
|
|
|
"""
|
|
self.settings = settings
|
|
self.rag_endpoint_url = settings.rag_endpoint_url
|
|
self.timeout = timeout
|
|
|
|
# Configure connection limits for high concurrency
|
|
limits = httpx.Limits(
|
|
max_connections=max_connections,
|
|
max_keepalive_connections=max_keepalive_connections,
|
|
)
|
|
|
|
# Create async client with connection pooling
|
|
self._client = httpx.AsyncClient(
|
|
limits=limits,
|
|
timeout=httpx.Timeout(timeout),
|
|
http2=True, # Enable HTTP/2 for better performance
|
|
)
|
|
|
|
logger.info(
|
|
"RAGService initialized with endpoint: %s, "
|
|
"max_connections: %s, timeout: %ss",
|
|
self.rag_endpoint_url,
|
|
max_connections,
|
|
timeout,
|
|
)
|
|
|
|
async def query(self, messages: list[dict[str, str]]) -> str:
|
|
"""Send conversation history to RAG endpoint and get response.
|
|
|
|
Args:
|
|
messages: OpenAI-style conversation history
|
|
e.g., [{"role": "user", "content": "Hello"}, ...]
|
|
|
|
Returns:
|
|
Response string from RAG endpoint
|
|
|
|
Raises:
|
|
httpx.HTTPError: If HTTP request fails
|
|
ValueError: If response format is invalid
|
|
|
|
"""
|
|
try:
|
|
# Validate and construct request
|
|
message_objects = [Message(**msg) for msg in messages]
|
|
request = RAGRequest(messages=message_objects)
|
|
|
|
# Make async HTTP POST request
|
|
logger.debug("Sending RAG request with %s messages", len(messages))
|
|
|
|
response = await self._client.post(
|
|
self.rag_endpoint_url,
|
|
json=request.model_dump(),
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
|
|
# Raise exception for HTTP errors
|
|
response.raise_for_status()
|
|
|
|
# Parse response
|
|
response_data = response.json()
|
|
rag_response = RAGResponse(**response_data)
|
|
|
|
logger.debug("RAG response received: %s chars", len(rag_response.response))
|
|
except httpx.HTTPStatusError as e:
|
|
logger.exception(
|
|
"HTTP error calling RAG endpoint: %s - %s",
|
|
e.response.status_code,
|
|
e.response.text,
|
|
)
|
|
raise
|
|
except httpx.RequestError:
|
|
logger.exception("Request error calling RAG endpoint:")
|
|
raise
|
|
except Exception:
|
|
logger.exception("Unexpected error calling RAG endpoint")
|
|
raise
|
|
else:
|
|
return rag_response.response
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP client and release connections."""
|
|
await self._client.aclose()
|
|
logger.info("RAGService client closed")
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Async context manager exit."""
|
|
await self.close()
|