Add linting/formatting
This commit is contained in:
@@ -4,12 +4,12 @@ import random
|
||||
|
||||
import pandas as pd
|
||||
import typer
|
||||
import vertexai
|
||||
from google import genai
|
||||
from google.cloud import storage
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from rich.console import Console
|
||||
from rich.progress import track
|
||||
from vertexai.generative_models import GenerationConfig, GenerativeModel
|
||||
|
||||
from va_evaluator.config import Settings
|
||||
|
||||
@@ -89,20 +89,25 @@ class MultiStepResponseSchema(BaseModel):
|
||||
|
||||
|
||||
def generate_structured(
|
||||
model: GenerativeModel,
|
||||
client: genai.Client,
|
||||
model: str,
|
||||
prompt: str,
|
||||
response_model: type[BaseModel],
|
||||
) -> BaseModel:
|
||||
generation_config = GenerationConfig(
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_model,
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_model,
|
||||
),
|
||||
)
|
||||
response = model.generate_content(prompt, generation_config=generation_config)
|
||||
return response_model.model_validate_json(response.text)
|
||||
|
||||
|
||||
def generate_synthetic_question(
|
||||
model: GenerativeModel,
|
||||
client: genai.Client,
|
||||
model: str,
|
||||
file_content: str,
|
||||
file_path: str,
|
||||
q_type: str,
|
||||
@@ -111,11 +116,12 @@ def generate_synthetic_question(
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
context=file_content, id=file_path, qtype=q_type, qtype_def=q_def
|
||||
)
|
||||
return generate_structured(model, prompt, ResponseSchema)
|
||||
return generate_structured(client, model, prompt, ResponseSchema)
|
||||
|
||||
|
||||
def generate_synthetic_conversation(
|
||||
model: GenerativeModel,
|
||||
client: genai.Client,
|
||||
model: str,
|
||||
file_content: str,
|
||||
file_path: str,
|
||||
num_turns: int,
|
||||
@@ -123,7 +129,7 @@ def generate_synthetic_conversation(
|
||||
prompt = MULTI_STEP_PROMPT_TEMPLATE.format(
|
||||
context=file_content, num_turns=num_turns
|
||||
)
|
||||
return generate_structured(model, prompt, MultiStepResponseSchema)
|
||||
return generate_structured(client, model, prompt, MultiStepResponseSchema)
|
||||
|
||||
|
||||
def generate(
|
||||
@@ -134,8 +140,12 @@ def generate(
|
||||
console = Console()
|
||||
settings = Settings()
|
||||
|
||||
vertexai.init(project=settings.project_id, location=settings.location)
|
||||
model = GenerativeModel(settings.agent.language_model)
|
||||
client = genai.Client(
|
||||
vertexai=True,
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
)
|
||||
model_name = settings.agent.language_model
|
||||
gcs_client = storage.Client(project=settings.project_id)
|
||||
bucket = gcs_client.bucket(settings.require_bucket)
|
||||
|
||||
@@ -165,9 +175,7 @@ def generate(
|
||||
console.print("[yellow]No files found. Skipping.[/yellow]")
|
||||
return ""
|
||||
|
||||
files_to_process = random.sample(
|
||||
all_files, k=min(num_questions, len(all_files))
|
||||
)
|
||||
files_to_process = random.sample(all_files, k=min(num_questions, len(all_files)))
|
||||
console.print(
|
||||
f"Randomly selected {len(files_to_process)} files to generate questions from."
|
||||
)
|
||||
@@ -182,7 +190,7 @@ def generate(
|
||||
conversation_data = None
|
||||
for attempt in range(3):
|
||||
conversation_data = generate_synthetic_conversation(
|
||||
model, file_content, file_path, num_turns
|
||||
client, model_name, file_content, file_path, num_turns
|
||||
)
|
||||
if (
|
||||
conversation_data
|
||||
@@ -219,7 +227,7 @@ def generate(
|
||||
generated_data = None
|
||||
for attempt in range(3):
|
||||
generated_data = generate_synthetic_question(
|
||||
model, file_content, file_path, q_type, q_def
|
||||
client, model_name, file_content, file_path, q_type, q_def
|
||||
)
|
||||
if (
|
||||
generated_data
|
||||
@@ -249,7 +257,9 @@ def generate(
|
||||
all_rows.append(row)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error processing file {file_path}: {e}[/bold red]")
|
||||
console.print(
|
||||
f"[bold red]Error processing file {file_path}: {e}[/bold red]"
|
||||
)
|
||||
|
||||
if not all_rows:
|
||||
console.print("[bold yellow]No questions were generated.[/bold yellow]")
|
||||
@@ -262,7 +272,9 @@ def generate(
|
||||
f"\n[bold green]Saving {len(df)} generated questions to {output_csv}...[/bold green]"
|
||||
)
|
||||
df.to_csv(output_csv, index=False, encoding="utf-8-sig")
|
||||
console.print("[bold green]Synthetic question generation complete.[/bold green]")
|
||||
console.print(
|
||||
"[bold green]Synthetic question generation complete.[/bold green]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"\n[bold green]Saving {len(df)} generated questions to BigQuery...[/bold green]"
|
||||
@@ -295,4 +307,3 @@ def generate(
|
||||
|
||||
console.print(f"[bold yellow]Finished run with ID: {run_id}[/bold yellow]")
|
||||
return run_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user