First commit
This commit is contained in:
80
scripts/diagnose_embeddings.py
Normal file
80
scripts/diagnose_embeddings.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import typer
|
||||
import random
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
|
||||
load_dotenv()
|
||||
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
location = os.getenv("GOOGLE_CLOUD_LOCATION")
|
||||
|
||||
MODEL_NAME = "gemini-embedding-001"
|
||||
CONTENT_LIST = [
|
||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
||||
"¿Qué es una hipoteca y cómo funciona?",
|
||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
||||
"¿Qué es la banca en línea y cómo me registro?",
|
||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
||||
]
|
||||
TASK_TYPE = "RETRIEVAL_DOCUMENT"
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
logger.info(f"Initializing GenAI Client for project '{project}' in '{location}'")
|
||||
embedder = VertexAIEmbedder(MODEL_NAME, project, location)
|
||||
|
||||
async def embed_content_task():
|
||||
"""A single task to send one embedding request using the global client."""
|
||||
content_to_embed = random.choice(CONTENT_LIST)
|
||||
await embedder.async_generate_embedding(content_to_embed)
|
||||
|
||||
async def run_test(concurrency: int):
|
||||
"""Continuously calls the embedding API and tracks requests."""
|
||||
total_requests = 0
|
||||
|
||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on model '{MODEL_NAME}'.")
|
||||
logger.info("Press Ctrl+C to stop.")
|
||||
|
||||
while True:
|
||||
# Create tasks, passing project_id and location
|
||||
tasks = [embed_content_task() for _ in range(concurrency)]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
total_requests += concurrency
|
||||
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
|
||||
except Exception as e:
|
||||
logger.error("Caught an error. Stopping test.")
|
||||
print("\n--- STATS ---")
|
||||
print(f"Total successful requests: {total_requests}")
|
||||
print(f"Concurrent requests during failure: {concurrency}")
|
||||
print(f"Error Type: {e.__class__.__name__}")
|
||||
print(f"Error Details: {e}")
|
||||
print("-------------")
|
||||
break
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
concurrency: int = typer.Option(
|
||||
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
|
||||
),
|
||||
):
|
||||
try:
|
||||
asyncio.run(run_test(concurrency))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nKeyboard interrupt received. Exiting.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
98
scripts/diagnose_rag_endpoint.py
Normal file
98
scripts/diagnose_rag_endpoint.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import typer
|
||||
import httpx
|
||||
|
||||
CONTENT_LIST = [
|
||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
||||
"¿Qué es una hipoteca y cómo funciona?",
|
||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
||||
"¿Qué es la banca en línea y cómo me registro?",
|
||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
||||
]
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
async def call_rag_endpoint_task(client: httpx.AsyncClient, url: str):
|
||||
"""A single task to send one request to the RAG endpoint."""
|
||||
question = random.choice(CONTENT_LIST)
|
||||
json_payload = {
|
||||
"sessionInfo": {
|
||||
"parameters": {
|
||||
"query": question
|
||||
}
|
||||
}
|
||||
}
|
||||
response = await client.post(url, json=json_payload)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
response_data = response.json()
|
||||
response_text = response_data["sessionInfo"]["parameters"]["response"]
|
||||
logger.info(f"Question: {question[:50]}... Response: {response_text[:100]}...")
|
||||
|
||||
async def run_test(concurrency: int, url: str, timeout_seconds: float):
|
||||
"""Continuously calls the RAG endpoint and tracks requests."""
|
||||
total_requests = 0
|
||||
|
||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on endpoint '{url}'.")
|
||||
logger.info(f"Request timeout is set to {timeout_seconds} seconds.")
|
||||
logger.info("Press Ctrl+C to stop.")
|
||||
|
||||
timeout = httpx.Timeout(timeout_seconds)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
while True:
|
||||
tasks = [call_rag_endpoint_task(client, url) for _ in range(concurrency)]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
total_requests += concurrency
|
||||
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"A request timed out: {e.request.method} {e.request.url}")
|
||||
logger.error("Consider increasing the timeout with the --timeout option.")
|
||||
break
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"An HTTP error occurred: {e.response.status_code} - {e.request.method} {e.request.url}")
|
||||
logger.error(f"Response body: {e.response.text}")
|
||||
break
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"A request error occurred: {e.request.method} {e.request.url}")
|
||||
logger.error(f"Error details: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Caught an unexpected error. Stopping test.")
|
||||
print("\n--- STATS ---")
|
||||
print(f"Total successful requests: {total_requests}")
|
||||
print(f"Concurrent requests during failure: {concurrency}")
|
||||
print(f"Error Type: {e.__class__.__name__}")
|
||||
print(f"Error Details: {e}")
|
||||
print("-------------")
|
||||
break
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
concurrency: int = typer.Option(
|
||||
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
|
||||
),
|
||||
url: str = typer.Option(
|
||||
"http://127.0.0.1:8000/sigma-rag", "--url", "-u", help="The URL of the RAG endpoint to test."
|
||||
),
|
||||
timeout_seconds: float = typer.Option(
|
||||
30.0, "--timeout", "-t", help="Request timeout in seconds."
|
||||
)
|
||||
):
|
||||
try:
|
||||
asyncio.run(run_test(concurrency, url, timeout_seconds))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nKeyboard interrupt received. Exiting.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
91
scripts/stress_test.py
Normal file
91
scripts/stress_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import requests
|
||||
import time
|
||||
import random
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
# URL for the endpoint
|
||||
url = "http://localhost:8000/sigma-rag"
|
||||
|
||||
# List of Spanish banking questions
|
||||
spanish_questions = [
|
||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
||||
"¿Qué es una hipoteca y cómo funciona?",
|
||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
||||
"¿Qué es la banca en línea y cómo me registro?",
|
||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
||||
]
|
||||
|
||||
# A threading Event to signal all threads to stop
|
||||
stop_event = threading.Event()
|
||||
|
||||
def send_request(question, request_id):
|
||||
"""Sends a single request and handles the response."""
|
||||
if stop_event.is_set():
|
||||
return
|
||||
|
||||
data = {"sessionInfo": {"parameters": {"query": question}}}
|
||||
try:
|
||||
response = requests.post(url, json=data)
|
||||
|
||||
if stop_event.is_set():
|
||||
return
|
||||
|
||||
if response.status_code == 500:
|
||||
print(f"Request {request_id}: Received 500 error with question: '{question}'.")
|
||||
print("Stopping stress test.")
|
||||
stop_event.set()
|
||||
else:
|
||||
print(f"Request {request_id}: Successful with status code {response.status_code}.")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
if not stop_event.is_set():
|
||||
print(f"Request {request_id}: An error occurred: {e}")
|
||||
stop_event.set()
|
||||
|
||||
def main():
|
||||
"""Runs the stress test with parallel requests."""
|
||||
num_workers = 30 # Number of parallel requests
|
||||
print(f"Starting stress test with {num_workers} parallel workers. Press Ctrl+C to stop.")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(send_request, random.choice(spanish_questions), i)
|
||||
for i in range(1, num_workers + 1)
|
||||
}
|
||||
request_id_counter = num_workers + 1
|
||||
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
# Wait for any future to complete
|
||||
done, _ = concurrent.futures.wait(
|
||||
futures, return_when=concurrent.futures.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
for future in done:
|
||||
# Remove the completed future
|
||||
futures.remove(future)
|
||||
|
||||
# If we are not stopping, submit a new one
|
||||
if not stop_event.is_set():
|
||||
futures.add(
|
||||
executor.submit(
|
||||
send_request,
|
||||
random.choice(spanish_questions),
|
||||
request_id_counter,
|
||||
)
|
||||
)
|
||||
request_id_counter += 1
|
||||
except KeyboardInterrupt:
|
||||
print("\nKeyboard interrupt received. Stopping threads.")
|
||||
stop_event.set()
|
||||
|
||||
print("Stress test finished.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
84
scripts/submit_pipeline.py
Normal file
84
scripts/submit_pipeline.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import typer
|
||||
from google.cloud import aiplatform
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
pipeline_spec_path: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--pipeline-spec-path",
|
||||
"-p",
|
||||
help="Path to the compiled pipeline YAML file.",
|
||||
),
|
||||
],
|
||||
input_table: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--input-table",
|
||||
"-i",
|
||||
help="Full BigQuery table name for input (e.g., 'project.dataset.table')",
|
||||
),
|
||||
],
|
||||
output_table: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--output-table",
|
||||
"-o",
|
||||
help="Full BigQuery table name for output (e.g., 'project.dataset.table')",
|
||||
),
|
||||
],
|
||||
project_id: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--project-id",
|
||||
help="Google Cloud project ID.",
|
||||
),
|
||||
] = settings.project_id,
|
||||
location: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--location",
|
||||
help="Google Cloud location for the pipeline job.",
|
||||
),
|
||||
] = settings.location,
|
||||
display_name: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--display-name",
|
||||
help="Display name for the pipeline job.",
|
||||
),
|
||||
] = "search-eval-pipeline-job",
|
||||
):
|
||||
"""Submits a Vertex AI pipeline job."""
|
||||
|
||||
parameter_values = {
|
||||
"project_id": project_id,
|
||||
"location": location,
|
||||
"input_table": input_table,
|
||||
"output_table": output_table,
|
||||
}
|
||||
|
||||
job = aiplatform.PipelineJob(
|
||||
display_name=display_name,
|
||||
template_path=pipeline_spec_path,
|
||||
pipeline_root=f"gs://{settings.bucket}/pipeline_root",
|
||||
parameter_values=parameter_values,
|
||||
project=project_id,
|
||||
location=location,
|
||||
)
|
||||
|
||||
print(f"Submitting pipeline job with parameters: {parameter_values}")
|
||||
job.submit(
|
||||
service_account="sa-cicd-gitlab@bnt-orquestador-cognitivo-dev.iam.gserviceaccount.com"
|
||||
)
|
||||
print(f"Pipeline job submitted. You can view it at: {job._dashboard_uri()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
42
scripts/test_rerank.py
Normal file
42
scripts/test_rerank.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from google.cloud import discoveryengine_v1 as discoveryengine
|
||||
|
||||
# TODO(developer): Uncomment these variables before running the sample.
|
||||
project_id = "bnt-orquestador-cognitivo-dev"
|
||||
|
||||
client = discoveryengine.RankServiceClient()
|
||||
|
||||
# The full resource name of the ranking config.
|
||||
# Format: projects/{project_id}/locations/{location}/rankingConfigs/default_ranking_config
|
||||
ranking_config = client.ranking_config_path(
|
||||
project=project_id,
|
||||
location="global",
|
||||
ranking_config="default_ranking_config",
|
||||
)
|
||||
request = discoveryengine.RankRequest(
|
||||
ranking_config=ranking_config,
|
||||
model="semantic-ranker-default@latest",
|
||||
top_n=10,
|
||||
query="What is Google Gemini?",
|
||||
records=[
|
||||
discoveryengine.RankingRecord(
|
||||
id="1",
|
||||
title="Gemini",
|
||||
content="The Gemini zodiac symbol often depicts two figures standing side-by-side.",
|
||||
),
|
||||
discoveryengine.RankingRecord(
|
||||
id="2",
|
||||
title="Gemini",
|
||||
content="Gemini is a cutting edge large language model created by Google.",
|
||||
),
|
||||
discoveryengine.RankingRecord(
|
||||
id="3",
|
||||
title="Gemini Constellation",
|
||||
content="Gemini is a constellation that can be seen in the night sky.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
response = client.rank(request=request)
|
||||
|
||||
# Handle the response
|
||||
print(response)
|
||||
12
scripts/test_server.py
Normal file
12
scripts/test_server.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import requests
|
||||
|
||||
# Test the /sigma-rag endpoint
|
||||
url = "http://localhost:8000/sigma-rag"
|
||||
data = {
|
||||
"sessionInfo": {"parameters": {"query": "What are the benefits of a credit card?"}}
|
||||
}
|
||||
|
||||
response = requests.post(url, json=data)
|
||||
|
||||
print("Response from /sigma-rag:")
|
||||
print(response.json())
|
||||
Reference in New Issue
Block a user