Add linting/formatting

This commit is contained in:
Anibal Angulo
2026-02-23 17:23:22 +00:00
parent fd07f4a3e3
commit 099f6a50d1
7 changed files with 109 additions and 35 deletions

View File

@@ -4,12 +4,12 @@ import random
import pandas as pd
import typer
import vertexai
from google import genai
from google.cloud import storage
from google.genai import types
from pydantic import BaseModel
from rich.console import Console
from rich.progress import track
from vertexai.generative_models import GenerationConfig, GenerativeModel
from va_evaluator.config import Settings
@@ -89,20 +89,25 @@ class MultiStepResponseSchema(BaseModel):
def generate_structured(
model: GenerativeModel,
client: genai.Client,
model: str,
prompt: str,
response_model: type[BaseModel],
) -> BaseModel:
generation_config = GenerationConfig(
response_mime_type="application/json",
response_schema=response_model,
response = client.models.generate_content(
model=model,
contents=prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=response_model,
),
)
response = model.generate_content(prompt, generation_config=generation_config)
return response_model.model_validate_json(response.text)
def generate_synthetic_question(
model: GenerativeModel,
client: genai.Client,
model: str,
file_content: str,
file_path: str,
q_type: str,
@@ -111,11 +116,12 @@ def generate_synthetic_question(
prompt = PROMPT_TEMPLATE.format(
context=file_content, id=file_path, qtype=q_type, qtype_def=q_def
)
return generate_structured(model, prompt, ResponseSchema)
return generate_structured(client, model, prompt, ResponseSchema)
def generate_synthetic_conversation(
model: GenerativeModel,
client: genai.Client,
model: str,
file_content: str,
file_path: str,
num_turns: int,
@@ -123,7 +129,7 @@ def generate_synthetic_conversation(
prompt = MULTI_STEP_PROMPT_TEMPLATE.format(
context=file_content, num_turns=num_turns
)
return generate_structured(model, prompt, MultiStepResponseSchema)
return generate_structured(client, model, prompt, MultiStepResponseSchema)
def generate(
@@ -134,8 +140,12 @@ def generate(
console = Console()
settings = Settings()
vertexai.init(project=settings.project_id, location=settings.location)
model = GenerativeModel(settings.agent.language_model)
client = genai.Client(
vertexai=True,
project=settings.project_id,
location=settings.location,
)
model_name = settings.agent.language_model
gcs_client = storage.Client(project=settings.project_id)
bucket = gcs_client.bucket(settings.require_bucket)
@@ -165,9 +175,7 @@ def generate(
console.print("[yellow]No files found. Skipping.[/yellow]")
return ""
files_to_process = random.sample(
all_files, k=min(num_questions, len(all_files))
)
files_to_process = random.sample(all_files, k=min(num_questions, len(all_files)))
console.print(
f"Randomly selected {len(files_to_process)} files to generate questions from."
)
@@ -182,7 +190,7 @@ def generate(
conversation_data = None
for attempt in range(3):
conversation_data = generate_synthetic_conversation(
model, file_content, file_path, num_turns
client, model_name, file_content, file_path, num_turns
)
if (
conversation_data
@@ -219,7 +227,7 @@ def generate(
generated_data = None
for attempt in range(3):
generated_data = generate_synthetic_question(
model, file_content, file_path, q_type, q_def
client, model_name, file_content, file_path, q_type, q_def
)
if (
generated_data
@@ -249,7 +257,9 @@ def generate(
all_rows.append(row)
except Exception as e:
console.print(f"[bold red]Error processing file {file_path}: {e}[/bold red]")
console.print(
f"[bold red]Error processing file {file_path}: {e}[/bold red]"
)
if not all_rows:
console.print("[bold yellow]No questions were generated.[/bold yellow]")
@@ -262,7 +272,9 @@ def generate(
f"\n[bold green]Saving {len(df)} generated questions to {output_csv}...[/bold green]"
)
df.to_csv(output_csv, index=False, encoding="utf-8-sig")
console.print("[bold green]Synthetic question generation complete.[/bold green]")
console.print(
"[bold green]Synthetic question generation complete.[/bold green]"
)
else:
console.print(
f"\n[bold green]Saving {len(df)} generated questions to BigQuery...[/bold green]"
@@ -295,4 +307,3 @@ def generate(
console.print(f"[bold yellow]Finished run with ID: {run_id}[/bold yellow]")
return run_id