Files
agent/scripts/diagnose_rag_endpoint.py
Anibal Angulo a53f8fcf62 First commit
2026-02-18 19:57:43 +00:00

99 lines
4.0 KiB
Python

import asyncio
import logging
import random
import typer
import httpx
CONTENT_LIST = [
"¿Cuáles son los beneficios de una tarjeta de crédito?",
"¿Cómo puedo abrir una cuenta de ahorros?",
"¿Qué es una hipoteca y cómo funciona?",
"¿Cuáles son las tasas de interés para un préstamo personal?",
"¿Cómo puedo solicitar un préstamo para un coche?",
"¿Qué es la banca en línea y cómo me registro?",
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
"¿Qué es el phishing y cómo puedo protegerme?",
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
"¿Cómo puedo transferir dinero a una cuenta internacional?",
]
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
app = typer.Typer()
async def call_rag_endpoint_task(client: httpx.AsyncClient, url: str):
"""A single task to send one request to the RAG endpoint."""
question = random.choice(CONTENT_LIST)
json_payload = {
"sessionInfo": {
"parameters": {
"query": question
}
}
}
response = await client.post(url, json=json_payload)
response.raise_for_status() # Raise an exception for bad status codes
response_data = response.json()
response_text = response_data["sessionInfo"]["parameters"]["response"]
logger.info(f"Question: {question[:50]}... Response: {response_text[:100]}...")
async def run_test(concurrency: int, url: str, timeout_seconds: float):
"""Continuously calls the RAG endpoint and tracks requests."""
total_requests = 0
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on endpoint '{url}'.")
logger.info(f"Request timeout is set to {timeout_seconds} seconds.")
logger.info("Press Ctrl+C to stop.")
timeout = httpx.Timeout(timeout_seconds)
async with httpx.AsyncClient(timeout=timeout) as client:
while True:
tasks = [call_rag_endpoint_task(client, url) for _ in range(concurrency)]
try:
await asyncio.gather(*tasks)
total_requests += concurrency
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
except httpx.TimeoutException as e:
logger.error(f"A request timed out: {e.request.method} {e.request.url}")
logger.error("Consider increasing the timeout with the --timeout option.")
break
except httpx.HTTPStatusError as e:
logger.error(f"An HTTP error occurred: {e.response.status_code} - {e.request.method} {e.request.url}")
logger.error(f"Response body: {e.response.text}")
break
except httpx.RequestError as e:
logger.error(f"A request error occurred: {e.request.method} {e.request.url}")
logger.error(f"Error details: {e}")
break
except Exception as e:
logger.error("Caught an unexpected error. Stopping test.")
print("\n--- STATS ---")
print(f"Total successful requests: {total_requests}")
print(f"Concurrent requests during failure: {concurrency}")
print(f"Error Type: {e.__class__.__name__}")
print(f"Error Details: {e}")
print("-------------")
break
@app.command()
def main(
concurrency: int = typer.Option(
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
),
url: str = typer.Option(
"http://127.0.0.1:8000/sigma-rag", "--url", "-u", help="The URL of the RAG endpoint to test."
),
timeout_seconds: float = typer.Option(
30.0, "--timeout", "-t", help="Request timeout in seconds."
)
):
try:
asyncio.run(run_test(concurrency, url, timeout_seconds))
except KeyboardInterrupt:
logger.info("\nKeyboard interrupt received. Exiting.")
if __name__ == "__main__":
app()