Initial implementation
This commit is contained in:
298
src/va_evaluator/synthetic_question_generator.py
Normal file
298
src/va_evaluator/synthetic_question_generator.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import datetime
|
||||
import os
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
import typer
|
||||
import vertexai
|
||||
from google.cloud import storage
|
||||
from pydantic import BaseModel
|
||||
from rich.console import Console
|
||||
from rich.progress import track
|
||||
from vertexai.generative_models import GenerationConfig, GenerativeModel
|
||||
|
||||
from va_evaluator.config import Settings
|
||||
|
||||
|
||||
# --- Prompts & Schemas ---
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
Eres un experto en generación de preguntas sintéticas. Tu tarea es crear preguntas sintéticas en español basadas en documentos de referencia proporcionados.
|
||||
|
||||
## INSTRUCCIONES:
|
||||
|
||||
### Requisitos obligatorios:
|
||||
1. **Idioma**: La pregunta DEBE estar completamente en español
|
||||
2. **Basada en documentos**: La pregunta DEBE poder responderse ÚNICAMENTE con la información contenida en los documentos proporcionados
|
||||
3. **Tipo de pregunta**: Sigue estrictamente la definición del tipo de pregunta especificado
|
||||
4. **Identificación de fuentes**: Incluye el ID de fuente de todos los documentos necesarios para responder la pregunta
|
||||
5. **Salida esperada**: Incluye la respuesta perfecta basada en los documentos necesarios para responder la pregunta
|
||||
|
||||
### Tono de pregunta:
|
||||
La pregunta debe ser similar a la que haría un usuario sin contexto sobre el sistema o la información disponible. Ingenuo y curioso.
|
||||
|
||||
### Tipo de pregunta solicitado:
|
||||
**Tipo**: {qtype}
|
||||
**Definición**: {qtype_def}
|
||||
|
||||
### Documentos de referencia:
|
||||
{context}
|
||||
|
||||
Por favor, genera una pregunta siguiendo estas instrucciones.
|
||||
""".strip()
|
||||
|
||||
MULTI_STEP_PROMPT_TEMPLATE = """
|
||||
Eres un experto en la generación de conversaciones sintéticas. Tu tarea es crear una conversación en español con múltiples turnos basada en los documentos de referencia proporcionados.
|
||||
|
||||
## INSTRUCCIONES:
|
||||
|
||||
### Requisitos obligatorios:
|
||||
1. **Idioma**: La conversación DEBE estar completamente en español.
|
||||
2. **Basada en documentos**: Todas las respuestas DEBEN poder responderse ÚNICAMENTE con la información contenida en los documentos de referencia.
|
||||
3. **Número de turnos**: La conversación debe tener exactamente {num_turns} turnos. Un turno consiste en una pregunta del usuario y una respuesta del asistente.
|
||||
4. **Flujo conversacional**: Las preguntas deben seguir un orden lógico, como si un usuario estuviera explorando un tema paso a paso. La segunda pregunta debe ser una continuación de la primera, y así sucesivamente.
|
||||
5. **Salida esperada**: Proporciona la respuesta perfecta para cada pregunta, basada en los documentos de referencia.
|
||||
|
||||
### Tono de las preguntas:
|
||||
Las preguntas deben ser similares a las que haría un usuario sin contexto sobre el sistema o la información disponible. Deben ser ingenuas y curiosas.
|
||||
|
||||
### Documentos de referencia:
|
||||
{context}
|
||||
|
||||
Por favor, genera una conversación de {num_turns} turnos siguiendo estas instrucciones.
|
||||
""".strip()
|
||||
|
||||
QUESTION_TYPE_MAP = {
|
||||
"Factual": "Questions targeting specific details within a reference (e.g., a company's profit in a report, a verdict in a legal case, or symptoms in a medical record) to test RAG's retrieval accuracy.",
|
||||
"Summarization": "Questions that require comprehensive answers, covering all relevant information, to mainly evaluate the recall rate of RAG retrieval.",
|
||||
"Multi-hop Reasoning": "Questions involve logical relationships among events and details within adocument, forming a reasoning chain to assess RAG's logical reasoning ability.",
|
||||
"Unanswerable": "Questions arise from potential information loss during the schema-to-article generation, where no corresponding information fragment exists, or the information is insufficient for an answer.",
|
||||
}
|
||||
|
||||
|
||||
class ResponseSchema(BaseModel):
|
||||
pregunta: str
|
||||
expected_output: str
|
||||
ids: list[str]
|
||||
|
||||
|
||||
class Turn(BaseModel):
|
||||
pregunta: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
class MultiStepResponseSchema(BaseModel):
|
||||
conversation: list[Turn]
|
||||
|
||||
|
||||
# --- Core Logic ---
|
||||
|
||||
|
||||
def generate_structured(
|
||||
model: GenerativeModel,
|
||||
prompt: str,
|
||||
response_model: type[BaseModel],
|
||||
) -> BaseModel:
|
||||
generation_config = GenerationConfig(
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_model,
|
||||
)
|
||||
response = model.generate_content(prompt, generation_config=generation_config)
|
||||
return response_model.model_validate_json(response.text)
|
||||
|
||||
|
||||
def generate_synthetic_question(
|
||||
model: GenerativeModel,
|
||||
file_content: str,
|
||||
file_path: str,
|
||||
q_type: str,
|
||||
q_def: str,
|
||||
) -> ResponseSchema:
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
context=file_content, id=file_path, qtype=q_type, qtype_def=q_def
|
||||
)
|
||||
return generate_structured(model, prompt, ResponseSchema)
|
||||
|
||||
|
||||
def generate_synthetic_conversation(
|
||||
model: GenerativeModel,
|
||||
file_content: str,
|
||||
file_path: str,
|
||||
num_turns: int,
|
||||
) -> MultiStepResponseSchema:
|
||||
prompt = MULTI_STEP_PROMPT_TEMPLATE.format(
|
||||
context=file_content, num_turns=num_turns
|
||||
)
|
||||
return generate_structured(model, prompt, MultiStepResponseSchema)
|
||||
|
||||
|
||||
def generate(
|
||||
num_questions: int,
|
||||
output_csv: str | None = None,
|
||||
num_turns: int = 1,
|
||||
) -> str:
|
||||
console = Console()
|
||||
settings = Settings()
|
||||
|
||||
vertexai.init(project=settings.project_id, location=settings.location)
|
||||
model = GenerativeModel(settings.agent.language_model)
|
||||
gcs_client = storage.Client(project=settings.project_id)
|
||||
bucket = gcs_client.bucket(settings.require_bucket)
|
||||
|
||||
run_id = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
console.print(f"[bold yellow]Generated Run ID: {run_id}[/bold yellow]")
|
||||
|
||||
all_rows = []
|
||||
if not settings.index.name:
|
||||
console.print("[yellow]Skipping as no index is configured.[/yellow]")
|
||||
return ""
|
||||
|
||||
gcs_path = f"{settings.index.name}/contents/"
|
||||
console.print(f"[green]Fetching files from GCS path: {gcs_path}[/green]")
|
||||
|
||||
try:
|
||||
all_files = [
|
||||
blob.name
|
||||
for blob in bucket.list_blobs(prefix=gcs_path)
|
||||
if not blob.name.endswith("/")
|
||||
]
|
||||
console.print(f"Found {len(all_files)} total files to process.")
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error listing files: {e}[/bold red]")
|
||||
return ""
|
||||
|
||||
if not all_files:
|
||||
console.print("[yellow]No files found. Skipping.[/yellow]")
|
||||
return ""
|
||||
|
||||
files_to_process = random.sample(
|
||||
all_files, k=min(num_questions, len(all_files))
|
||||
)
|
||||
console.print(
|
||||
f"Randomly selected {len(files_to_process)} files to generate questions from."
|
||||
)
|
||||
|
||||
for file_path in track(files_to_process, description="Generating questions..."):
|
||||
try:
|
||||
blob = bucket.blob(file_path)
|
||||
file_content = blob.download_as_text(encoding="utf-8-sig")
|
||||
q_type, q_def = random.choice(list(QUESTION_TYPE_MAP.items()))
|
||||
|
||||
if num_turns > 1:
|
||||
conversation_data = None
|
||||
for attempt in range(3):
|
||||
conversation_data = generate_synthetic_conversation(
|
||||
model, file_content, file_path, num_turns
|
||||
)
|
||||
if (
|
||||
conversation_data
|
||||
and conversation_data.conversation
|
||||
and len(conversation_data.conversation) == num_turns
|
||||
):
|
||||
break
|
||||
console.print(
|
||||
f"[yellow]Failed to generate valid conversation for {os.path.basename(file_path)}. Retrying ({attempt + 1}/3)...[/yellow]"
|
||||
)
|
||||
conversation_data = None
|
||||
|
||||
if not conversation_data:
|
||||
console.print(
|
||||
f"[bold red]Failed to generate valid conversation for {os.path.basename(file_path)} after 3 attempts. Skipping.[/bold red]"
|
||||
)
|
||||
continue
|
||||
|
||||
conversation_id = str(random.randint(10000, 99999))
|
||||
for i, turn in enumerate(conversation_data.conversation):
|
||||
row = {
|
||||
"input": turn.pregunta,
|
||||
"expected_output": turn.expected_output,
|
||||
"source": os.path.splitext(os.path.basename(file_path))[0],
|
||||
"type": "Multi-turn",
|
||||
"agent": settings.agent.name,
|
||||
"run_id": run_id,
|
||||
"conversation_id": conversation_id,
|
||||
"turn": i + 1,
|
||||
}
|
||||
all_rows.append(row)
|
||||
|
||||
else:
|
||||
generated_data = None
|
||||
for attempt in range(3):
|
||||
generated_data = generate_synthetic_question(
|
||||
model, file_content, file_path, q_type, q_def
|
||||
)
|
||||
if (
|
||||
generated_data
|
||||
and generated_data.expected_output
|
||||
and generated_data.expected_output.strip()
|
||||
):
|
||||
break
|
||||
console.print(
|
||||
f"[yellow]Empty answer for {q_type} on {os.path.basename(file_path)}. Retrying ({attempt + 1}/3)...[/yellow]"
|
||||
)
|
||||
generated_data = None
|
||||
|
||||
if not generated_data:
|
||||
console.print(
|
||||
f"[bold red]Failed to generate valid answer for {q_type} on {os.path.basename(file_path)} after 3 attempts. Skipping.[/bold red]"
|
||||
)
|
||||
continue
|
||||
|
||||
row = {
|
||||
"input": generated_data.pregunta,
|
||||
"expected_output": generated_data.expected_output,
|
||||
"source": os.path.splitext(os.path.basename(file_path))[0],
|
||||
"type": q_type,
|
||||
"agent": settings.agent.name,
|
||||
"run_id": run_id,
|
||||
}
|
||||
all_rows.append(row)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error processing file {file_path}: {e}[/bold red]")
|
||||
|
||||
if not all_rows:
|
||||
console.print("[bold yellow]No questions were generated.[/bold yellow]")
|
||||
return ""
|
||||
|
||||
df = pd.DataFrame(all_rows)
|
||||
|
||||
if output_csv:
|
||||
console.print(
|
||||
f"\n[bold green]Saving {len(df)} generated questions to {output_csv}...[/bold green]"
|
||||
)
|
||||
df.to_csv(output_csv, index=False, encoding="utf-8-sig")
|
||||
console.print("[bold green]Synthetic question generation complete.[/bold green]")
|
||||
else:
|
||||
console.print(
|
||||
f"\n[bold green]Saving {len(df)} generated questions to BigQuery...[/bold green]"
|
||||
)
|
||||
project_id = settings.bigquery.project_id or settings.project_id
|
||||
dataset_id = settings.bigquery.require_dataset_id
|
||||
table_name = settings.bigquery.synth_gen_table
|
||||
table_id = f"{project_id}.{dataset_id}.{table_name}"
|
||||
|
||||
console.print(f"Saving to BigQuery table: [bold cyan]{table_id}[/bold cyan]")
|
||||
try:
|
||||
if "conversation_id" not in df.columns:
|
||||
df["conversation_id"] = None
|
||||
if "turn" not in df.columns:
|
||||
df["turn"] = None
|
||||
|
||||
df.to_gbq(
|
||||
destination_table=f"{dataset_id}.{table_name}",
|
||||
project_id=project_id,
|
||||
if_exists="append",
|
||||
)
|
||||
console.print(
|
||||
f"Successfully saved {len(df)} rows to [bold green]{table_id}[/bold green]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while saving to BigQuery: {e}[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
console.print(f"[bold yellow]Finished run with ID: {run_id}[/bold yellow]")
|
||||
return run_id
|
||||
|
||||
Reference in New Issue
Block a user