Add test coverage

This commit is contained in:
2026-02-20 05:23:43 +00:00
parent c01f4d1ab3
commit 59a76fc226
3 changed files with 61 additions and 151 deletions

View File

@@ -1,151 +0,0 @@
"""RAG service for calling RAG endpoints with high concurrency."""
import logging
from types import TracebackType
from typing import Self
import httpx
from pydantic import BaseModel, Field
from capa_de_integracion.config import Settings
logger = logging.getLogger(__name__)
class Message(BaseModel):
"""OpenAI-style message format."""
role: str = Field(..., description="Role: system, user, or assistant")
content: str = Field(..., description="Message content")
class RAGRequest(BaseModel):
"""Request model for RAG endpoint."""
messages: list[Message] = Field(..., description="Conversation history")
class RAGResponse(BaseModel):
"""Response model from RAG endpoint."""
response: str = Field(..., description="Generated response from RAG")
class RAGService:
"""Highly concurrent HTTP client for calling RAG endpoints.
Uses httpx AsyncClient with connection pooling for optimal performance
when handling multiple concurrent requests.
"""
def __init__(
self,
settings: Settings,
max_connections: int = 100,
max_keepalive_connections: int = 20,
timeout: float = 30.0,
) -> None:
"""Initialize RAG service with connection pooling.
Args:
settings: Application settings
max_connections: Maximum number of concurrent connections
max_keepalive_connections: Maximum number of idle connections to keep alive
timeout: Request timeout in seconds
"""
self.settings = settings
self.rag_endpoint_url = settings.rag_endpoint_url
self.timeout = timeout
# Configure connection limits for high concurrency
limits = httpx.Limits(
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
)
# Create async client with connection pooling
self._client = httpx.AsyncClient(
limits=limits,
timeout=httpx.Timeout(timeout),
http2=True, # Enable HTTP/2 for better performance
)
logger.info(
"RAGService initialized with endpoint: %s, "
"max_connections: %s, timeout: %ss",
self.rag_endpoint_url,
max_connections,
timeout,
)
async def query(self, messages: list[dict[str, str]]) -> str:
"""Send conversation history to RAG endpoint and get response.
Args:
messages: OpenAI-style conversation history
e.g., [{"role": "user", "content": "Hello"}, ...]
Returns:
Response string from RAG endpoint
Raises:
httpx.HTTPError: If HTTP request fails
ValueError: If response format is invalid
"""
try:
# Validate and construct request
message_objects = [Message(**msg) for msg in messages]
request = RAGRequest(messages=message_objects)
# Make async HTTP POST request
logger.debug("Sending RAG request with %s messages", len(messages))
response = await self._client.post(
self.rag_endpoint_url,
json=request.model_dump(),
headers={"Content-Type": "application/json"},
)
# Raise exception for HTTP errors
response.raise_for_status()
# Parse response
response_data = response.json()
rag_response = RAGResponse(**response_data)
logger.debug("RAG response received: %s chars", len(rag_response.response))
except httpx.HTTPStatusError as e:
logger.exception(
"HTTP error calling RAG endpoint: %s - %s",
e.response.status_code,
e.response.text,
)
raise
except httpx.RequestError:
logger.exception("Request error calling RAG endpoint:")
raise
except Exception:
logger.exception("Unexpected error calling RAG endpoint")
raise
else:
return rag_response.response
async def close(self) -> None:
"""Close the HTTP client and release connections."""
await self._client.aclose()
logger.info("RAGService client closed")
async def __aenter__(self) -> Self:
"""Async context manager entry."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit."""
await self.close()