Files
agent/scripts/diagnose_embeddings.py
2026-02-20 14:04:59 +00:00

80 lines
2.8 KiB
Python

import asyncio
import logging
import os
import random
import typer
from dotenv import load_dotenv
from embedder.vertex_ai import VertexAIEmbedder
load_dotenv()
project = os.getenv("GOOGLE_CLOUD_PROJECT")
location = os.getenv("GOOGLE_CLOUD_LOCATION")
MODEL_NAME = "gemini-embedding-001"
CONTENT_LIST = [
"¿Cuáles son los beneficios de una tarjeta de crédito?",
"¿Cómo puedo abrir una cuenta de ahorros?",
"¿Qué es una hipoteca y cómo funciona?",
"¿Cuáles son las tasas de interés para un préstamo personal?",
"¿Cómo puedo solicitar un préstamo para un coche?",
"¿Qué es la banca en línea y cómo me registro?",
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
"¿Qué es el phishing y cómo puedo protegerme?",
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
"¿Cómo puedo transferir dinero a una cuenta internacional?",
]
TASK_TYPE = "RETRIEVAL_DOCUMENT"
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
app = typer.Typer()
logger.info(f"Initializing GenAI Client for project '{project}' in '{location}'")
embedder = VertexAIEmbedder(MODEL_NAME, project, location)
async def embed_content_task():
"""A single task to send one embedding request using the global client."""
content_to_embed = random.choice(CONTENT_LIST)
await embedder.async_generate_embedding(content_to_embed)
async def run_test(concurrency: int):
"""Continuously calls the embedding API and tracks requests."""
total_requests = 0
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on model '{MODEL_NAME}'.")
logger.info("Press Ctrl+C to stop.")
while True:
# Create tasks, passing project_id and location
tasks = [embed_content_task() for _ in range(concurrency)]
try:
await asyncio.gather(*tasks)
total_requests += concurrency
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
except Exception as e:
logger.error("Caught an error. Stopping test.")
print("\n--- STATS ---")
print(f"Total successful requests: {total_requests}")
print(f"Concurrent requests during failure: {concurrency}")
print(f"Error Type: {e.__class__.__name__}")
print(f"Error Details: {e}")
print("-------------")
break
@app.command()
def main(
concurrency: int = typer.Option(
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
),
):
try:
asyncio.run(run_test(concurrency))
except KeyboardInterrupt:
logger.info("\nKeyboard interrupt received. Exiting.")
if __name__ == "__main__":
app()