81 lines
2.9 KiB
Python
81 lines
2.9 KiB
Python
import asyncio
|
|
import os
|
|
import logging
|
|
import typer
|
|
import random
|
|
from google import genai
|
|
from google.genai import types
|
|
from dotenv import load_dotenv
|
|
from embedder.vertex_ai import VertexAIEmbedder
|
|
|
|
load_dotenv()
|
|
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
location = os.getenv("GOOGLE_CLOUD_LOCATION")
|
|
|
|
MODEL_NAME = "gemini-embedding-001"
|
|
CONTENT_LIST = [
|
|
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
|
"¿Cómo puedo abrir una cuenta de ahorros?",
|
|
"¿Qué es una hipoteca y cómo funciona?",
|
|
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
|
"¿Cómo puedo solicitar un préstamo para un coche?",
|
|
"¿Qué es la banca en línea y cómo me registro?",
|
|
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
|
"¿Qué es el phishing y cómo puedo protegerme?",
|
|
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
|
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
|
]
|
|
TASK_TYPE = "RETRIEVAL_DOCUMENT"
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = typer.Typer()
|
|
|
|
logger.info(f"Initializing GenAI Client for project '{project}' in '{location}'")
|
|
embedder = VertexAIEmbedder(MODEL_NAME, project, location)
|
|
|
|
async def embed_content_task():
|
|
"""A single task to send one embedding request using the global client."""
|
|
content_to_embed = random.choice(CONTENT_LIST)
|
|
await embedder.async_generate_embedding(content_to_embed)
|
|
|
|
async def run_test(concurrency: int):
|
|
"""Continuously calls the embedding API and tracks requests."""
|
|
total_requests = 0
|
|
|
|
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on model '{MODEL_NAME}'.")
|
|
logger.info("Press Ctrl+C to stop.")
|
|
|
|
while True:
|
|
# Create tasks, passing project_id and location
|
|
tasks = [embed_content_task() for _ in range(concurrency)]
|
|
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
total_requests += concurrency
|
|
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
|
|
except Exception as e:
|
|
logger.error("Caught an error. Stopping test.")
|
|
print("\n--- STATS ---")
|
|
print(f"Total successful requests: {total_requests}")
|
|
print(f"Concurrent requests during failure: {concurrency}")
|
|
print(f"Error Type: {e.__class__.__name__}")
|
|
print(f"Error Details: {e}")
|
|
print("-------------")
|
|
break
|
|
|
|
@app.command()
|
|
def main(
|
|
concurrency: int = typer.Option(
|
|
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
|
|
),
|
|
):
|
|
try:
|
|
asyncio.run(run_test(concurrency))
|
|
except KeyboardInterrupt:
|
|
logger.info("\nKeyboard interrupt received. Exiting.")
|
|
|
|
if __name__ == "__main__":
|
|
app()
|