310 lines
12 KiB
Python
310 lines
12 KiB
Python
import datetime
|
|
import os
|
|
import random
|
|
|
|
import pandas as pd
|
|
import typer
|
|
from google import genai
|
|
from google.cloud import storage
|
|
from google.genai import types
|
|
from pydantic import BaseModel
|
|
from rich.console import Console
|
|
from rich.progress import track
|
|
|
|
from va_evaluator.config import Settings
|
|
|
|
|
|
# --- Prompts & Schemas ---
|
|
|
|
PROMPT_TEMPLATE = """
|
|
Eres un experto en generación de preguntas sintéticas. Tu tarea es crear preguntas sintéticas en español basadas en documentos de referencia proporcionados.
|
|
|
|
## INSTRUCCIONES:
|
|
|
|
### Requisitos obligatorios:
|
|
1. **Idioma**: La pregunta DEBE estar completamente en español
|
|
2. **Basada en documentos**: La pregunta DEBE poder responderse ÚNICAMENTE con la información contenida en los documentos proporcionados
|
|
3. **Tipo de pregunta**: Sigue estrictamente la definición del tipo de pregunta especificado
|
|
4. **Identificación de fuentes**: Incluye el ID de fuente de todos los documentos necesarios para responder la pregunta
|
|
5. **Salida esperada**: Incluye la respuesta perfecta basada en los documentos necesarios para responder la pregunta
|
|
|
|
### Tono de pregunta:
|
|
La pregunta debe ser similar a la que haría un usuario sin contexto sobre el sistema o la información disponible. Ingenuo y curioso.
|
|
|
|
### Tipo de pregunta solicitado:
|
|
**Tipo**: {qtype}
|
|
**Definición**: {qtype_def}
|
|
|
|
### Documentos de referencia:
|
|
{context}
|
|
|
|
Por favor, genera una pregunta siguiendo estas instrucciones.
|
|
""".strip()
|
|
|
|
MULTI_STEP_PROMPT_TEMPLATE = """
|
|
Eres un experto en la generación de conversaciones sintéticas. Tu tarea es crear una conversación en español con múltiples turnos basada en los documentos de referencia proporcionados.
|
|
|
|
## INSTRUCCIONES:
|
|
|
|
### Requisitos obligatorios:
|
|
1. **Idioma**: La conversación DEBE estar completamente en español.
|
|
2. **Basada en documentos**: Todas las respuestas DEBEN poder responderse ÚNICAMENTE con la información contenida en los documentos de referencia.
|
|
3. **Número de turnos**: La conversación debe tener exactamente {num_turns} turnos. Un turno consiste en una pregunta del usuario y una respuesta del asistente.
|
|
4. **Flujo conversacional**: Las preguntas deben seguir un orden lógico, como si un usuario estuviera explorando un tema paso a paso. La segunda pregunta debe ser una continuación de la primera, y así sucesivamente.
|
|
5. **Salida esperada**: Proporciona la respuesta perfecta para cada pregunta, basada en los documentos de referencia.
|
|
|
|
### Tono de las preguntas:
|
|
Las preguntas deben ser similares a las que haría un usuario sin contexto sobre el sistema o la información disponible. Deben ser ingenuas y curiosas.
|
|
|
|
### Documentos de referencia:
|
|
{context}
|
|
|
|
Por favor, genera una conversación de {num_turns} turnos siguiendo estas instrucciones.
|
|
""".strip()
|
|
|
|
QUESTION_TYPE_MAP = {
|
|
"Factual": "Questions targeting specific details within a reference (e.g., a company's profit in a report, a verdict in a legal case, or symptoms in a medical record) to test RAG's retrieval accuracy.",
|
|
"Summarization": "Questions that require comprehensive answers, covering all relevant information, to mainly evaluate the recall rate of RAG retrieval.",
|
|
"Multi-hop Reasoning": "Questions involve logical relationships among events and details within adocument, forming a reasoning chain to assess RAG's logical reasoning ability.",
|
|
"Unanswerable": "Questions arise from potential information loss during the schema-to-article generation, where no corresponding information fragment exists, or the information is insufficient for an answer.",
|
|
}
|
|
|
|
|
|
class ResponseSchema(BaseModel):
|
|
pregunta: str
|
|
expected_output: str
|
|
ids: list[str]
|
|
|
|
|
|
class Turn(BaseModel):
|
|
pregunta: str
|
|
expected_output: str
|
|
|
|
|
|
class MultiStepResponseSchema(BaseModel):
|
|
conversation: list[Turn]
|
|
|
|
|
|
# --- Core Logic ---
|
|
|
|
|
|
def generate_structured(
|
|
client: genai.Client,
|
|
model: str,
|
|
prompt: str,
|
|
response_model: type[BaseModel],
|
|
) -> BaseModel:
|
|
response = client.models.generate_content(
|
|
model=model,
|
|
contents=prompt,
|
|
config=types.GenerateContentConfig(
|
|
response_mime_type="application/json",
|
|
response_schema=response_model,
|
|
),
|
|
)
|
|
return response_model.model_validate_json(response.text)
|
|
|
|
|
|
def generate_synthetic_question(
|
|
client: genai.Client,
|
|
model: str,
|
|
file_content: str,
|
|
file_path: str,
|
|
q_type: str,
|
|
q_def: str,
|
|
) -> ResponseSchema:
|
|
prompt = PROMPT_TEMPLATE.format(
|
|
context=file_content, id=file_path, qtype=q_type, qtype_def=q_def
|
|
)
|
|
return generate_structured(client, model, prompt, ResponseSchema)
|
|
|
|
|
|
def generate_synthetic_conversation(
|
|
client: genai.Client,
|
|
model: str,
|
|
file_content: str,
|
|
file_path: str,
|
|
num_turns: int,
|
|
) -> MultiStepResponseSchema:
|
|
prompt = MULTI_STEP_PROMPT_TEMPLATE.format(
|
|
context=file_content, num_turns=num_turns
|
|
)
|
|
return generate_structured(client, model, prompt, MultiStepResponseSchema)
|
|
|
|
|
|
def generate(
|
|
num_questions: int,
|
|
output_csv: str | None = None,
|
|
num_turns: int = 1,
|
|
) -> str:
|
|
console = Console()
|
|
settings = Settings()
|
|
|
|
client = genai.Client(
|
|
vertexai=True,
|
|
project=settings.project_id,
|
|
location=settings.location,
|
|
)
|
|
model_name = settings.agent.language_model
|
|
gcs_client = storage.Client(project=settings.project_id)
|
|
bucket = gcs_client.bucket(settings.require_bucket)
|
|
|
|
run_id = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
console.print(f"[bold yellow]Generated Run ID: {run_id}[/bold yellow]")
|
|
|
|
all_rows = []
|
|
if not settings.index.name:
|
|
console.print("[yellow]Skipping as no index is configured.[/yellow]")
|
|
return ""
|
|
|
|
gcs_path = f"{settings.index.name}/contents/"
|
|
console.print(f"[green]Fetching files from GCS path: {gcs_path}[/green]")
|
|
|
|
try:
|
|
all_files = [
|
|
blob.name
|
|
for blob in bucket.list_blobs(prefix=gcs_path)
|
|
if not blob.name.endswith("/")
|
|
]
|
|
console.print(f"Found {len(all_files)} total files to process.")
|
|
except Exception as e:
|
|
console.print(f"[bold red]Error listing files: {e}[/bold red]")
|
|
return ""
|
|
|
|
if not all_files:
|
|
console.print("[yellow]No files found. Skipping.[/yellow]")
|
|
return ""
|
|
|
|
files_to_process = random.sample(all_files, k=min(num_questions, len(all_files)))
|
|
console.print(
|
|
f"Randomly selected {len(files_to_process)} files to generate questions from."
|
|
)
|
|
|
|
for file_path in track(files_to_process, description="Generating questions..."):
|
|
try:
|
|
blob = bucket.blob(file_path)
|
|
file_content = blob.download_as_text(encoding="utf-8-sig")
|
|
q_type, q_def = random.choice(list(QUESTION_TYPE_MAP.items()))
|
|
|
|
if num_turns > 1:
|
|
conversation_data = None
|
|
for attempt in range(3):
|
|
conversation_data = generate_synthetic_conversation(
|
|
client, model_name, file_content, file_path, num_turns
|
|
)
|
|
if (
|
|
conversation_data
|
|
and conversation_data.conversation
|
|
and len(conversation_data.conversation) == num_turns
|
|
):
|
|
break
|
|
console.print(
|
|
f"[yellow]Failed to generate valid conversation for {os.path.basename(file_path)}. Retrying ({attempt + 1}/3)...[/yellow]"
|
|
)
|
|
conversation_data = None
|
|
|
|
if not conversation_data:
|
|
console.print(
|
|
f"[bold red]Failed to generate valid conversation for {os.path.basename(file_path)} after 3 attempts. Skipping.[/bold red]"
|
|
)
|
|
continue
|
|
|
|
conversation_id = str(random.randint(10000, 99999))
|
|
for i, turn in enumerate(conversation_data.conversation):
|
|
row = {
|
|
"input": turn.pregunta,
|
|
"expected_output": turn.expected_output,
|
|
"source": os.path.splitext(os.path.basename(file_path))[0],
|
|
"type": "Multi-turn",
|
|
"agent": settings.agent.name,
|
|
"run_id": run_id,
|
|
"conversation_id": conversation_id,
|
|
"turn": i + 1,
|
|
}
|
|
all_rows.append(row)
|
|
|
|
else:
|
|
generated_data = None
|
|
for attempt in range(3):
|
|
generated_data = generate_synthetic_question(
|
|
client, model_name, file_content, file_path, q_type, q_def
|
|
)
|
|
if (
|
|
generated_data
|
|
and generated_data.expected_output
|
|
and generated_data.expected_output.strip()
|
|
):
|
|
break
|
|
console.print(
|
|
f"[yellow]Empty answer for {q_type} on {os.path.basename(file_path)}. Retrying ({attempt + 1}/3)...[/yellow]"
|
|
)
|
|
generated_data = None
|
|
|
|
if not generated_data:
|
|
console.print(
|
|
f"[bold red]Failed to generate valid answer for {q_type} on {os.path.basename(file_path)} after 3 attempts. Skipping.[/bold red]"
|
|
)
|
|
continue
|
|
|
|
row = {
|
|
"input": generated_data.pregunta,
|
|
"expected_output": generated_data.expected_output,
|
|
"source": os.path.splitext(os.path.basename(file_path))[0],
|
|
"type": q_type,
|
|
"agent": settings.agent.name,
|
|
"run_id": run_id,
|
|
}
|
|
all_rows.append(row)
|
|
|
|
except Exception as e:
|
|
console.print(
|
|
f"[bold red]Error processing file {file_path}: {e}[/bold red]"
|
|
)
|
|
|
|
if not all_rows:
|
|
console.print("[bold yellow]No questions were generated.[/bold yellow]")
|
|
return ""
|
|
|
|
df = pd.DataFrame(all_rows)
|
|
|
|
if output_csv:
|
|
console.print(
|
|
f"\n[bold green]Saving {len(df)} generated questions to {output_csv}...[/bold green]"
|
|
)
|
|
df.to_csv(output_csv, index=False, encoding="utf-8-sig")
|
|
console.print(
|
|
"[bold green]Synthetic question generation complete.[/bold green]"
|
|
)
|
|
else:
|
|
console.print(
|
|
f"\n[bold green]Saving {len(df)} generated questions to BigQuery...[/bold green]"
|
|
)
|
|
project_id = settings.bigquery.project_id or settings.project_id
|
|
dataset_id = settings.bigquery.require_dataset_id
|
|
table_name = settings.bigquery.synth_gen_table
|
|
table_id = f"{project_id}.{dataset_id}.{table_name}"
|
|
|
|
console.print(f"Saving to BigQuery table: [bold cyan]{table_id}[/bold cyan]")
|
|
try:
|
|
if "conversation_id" not in df.columns:
|
|
df["conversation_id"] = None
|
|
if "turn" not in df.columns:
|
|
df["turn"] = None
|
|
|
|
df.to_gbq(
|
|
destination_table=f"{dataset_id}.{table_name}",
|
|
project_id=project_id,
|
|
if_exists="append",
|
|
)
|
|
console.print(
|
|
f"Successfully saved {len(df)} rows to [bold green]{table_id}[/bold green]"
|
|
)
|
|
except Exception as e:
|
|
console.print(
|
|
f"[bold red]An error occurred while saving to BigQuery: {e}[/bold red]"
|
|
)
|
|
raise typer.Exit(code=1)
|
|
|
|
console.print(f"[bold yellow]Finished run with ID: {run_id}[/bold yellow]")
|
|
return run_id
|