import asyncio import logging import os import random import typer from dotenv import load_dotenv from embedder.vertex_ai import VertexAIEmbedder load_dotenv() project = os.getenv("GOOGLE_CLOUD_PROJECT") location = os.getenv("GOOGLE_CLOUD_LOCATION") MODEL_NAME = "gemini-embedding-001" CONTENT_LIST = [ "¿Cuáles son los beneficios de una tarjeta de crédito?", "¿Cómo puedo abrir una cuenta de ahorros?", "¿Qué es una hipoteca y cómo funciona?", "¿Cuáles son las tasas de interés para un préstamo personal?", "¿Cómo puedo solicitar un préstamo para un coche?", "¿Qué es la banca en línea y cómo me registro?", "¿Cómo puedo reportar una tarjeta de crédito perdida o robada?", "¿Qué es el phishing y cómo puedo protegerme?", "¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?", "¿Cómo puedo transferir dinero a una cuenta internacional?", ] TASK_TYPE = "RETRIEVAL_DOCUMENT" logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) app = typer.Typer() logger.info(f"Initializing GenAI Client for project '{project}' in '{location}'") embedder = VertexAIEmbedder(MODEL_NAME, project, location) async def embed_content_task(): """A single task to send one embedding request using the global client.""" content_to_embed = random.choice(CONTENT_LIST) await embedder.async_generate_embedding(content_to_embed) async def run_test(concurrency: int): """Continuously calls the embedding API and tracks requests.""" total_requests = 0 logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on model '{MODEL_NAME}'.") logger.info("Press Ctrl+C to stop.") while True: # Create tasks, passing project_id and location tasks = [embed_content_task() for _ in range(concurrency)] try: await asyncio.gather(*tasks) total_requests += concurrency logger.info(f"Successfully completed batch. Total requests so far: {total_requests}") except Exception as e: logger.error("Caught an error. Stopping test.") print("\n--- STATS ---") print(f"Total successful requests: {total_requests}") print(f"Concurrent requests during failure: {concurrency}") print(f"Error Type: {e.__class__.__name__}") print(f"Error Details: {e}") print("-------------") break @app.command() def main( concurrency: int = typer.Option( 10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch." ), ): try: asyncio.run(run_test(concurrency)) except KeyboardInterrupt: logger.info("\nKeyboard interrupt received. Exiting.") if __name__ == "__main__": app()