First commit
This commit is contained in:
80
scripts/diagnose_embeddings.py
Normal file
80
scripts/diagnose_embeddings.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import typer
|
||||
import random
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
|
||||
load_dotenv()
|
||||
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
location = os.getenv("GOOGLE_CLOUD_LOCATION")
|
||||
|
||||
MODEL_NAME = "gemini-embedding-001"
|
||||
CONTENT_LIST = [
|
||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
||||
"¿Qué es una hipoteca y cómo funciona?",
|
||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
||||
"¿Qué es la banca en línea y cómo me registro?",
|
||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
||||
]
|
||||
TASK_TYPE = "RETRIEVAL_DOCUMENT"
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
logger.info(f"Initializing GenAI Client for project '{project}' in '{location}'")
|
||||
embedder = VertexAIEmbedder(MODEL_NAME, project, location)
|
||||
|
||||
async def embed_content_task():
|
||||
"""A single task to send one embedding request using the global client."""
|
||||
content_to_embed = random.choice(CONTENT_LIST)
|
||||
await embedder.async_generate_embedding(content_to_embed)
|
||||
|
||||
async def run_test(concurrency: int):
|
||||
"""Continuously calls the embedding API and tracks requests."""
|
||||
total_requests = 0
|
||||
|
||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on model '{MODEL_NAME}'.")
|
||||
logger.info("Press Ctrl+C to stop.")
|
||||
|
||||
while True:
|
||||
# Create tasks, passing project_id and location
|
||||
tasks = [embed_content_task() for _ in range(concurrency)]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
total_requests += concurrency
|
||||
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
|
||||
except Exception as e:
|
||||
logger.error("Caught an error. Stopping test.")
|
||||
print("\n--- STATS ---")
|
||||
print(f"Total successful requests: {total_requests}")
|
||||
print(f"Concurrent requests during failure: {concurrency}")
|
||||
print(f"Error Type: {e.__class__.__name__}")
|
||||
print(f"Error Details: {e}")
|
||||
print("-------------")
|
||||
break
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
concurrency: int = typer.Option(
|
||||
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
|
||||
),
|
||||
):
|
||||
try:
|
||||
asyncio.run(run_test(concurrency))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nKeyboard interrupt received. Exiting.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
Reference in New Issue
Block a user