First commit
This commit is contained in:
98
scripts/diagnose_rag_endpoint.py
Normal file
98
scripts/diagnose_rag_endpoint.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import typer
|
||||
import httpx
|
||||
|
||||
CONTENT_LIST = [
|
||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
||||
"¿Qué es una hipoteca y cómo funciona?",
|
||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
||||
"¿Qué es la banca en línea y cómo me registro?",
|
||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
||||
]
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
async def call_rag_endpoint_task(client: httpx.AsyncClient, url: str):
|
||||
"""A single task to send one request to the RAG endpoint."""
|
||||
question = random.choice(CONTENT_LIST)
|
||||
json_payload = {
|
||||
"sessionInfo": {
|
||||
"parameters": {
|
||||
"query": question
|
||||
}
|
||||
}
|
||||
}
|
||||
response = await client.post(url, json=json_payload)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
response_data = response.json()
|
||||
response_text = response_data["sessionInfo"]["parameters"]["response"]
|
||||
logger.info(f"Question: {question[:50]}... Response: {response_text[:100]}...")
|
||||
|
||||
async def run_test(concurrency: int, url: str, timeout_seconds: float):
|
||||
"""Continuously calls the RAG endpoint and tracks requests."""
|
||||
total_requests = 0
|
||||
|
||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on endpoint '{url}'.")
|
||||
logger.info(f"Request timeout is set to {timeout_seconds} seconds.")
|
||||
logger.info("Press Ctrl+C to stop.")
|
||||
|
||||
timeout = httpx.Timeout(timeout_seconds)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
while True:
|
||||
tasks = [call_rag_endpoint_task(client, url) for _ in range(concurrency)]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
total_requests += concurrency
|
||||
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"A request timed out: {e.request.method} {e.request.url}")
|
||||
logger.error("Consider increasing the timeout with the --timeout option.")
|
||||
break
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"An HTTP error occurred: {e.response.status_code} - {e.request.method} {e.request.url}")
|
||||
logger.error(f"Response body: {e.response.text}")
|
||||
break
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"A request error occurred: {e.request.method} {e.request.url}")
|
||||
logger.error(f"Error details: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Caught an unexpected error. Stopping test.")
|
||||
print("\n--- STATS ---")
|
||||
print(f"Total successful requests: {total_requests}")
|
||||
print(f"Concurrent requests during failure: {concurrency}")
|
||||
print(f"Error Type: {e.__class__.__name__}")
|
||||
print(f"Error Details: {e}")
|
||||
print("-------------")
|
||||
break
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
concurrency: int = typer.Option(
|
||||
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
|
||||
),
|
||||
url: str = typer.Option(
|
||||
"http://127.0.0.1:8000/sigma-rag", "--url", "-u", help="The URL of the RAG endpoint to test."
|
||||
),
|
||||
timeout_seconds: float = typer.Option(
|
||||
30.0, "--timeout", "-t", help="Request timeout in seconds."
|
||||
)
|
||||
):
|
||||
try:
|
||||
asyncio.run(run_test(concurrency, url, timeout_seconds))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nKeyboard interrupt received. Exiting.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
Reference in New Issue
Block a user