Switch to agent arch
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import typer
|
||||
import os
|
||||
import random
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
import typer
|
||||
from dotenv import load_dotenv
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
|
||||
@@ -27,7 +26,7 @@ CONTENT_LIST = [
|
||||
]
|
||||
TASK_TYPE = "RETRIEVAL_DOCUMENT"
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer()
|
||||
@@ -43,14 +42,14 @@ async def embed_content_task():
|
||||
async def run_test(concurrency: int):
|
||||
"""Continuously calls the embedding API and tracks requests."""
|
||||
total_requests = 0
|
||||
|
||||
|
||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on model '{MODEL_NAME}'.")
|
||||
logger.info("Press Ctrl+C to stop.")
|
||||
|
||||
while True:
|
||||
# Create tasks, passing project_id and location
|
||||
tasks = [embed_content_task() for _ in range(concurrency)]
|
||||
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
total_requests += concurrency
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import typer
|
||||
|
||||
import httpx
|
||||
import typer
|
||||
|
||||
CONTENT_LIST = [
|
||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
||||
@@ -17,7 +18,7 @@ CONTENT_LIST = [
|
||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
||||
]
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer()
|
||||
@@ -41,7 +42,7 @@ async def call_rag_endpoint_task(client: httpx.AsyncClient, url: str):
|
||||
async def run_test(concurrency: int, url: str, timeout_seconds: float):
|
||||
"""Continuously calls the RAG endpoint and tracks requests."""
|
||||
total_requests = 0
|
||||
|
||||
|
||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on endpoint '{url}'.")
|
||||
logger.info(f"Request timeout is set to {timeout_seconds} seconds.")
|
||||
logger.info("Press Ctrl+C to stop.")
|
||||
@@ -50,7 +51,7 @@ async def run_test(concurrency: int, url: str, timeout_seconds: float):
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
while True:
|
||||
tasks = [call_rag_endpoint_task(client, url) for _ in range(concurrency)]
|
||||
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
total_requests += concurrency
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import requests
|
||||
import time
|
||||
import random
|
||||
import concurrent.futures
|
||||
import random
|
||||
import threading
|
||||
|
||||
import requests
|
||||
|
||||
# URL for the endpoint
|
||||
url = "http://localhost:8000/sigma-rag"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from google.cloud import aiplatform
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
@@ -56,7 +57,6 @@ def main(
|
||||
] = "search-eval-pipeline-job",
|
||||
):
|
||||
"""Submits a Vertex AI pipeline job."""
|
||||
|
||||
parameter_values = {
|
||||
"project_id": project_id,
|
||||
"location": location,
|
||||
|
||||
Reference in New Issue
Block a user