Initial implementation
This commit is contained in:
218
.gitignore
vendored
Normal file
218
.gitignore
vendored
Normal file
@@ -0,0 +1,218 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
# Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
# poetry.lock
|
||||
# poetry.toml
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||
# pdm.lock
|
||||
# pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# pixi
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||
# pixi.lock
|
||||
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||
.pixi
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# Redis
|
||||
*.rdb
|
||||
*.aof
|
||||
*.pid
|
||||
|
||||
# RabbitMQ
|
||||
mnesia/
|
||||
rabbitmq/
|
||||
rabbitmq-data/
|
||||
|
||||
# ActiveMQ
|
||||
activemq-data/
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.envrc
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
# .idea/
|
||||
|
||||
# Abstra
|
||||
# Abstra is an AI-powered process automation framework.
|
||||
# Ignore directories containing user credentials, local state, and settings.
|
||||
# Learn more at https://abstra.io/docs
|
||||
.abstra/
|
||||
|
||||
# Visual Studio Code
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# you could uncomment the following to ignore the entire vscode folder
|
||||
# .vscode/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Marimo
|
||||
marimo/_static/
|
||||
marimo/_lsp/
|
||||
__marimo__/
|
||||
|
||||
# Streamlit
|
||||
.streamlit/secrets.toml
|
||||
|
||||
ref/
|
||||
1
AGENTS.md
Normal file
1
AGENTS.md
Normal file
@@ -0,0 +1 @@
|
||||
Use `uv` for project management and dependency resolution.
|
||||
26
pyproject.toml
Normal file
26
pyproject.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[project]
|
||||
name = "va-evaluator"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"google-cloud-aiplatform",
|
||||
"google-cloud-bigquery>=3.40.1",
|
||||
"google-cloud-storage",
|
||||
"google-genai>=1.64.0",
|
||||
"pandas",
|
||||
"pandas-gbq",
|
||||
"pydantic",
|
||||
"pydantic-settings",
|
||||
"ranx>=0.3.21",
|
||||
"rich",
|
||||
"typer",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
va-evaluator = "va_evaluator.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
0
src/va_evaluator/__init__.py
Normal file
0
src/va_evaluator/__init__.py
Normal file
96
src/va_evaluator/cli.py
Normal file
96
src/va_evaluator/cli.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
app = typer.Typer(name="va-evaluator", help="VA Evaluator toolkit.")
|
||||
|
||||
|
||||
@app.command()
|
||||
def generate(
|
||||
num_questions: Annotated[
|
||||
int,
|
||||
typer.Option("--num-questions", "-n", help="Number of questions to generate."),
|
||||
] = 10,
|
||||
output_csv: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"--output-csv", "-o", help="Optional: Path to save the output CSV file."
|
||||
),
|
||||
] = None,
|
||||
num_turns: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--num-turns", "-t", help="Number of conversational turns to generate."
|
||||
),
|
||||
] = 1,
|
||||
):
|
||||
"""Generates synthetic questions and saves them to BigQuery (default) or a local CSV file."""
|
||||
from va_evaluator.synthetic_question_generator import generate as _generate
|
||||
|
||||
_generate(num_questions=num_questions, output_csv=output_csv, num_turns=num_turns)
|
||||
|
||||
|
||||
@app.command()
|
||||
def eval_search(
|
||||
input_file: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"-i",
|
||||
"--input-file",
|
||||
help="Path to a local CSV or SQLite file for evaluation data. "
|
||||
"If not provided, data will be loaded from BigQuery.",
|
||||
),
|
||||
] = None,
|
||||
output_file: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"-o",
|
||||
"--output-file",
|
||||
help="Path to save the detailed results as a CSV file. "
|
||||
"If not provided, results will be saved to BigQuery.",
|
||||
),
|
||||
] = None,
|
||||
run_id: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
help="Optional: The specific run_id to filter the evaluation data by."
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
"""Evaluates the search metrics by loading data from BigQuery or a local file."""
|
||||
from va_evaluator.search_metrics_evaluator import evaluate
|
||||
|
||||
evaluate(input_file=input_file, output_file=output_file, run_id=run_id)
|
||||
|
||||
|
||||
@app.command()
|
||||
def eval_keypoint(
|
||||
input_file: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"-i",
|
||||
"--input-file",
|
||||
help="Path to a local CSV or SQLite file for evaluation data. "
|
||||
"If not provided, data will be loaded from BigQuery.",
|
||||
),
|
||||
] = None,
|
||||
output_file: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"-o",
|
||||
"--output-file",
|
||||
help="Path to save the detailed results as a CSV file. "
|
||||
"If not provided, results will be saved to BigQuery.",
|
||||
),
|
||||
] = None,
|
||||
run_id: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
help="Optional: The specific run_id to filter the evaluation data by."
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
"""Evaluates RAG responses using the keypoint methodology."""
|
||||
from va_evaluator.keypoint_metrics_evaluator import evaluate
|
||||
|
||||
evaluate(input_file=input_file, output_file=output_file, run_id=run_id)
|
||||
57
src/va_evaluator/config.py
Normal file
57
src/va_evaluator/config.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class IndexSettings(BaseModel):
|
||||
name: str | None = None
|
||||
endpoint: str | None = None
|
||||
deployment: str | None = None
|
||||
|
||||
@property
|
||||
def require_endpoint(self) -> str:
|
||||
if not self.endpoint:
|
||||
raise ValueError("INDEX__ENDPOINT environment variable must be set")
|
||||
return self.endpoint
|
||||
|
||||
@property
|
||||
def require_deployment(self) -> str:
|
||||
if not self.deployment:
|
||||
raise ValueError("INDEX__DEPLOYMENT environment variable must be set")
|
||||
return self.deployment
|
||||
|
||||
|
||||
class AgentSettings(BaseModel):
|
||||
name: str = "default"
|
||||
language_model: str = "gemini-2.0-flash"
|
||||
embedding_model: str = "text-embedding-004"
|
||||
|
||||
|
||||
class BigQuerySettings(BaseModel):
|
||||
project_id: str | None = None
|
||||
dataset_id: str | None = None
|
||||
synth_gen_table: str = "synthetic_questions"
|
||||
search_eval_table: str = "search_eval_results"
|
||||
keypoint_eval_table: str = "keypoint_eval_results"
|
||||
|
||||
@property
|
||||
def require_dataset_id(self) -> str:
|
||||
if not self.dataset_id:
|
||||
raise ValueError("BIGQUERY__DATASET_ID environment variable must be set")
|
||||
return self.dataset_id
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_nested_delimiter="__")
|
||||
|
||||
project_id: str
|
||||
location: str = "us-central1"
|
||||
bucket: str | None = None
|
||||
index: IndexSettings = IndexSettings()
|
||||
agent: AgentSettings = AgentSettings()
|
||||
bigquery: BigQuerySettings = BigQuerySettings()
|
||||
|
||||
@property
|
||||
def require_bucket(self) -> str:
|
||||
if not self.bucket:
|
||||
raise ValueError("BUCKET environment variable must be set")
|
||||
return self.bucket
|
||||
558
src/va_evaluator/keypoint_metrics_evaluator.py
Normal file
558
src/va_evaluator/keypoint_metrics_evaluator.py
Normal file
@@ -0,0 +1,558 @@
|
||||
import json
|
||||
import pathlib
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
import pandas as pd
|
||||
import typer
|
||||
from google import genai
|
||||
from google.cloud import bigquery
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel, Field
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn
|
||||
from rich.table import Table
|
||||
|
||||
from va_evaluator.config import Settings
|
||||
|
||||
|
||||
# --- Schemas ---
|
||||
|
||||
|
||||
class KeyPointResponse(BaseModel):
|
||||
keypoints: list[str]
|
||||
|
||||
|
||||
class KeyPointEval(BaseModel):
|
||||
keypoint: str
|
||||
analysis: str
|
||||
category: Literal["relevant", "irrelevant", "incorrect"]
|
||||
|
||||
|
||||
class KeyPointEvalList(BaseModel):
|
||||
evals: list[KeyPointEval]
|
||||
|
||||
def _count(self, category: str) -> int:
|
||||
return sum(1 for e in self.evals if e.category == category)
|
||||
|
||||
def count_relevant(self) -> int:
|
||||
return self._count("relevant")
|
||||
|
||||
def count_irrelevant(self) -> int:
|
||||
return self._count("irrelevant")
|
||||
|
||||
def count_incorrect(self) -> int:
|
||||
return self._count("incorrect")
|
||||
|
||||
def keypoint_details(self) -> list[dict]:
|
||||
return [e.model_dump() for e in self.evals]
|
||||
|
||||
|
||||
class ConcisenessScore(BaseModel):
|
||||
score: float = Field(
|
||||
description="A score from 0.0 to 1.0 evaluating the conciseness of the answer."
|
||||
)
|
||||
|
||||
|
||||
# --- Prompts ---
|
||||
|
||||
EXTRACT_KEYPOINTS_PROMPT = """En esta tarea, se te dará una pregunta y una respuesta ideal. Basado en la respuesta ideal, \
|
||||
necesitas resumir los puntos clave necesarios para responder la pregunta.
|
||||
|
||||
<ejemplo>
|
||||
<pregunta>
|
||||
Cómo puedo sacar un adelanto de nómina?
|
||||
</pregunta>
|
||||
<respuesta>
|
||||
¡Hola! 👋 Sacar un Adelanto de Nómina con Banorte es muy fácil y \
|
||||
puede ayudarte con liquidez al instante. Aquí te explico cómo \
|
||||
funciona:
|
||||
|
||||
Es un monto de hasta $10,000 MXN que puedes usar para lo que \
|
||||
necesites, sin intereses y con una comisión fija del 7%. Lo puedes \
|
||||
contratar directamente desde la aplicación móvil de Banorte. Los \
|
||||
pagos se ajustan a la frecuencia de tu nómina y se cargan \
|
||||
automáticamente a tu cuenta.
|
||||
|
||||
Los principales requisitos son:
|
||||
* Recibir tu nómina en Banorte y no tener otro adelanto vigente.
|
||||
* Tener un ingreso neto mensual mayor a $2,000 MXN.
|
||||
* Tener entre 18 y 74 años con 11 meses.
|
||||
* Contar con un buen historial en Buró de Crédito.
|
||||
|
||||
¡Espero que esta información te sea muy útil! 😊
|
||||
</respuesta>
|
||||
<puntos clave>
|
||||
[
|
||||
"Recibir tu nómina en Banorte",
|
||||
"No tener otro adelanto vigente",
|
||||
"Tener entre 18 y 74 años con 11 meses",
|
||||
"Contar con buen historial en Buró de Crédito",
|
||||
]
|
||||
</puntos clave>
|
||||
</ejemplo>
|
||||
|
||||
<real>
|
||||
<pregunta>
|
||||
{question}
|
||||
</pregunta>
|
||||
<respuesta>
|
||||
{ground_truth}
|
||||
</respuesta>
|
||||
</real>
|
||||
"""
|
||||
|
||||
EVALUATE_KEYPOINTS_PROMPT = """En esta tarea, recibirás una respuesta real y múltiples puntos clave \
|
||||
extraídos de una respuesta ideal. Tu objetivo es evaluar la calidad y concisión de la respuesta generada.
|
||||
|
||||
Para cada punto clave, proporciona un breve análisis y concluye con una de las siguientes clasificaciones:
|
||||
|
||||
[[[ Relevante ]]] - La respuesta generada aborda el punto clave de manera precisa, correcta y directa. \
|
||||
La información es fácil de encontrar y no está oculta por un exceso de texto innecesario o "fluff" \
|
||||
conversacional (saludos, despedidas, jerga, etc.).
|
||||
|
||||
[[[ Irrelevante ]]] - La respuesta generada omite por completo el punto clave o no contiene ninguna \
|
||||
información relacionada con él. También se considera Irrelevante si la información del punto clave está \
|
||||
presente, pero tan oculta por el "fluff" que un usuario tendría dificultades para encontrarla.
|
||||
|
||||
[[[ Incorrecto ]]] - La respuesta generada contiene información relacionada con el punto clave pero es \
|
||||
incorrecta, contradice el punto clave, o podría confundir o desinformar al usuario.
|
||||
|
||||
**Criterio de Evaluación:**
|
||||
Sé estricto con el "fluff". Una respuesta ideal es tanto correcta como concisa. El exceso de texto \
|
||||
conversacional que no aporta valor a la respuesta debe penalizarse. Si la información clave está presente \
|
||||
pero la respuesta es innecesariamente larga y verbosa, considera rebajar su clasificación de Relevante \
|
||||
a Irrelevante.
|
||||
|
||||
Respuesta Generada: {generated_answer}
|
||||
|
||||
Puntos Clave de la Respuesta ideal:
|
||||
{keypoints_list}
|
||||
"""
|
||||
|
||||
CONCISENESS_PROMPT = """Evaluate the conciseness of the following generated answer in response to the user's query.
|
||||
The score should be a single float from 0.0 to 1.0, where 1.0 is perfectly concise and direct, \
|
||||
and 0.0 is extremely verbose and full of conversational fluff.
|
||||
Only consider the conciseness, not the correctness of the answer.
|
||||
|
||||
User Query: {query}
|
||||
Generated Answer: {answer}
|
||||
"""
|
||||
|
||||
|
||||
# --- Evaluator ---
|
||||
|
||||
|
||||
class KeypointRAGEvaluator:
|
||||
def __init__(
|
||||
self, client: genai.Client, console: Console, model: str = "gemini-2.0-flash"
|
||||
):
|
||||
self.client = client
|
||||
self.console = console
|
||||
self.model = model
|
||||
|
||||
def _structured_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
response_model: type[BaseModel],
|
||||
system_prompt: str,
|
||||
) -> BaseModel:
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_model,
|
||||
system_instruction=system_prompt,
|
||||
),
|
||||
)
|
||||
return response_model.model_validate_json(response.text)
|
||||
|
||||
def extract_keypoints(self, question: str, ground_truth: str) -> list[str]:
|
||||
prompt = EXTRACT_KEYPOINTS_PROMPT.format(
|
||||
question=question, ground_truth=ground_truth
|
||||
)
|
||||
response = self._structured_generation(
|
||||
prompt=prompt,
|
||||
response_model=KeyPointResponse,
|
||||
system_prompt="Eres un asistente experto en extraer puntos clave informativos de respuestas.",
|
||||
)
|
||||
return response.keypoints
|
||||
|
||||
def evaluate_keypoints(
|
||||
self,
|
||||
generated_answer: str,
|
||||
keypoints: list[str],
|
||||
) -> tuple[dict[str, float], list[dict]]:
|
||||
keypoints_list = "\n".join(
|
||||
[f"{i + 1}. {kp}" for i, kp in enumerate(keypoints)]
|
||||
)
|
||||
prompt = EVALUATE_KEYPOINTS_PROMPT.format(
|
||||
generated_answer=generated_answer,
|
||||
keypoints_list=keypoints_list,
|
||||
)
|
||||
response = self._structured_generation(
|
||||
prompt=prompt,
|
||||
response_model=KeyPointEvalList,
|
||||
system_prompt=(
|
||||
"Eres un evaluador experto de respuestas basadas en puntos clave, "
|
||||
"capaz de detectar si la información es relevante, irrelevante o incorrecta. "
|
||||
"Adoptas una postura favorable cuando evalúas la utilidad de las respuestas "
|
||||
"para los usuarios."
|
||||
),
|
||||
)
|
||||
|
||||
total_keypoints = len(keypoints)
|
||||
metrics = {
|
||||
"completeness": (
|
||||
response.count_relevant() / total_keypoints
|
||||
if total_keypoints > 0
|
||||
else 0
|
||||
),
|
||||
"hallucination": (
|
||||
response.count_incorrect() / total_keypoints
|
||||
if total_keypoints > 0
|
||||
else 0
|
||||
),
|
||||
"irrelevance": (
|
||||
response.count_irrelevant() / total_keypoints
|
||||
if total_keypoints > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
return metrics, response.keypoint_details()
|
||||
|
||||
def evaluate_conciseness(self, query: str, answer: str) -> float:
|
||||
prompt = CONCISENESS_PROMPT.format(query=query, answer=answer)
|
||||
try:
|
||||
response = self._structured_generation(
|
||||
prompt=prompt,
|
||||
response_model=ConcisenessScore,
|
||||
system_prompt=(
|
||||
"You are an expert evaluator focused on the conciseness and "
|
||||
"directness of answers. You output a single float score and nothing else."
|
||||
),
|
||||
)
|
||||
return response.score
|
||||
except Exception as e:
|
||||
self.console.print(
|
||||
f"[bold red]Error during conciseness evaluation: {str(e)}[/bold red]"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
def evaluate_response(
|
||||
self,
|
||||
query: str,
|
||||
response: str,
|
||||
ground_truth: str,
|
||||
) -> dict:
|
||||
keypoints = self.extract_keypoints(query, ground_truth)
|
||||
metrics, keypoint_details = self.evaluate_keypoints(response, keypoints)
|
||||
conciseness = self.evaluate_conciseness(query, response)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"response": response,
|
||||
"ground_truth": ground_truth,
|
||||
"completeness": metrics["completeness"],
|
||||
"hallucination": metrics["hallucination"],
|
||||
"irrelevance": metrics["irrelevance"],
|
||||
"conciseness": conciseness,
|
||||
"keypoints": keypoints,
|
||||
"keypoint_details": keypoint_details,
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
|
||||
|
||||
# --- Data Loading ---
|
||||
|
||||
|
||||
def load_data_from_local_file(
|
||||
file_path: str, console: Console, run_id: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""Loads evaluation data from a local CSV or SQLite file."""
|
||||
console.print(f"[bold green]Loading data from {file_path}...[/bold green]")
|
||||
path = pathlib.Path(file_path)
|
||||
if not path.exists():
|
||||
console.print(f"[bold red]Error: File not found at {file_path}[/bold red]")
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
if path.suffix == ".csv":
|
||||
df = pd.read_csv(path)
|
||||
elif path.suffix in [".db", ".sqlite"]:
|
||||
con = sqlite3.connect(path)
|
||||
df = pd.read_sql(f"SELECT * FROM {path.stem}", con)
|
||||
con.close()
|
||||
else:
|
||||
console.print(
|
||||
f"[bold red]Unsupported file type: {path.suffix}. Please use .csv or .db/.sqlite[/bold red]"
|
||||
)
|
||||
raise ValueError(f"Unsupported file type: {path.suffix}")
|
||||
|
||||
required_cols = {"input", "expected_output", "response"}
|
||||
if not required_cols.issubset(df.columns):
|
||||
missing = required_cols - set(df.columns)
|
||||
console.print(
|
||||
f"[bold red]Error: Missing required columns: {missing}[/bold red]"
|
||||
)
|
||||
raise ValueError(f"Missing required columns: {missing}")
|
||||
|
||||
if run_id:
|
||||
if "run_id" in df.columns:
|
||||
df = df[df["run_id"] == run_id].copy()
|
||||
console.print(f"Filtered data for run_id: [bold cyan]{run_id}[/bold cyan]")
|
||||
if df.empty:
|
||||
console.print(
|
||||
f"[bold yellow]Warning: No data found for run_id '{run_id}' in {file_path}.[/bold yellow]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[bold yellow]Warning: --run-id provided, but 'run_id' column not found in {file_path}. Using all data.[/bold yellow]"
|
||||
)
|
||||
|
||||
if "type" in df.columns:
|
||||
df = df[df["type"] != "Unanswerable"].copy()
|
||||
|
||||
df.dropna(subset=["input", "expected_output", "response"], inplace=True)
|
||||
console.print(f"Loaded {len(df)} rows for evaluation.")
|
||||
return df
|
||||
|
||||
|
||||
def load_data_from_bigquery(
|
||||
settings: Settings, console: Console, run_id: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""Loads evaluation data from the BigQuery table."""
|
||||
console.print("[bold green]Loading data from BigQuery...[/bold green]")
|
||||
bq_project_id = settings.bigquery.project_id or settings.project_id
|
||||
client = bigquery.Client(project=bq_project_id)
|
||||
table_ref = f"{bq_project_id}.{settings.bigquery.require_dataset_id}.{settings.bigquery.synth_gen_table}"
|
||||
|
||||
console.print(f"Querying table: [bold cyan]{table_ref}[/bold cyan]")
|
||||
query = f"""
|
||||
SELECT *
|
||||
FROM `{table_ref}`
|
||||
WHERE `type` != 'Unanswerable'
|
||||
"""
|
||||
if run_id:
|
||||
console.print(f"Filtering for run_id: [bold cyan]{run_id}[/bold cyan]")
|
||||
query += f" AND run_id = '{run_id}'"
|
||||
|
||||
df = client.query(query).to_dataframe()
|
||||
df.dropna(subset=["input", "expected_output", "response"], inplace=True)
|
||||
console.print(f"Loaded {len(df)} rows for evaluation.")
|
||||
if df.empty:
|
||||
console.print(
|
||||
"[bold yellow]Warning: No data found in BigQuery.[/bold yellow]"
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
# --- Core Logic ---
|
||||
|
||||
def evaluate(
|
||||
input_file: str | None = None,
|
||||
output_file: str | None = None,
|
||||
run_id: str | None = None,
|
||||
):
|
||||
"""Core logic for running keypoint-based evaluation."""
|
||||
console = Console()
|
||||
settings = Settings()
|
||||
|
||||
if input_file:
|
||||
df = load_data_from_local_file(input_file, console, run_id)
|
||||
else:
|
||||
df = load_data_from_bigquery(settings, console, run_id)
|
||||
|
||||
if df.empty:
|
||||
console.print("[bold red]No data to evaluate.[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
console.print(
|
||||
f"[bold blue]Running keypoint evaluation for agent: {settings.agent.name}[/bold blue]"
|
||||
)
|
||||
|
||||
client = genai.Client(
|
||||
vertexai=True,
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
)
|
||||
evaluator = KeypointRAGEvaluator(
|
||||
client=client,
|
||||
console=console,
|
||||
model=settings.agent.language_model,
|
||||
)
|
||||
|
||||
is_conversational = "conversation_id" in df.columns and "turn" in df.columns
|
||||
if is_conversational:
|
||||
df.sort_values(by=["conversation_id", "turn"], inplace=True)
|
||||
|
||||
all_results = []
|
||||
total_skipped = 0
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task(
|
||||
"[green]Processing evaluations...[/green]",
|
||||
total=len(df),
|
||||
)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
try:
|
||||
result = evaluator.evaluate_response(
|
||||
query=row["input"],
|
||||
response=row["response"],
|
||||
ground_truth=row["expected_output"],
|
||||
)
|
||||
result["agent"] = settings.agent.name
|
||||
if is_conversational:
|
||||
result["conversation_id"] = row.get("conversation_id")
|
||||
result["turn"] = row.get("turn")
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
if "Token limit exceeded" in str(e):
|
||||
total_skipped += 1
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold]Query:[/bold]\n[white]{row['input']}[/white]",
|
||||
title="[yellow]Skipping Question (Token Limit Exceeded)[/yellow]",
|
||||
expand=False,
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[bold red]Error evaluating row: {e}[/bold red]"
|
||||
)
|
||||
all_results.append(
|
||||
{
|
||||
"query": row["input"],
|
||||
"response": row["response"],
|
||||
"ground_truth": row["expected_output"],
|
||||
"completeness": 0.0,
|
||||
"hallucination": 0.0,
|
||||
"irrelevance": 0.0,
|
||||
"conciseness": 0.0,
|
||||
"keypoints": [],
|
||||
"keypoint_details": [],
|
||||
"timestamp": datetime.now(),
|
||||
"agent": settings.agent.name,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
finally:
|
||||
progress.advance(task)
|
||||
|
||||
if not all_results:
|
||||
console.print("[bold red]No evaluation results were generated.[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
final_df = pd.DataFrame(all_results)
|
||||
|
||||
# --- Summary Table ---
|
||||
summary_df = (
|
||||
final_df.groupby("agent")[
|
||||
["completeness", "hallucination", "irrelevance", "conciseness"]
|
||||
]
|
||||
.mean()
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
table = Table(
|
||||
title="[bold green]Keypoint Evaluation Summary[/bold green]",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
table.add_column("Agent", justify="left", style="cyan", no_wrap=True)
|
||||
table.add_column("Completeness", justify="right", style="magenta")
|
||||
table.add_column("Hallucination", justify="right", style="green")
|
||||
table.add_column("Irrelevance", justify="right", style="yellow")
|
||||
table.add_column("Conciseness", justify="right", style="cyan")
|
||||
|
||||
for _, row in summary_df.iterrows():
|
||||
table.add_row(
|
||||
row["agent"],
|
||||
f"{row['completeness']:.4f}",
|
||||
f"{row['hallucination']:.4f}",
|
||||
f"{row['irrelevance']:.4f}",
|
||||
f"{row['conciseness']:.4f}",
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
if total_skipped > 0:
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold yellow]Total questions skipped due to token limit: {total_skipped}[/bold yellow]",
|
||||
title="[bold]Skipped Questions[/bold]",
|
||||
expand=False,
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
if "timestamp" in final_df.columns:
|
||||
final_df["timestamp"] = pd.to_datetime(final_df["timestamp"]).dt.tz_localize(
|
||||
None
|
||||
)
|
||||
|
||||
if output_file:
|
||||
for col in ["keypoints", "keypoint_details"]:
|
||||
if col in final_df.columns:
|
||||
final_df[col] = final_df[col].apply(json.dumps)
|
||||
|
||||
console.print(
|
||||
f"Saving detailed results to CSV file: [bold cyan]{output_file}[/bold cyan]"
|
||||
)
|
||||
final_df.to_csv(output_file, index=False, encoding="utf-8-sig")
|
||||
console.print(
|
||||
f"Successfully saved {len(final_df)} rows to [bold green]{output_file}[/bold green]"
|
||||
)
|
||||
else:
|
||||
project_id = settings.bigquery.project_id or settings.project_id
|
||||
dataset_id = settings.bigquery.require_dataset_id
|
||||
table_name = settings.bigquery.keypoint_eval_table
|
||||
table_id = f"{project_id}.{dataset_id}.{table_name}"
|
||||
|
||||
console.print(
|
||||
f"Saving detailed results to BigQuery table: [bold cyan]{table_id}[/bold cyan]"
|
||||
)
|
||||
|
||||
if "run_id" not in final_df.columns:
|
||||
final_df["run_id"] = run_id
|
||||
if "error" not in final_df.columns:
|
||||
final_df["error"] = ""
|
||||
|
||||
final_df["completeness"] = final_df["completeness"].fillna(0.0)
|
||||
final_df["hallucination"] = final_df["hallucination"].fillna(0.0)
|
||||
final_df["irrelevance"] = final_df["irrelevance"].fillna(0.0)
|
||||
final_df["error"] = final_df["error"].fillna("")
|
||||
|
||||
for col_name in ["keypoints", "keypoint_details"]:
|
||||
if col_name in final_df.columns:
|
||||
final_df[col_name] = [
|
||||
item if isinstance(item, list) else []
|
||||
for item in final_df[col_name]
|
||||
]
|
||||
|
||||
try:
|
||||
final_df.to_gbq(
|
||||
destination_table=f"{dataset_id}.{table_name}",
|
||||
project_id=project_id,
|
||||
if_exists="append",
|
||||
)
|
||||
console.print(
|
||||
f"Successfully saved {len(final_df)} rows to [bold green]{table_id}[/bold green]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while saving to BigQuery: {e}[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
256
src/va_evaluator/search_metrics_evaluator.py
Normal file
256
src/va_evaluator/search_metrics_evaluator.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import pathlib
|
||||
import sqlite3
|
||||
|
||||
import pandas as pd
|
||||
import typer
|
||||
import vertexai
|
||||
from google.cloud import bigquery
|
||||
from google.cloud.aiplatform.matching_engine import MatchingEngineIndexEndpoint
|
||||
from ranx import Qrels, Run
|
||||
from ranx import evaluate as ranx_evaluate
|
||||
from rich.console import Console
|
||||
from rich.progress import track
|
||||
from rich.table import Table
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
|
||||
from va_evaluator.config import Settings
|
||||
|
||||
|
||||
# --- Core Logic ---
|
||||
|
||||
|
||||
def load_data_from_local_file(
|
||||
file_path: str, console: Console, run_id: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""Loads evaluation data from a local CSV or SQLite file."""
|
||||
console.print(f"[bold green]Loading data from {file_path}...[/bold green]")
|
||||
path = pathlib.Path(file_path)
|
||||
if not path.exists():
|
||||
console.print(f"[bold red]Error: File not found at {file_path}[/bold red]")
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
if path.suffix == ".csv":
|
||||
df = pd.read_csv(path)
|
||||
elif path.suffix in [".db", ".sqlite"]:
|
||||
con = sqlite3.connect(path)
|
||||
df = pd.read_sql("SELECT * FROM evaluation_data", con)
|
||||
con.close()
|
||||
else:
|
||||
console.print(
|
||||
f"[bold red]Unsupported file type: {path.suffix}. Please use .csv or .db/.sqlite[/bold red]"
|
||||
)
|
||||
raise ValueError(f"Unsupported file type: {path.suffix}")
|
||||
|
||||
if "input" in df.columns and "source" in df.columns:
|
||||
df = df.rename(columns={"input": "question", "source": "document_path"})
|
||||
df["id"] = (df.index + 1).astype(str)
|
||||
else:
|
||||
console.print(
|
||||
"[bold red]Error: The input file must contain 'input' and 'source' columns.[/bold red]"
|
||||
)
|
||||
raise ValueError("Input file must contain 'input' and 'source' columns.")
|
||||
|
||||
if run_id:
|
||||
if "run_id" in df.columns:
|
||||
df = df[df["run_id"] == run_id].copy()
|
||||
console.print(f"Filtered data for run_id: [bold cyan]{run_id}[/bold cyan]")
|
||||
if df.empty:
|
||||
console.print(
|
||||
f"[bold yellow]Warning: No data found for run_id '{run_id}' in {file_path}.[/bold yellow]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[bold yellow]Warning: --run-id provided, but 'run_id' column not found in {file_path}. Using all data.[/bold yellow]"
|
||||
)
|
||||
|
||||
df.dropna(inplace=True)
|
||||
console.print(f"Loaded {len(df)} questions for evaluation.")
|
||||
return df
|
||||
|
||||
|
||||
def load_data_from_bigquery(
|
||||
settings: Settings, console: Console, run_id: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""Loads evaluation data from the BigQuery table."""
|
||||
console.print("[bold green]Loading data from BigQuery...[/bold green]")
|
||||
bq_project_id = settings.bigquery.project_id or settings.project_id
|
||||
client = bigquery.Client(project=bq_project_id)
|
||||
table_ref = f"{bq_project_id}.{settings.bigquery.require_dataset_id}.{settings.bigquery.synth_gen_table}"
|
||||
|
||||
console.print(f"Querying table: [bold cyan]{table_ref}[/bold cyan]")
|
||||
query = f"""
|
||||
SELECT
|
||||
input AS question,
|
||||
source AS document_path,
|
||||
ROW_NUMBER() OVER() as id
|
||||
FROM
|
||||
`{table_ref}`
|
||||
WHERE
|
||||
`type` != 'Unanswerable'
|
||||
"""
|
||||
if run_id:
|
||||
console.print(f"Filtering for run_id: [bold cyan]{run_id}[/bold cyan]")
|
||||
query += f" AND run_id = '{run_id}'"
|
||||
|
||||
df = client.query(query).to_dataframe()
|
||||
df.dropna(inplace=True)
|
||||
console.print(f"Loaded {len(df)} questions for evaluation.")
|
||||
if df.empty:
|
||||
console.print(
|
||||
f"[bold yellow]Warning: No data found for run_id '{run_id}' in BigQuery.[/bold yellow]"
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def run_evaluation(
|
||||
df: pd.DataFrame, settings: Settings, console: Console
|
||||
) -> pd.DataFrame:
|
||||
"""Runs the search evaluation on the given dataframe."""
|
||||
console.print(
|
||||
f"Embedding Model: [bold cyan]{settings.agent.embedding_model}[/bold cyan]"
|
||||
)
|
||||
console.print(f"Index Name: [bold cyan]{settings.index.name}[/bold cyan]")
|
||||
|
||||
vertexai.init(project=settings.project_id, location=settings.location)
|
||||
embedding_model = TextEmbeddingModel.from_pretrained(settings.agent.embedding_model)
|
||||
index_endpoint = MatchingEngineIndexEndpoint(settings.index.require_endpoint)
|
||||
|
||||
# Prepare qrels
|
||||
qrels_data = {}
|
||||
for _, row in track(df.iterrows(), total=len(df), description="Preparing qrels..."):
|
||||
doc_path = str(row["document_path"]).split("/")[-1].strip()
|
||||
qrels_data[str(row["id"])] = {doc_path: 1}
|
||||
qrels = Qrels(qrels_data)
|
||||
|
||||
# Prepare run
|
||||
run_data = {}
|
||||
detailed_results_list = []
|
||||
for _, row in track(df.iterrows(), total=len(df), description="Preparing run..."):
|
||||
embeddings = embedding_model.get_embeddings([row["question"]])
|
||||
question_embedding = embeddings[0].values
|
||||
|
||||
results = index_endpoint.find_neighbors(
|
||||
deployed_index_id=settings.index.require_deployment,
|
||||
queries=[question_embedding],
|
||||
num_neighbors=10,
|
||||
)
|
||||
|
||||
neighbors = results[0]
|
||||
run_data[str(row["id"])] = {
|
||||
neighbor.id: neighbor.distance for neighbor in neighbors
|
||||
}
|
||||
|
||||
retrieved_docs = [neighbor.id for neighbor in neighbors]
|
||||
retrieved_distances = [neighbor.distance for neighbor in neighbors]
|
||||
expected_doc = str(row["document_path"]).split("/")[-1].strip()
|
||||
|
||||
detailed_results_list.append(
|
||||
{
|
||||
"agent": settings.agent.name,
|
||||
"id": row["id"],
|
||||
"input": row["question"],
|
||||
"expected_document": expected_doc,
|
||||
"retrieved_documents": retrieved_docs,
|
||||
"retrieved_distances": retrieved_distances,
|
||||
"is_expected_in_results": expected_doc in retrieved_docs,
|
||||
}
|
||||
)
|
||||
run = Run(run_data)
|
||||
|
||||
# Evaluate
|
||||
k_values = [1, 3, 5, 10]
|
||||
metrics = []
|
||||
for k in k_values:
|
||||
metrics.extend(
|
||||
[f"precision@{k}", f"recall@{k}", f"f1@{k}", f"ndcg@{k}", f"mrr@{k}"]
|
||||
)
|
||||
|
||||
with console.status("[bold green]Running evaluation..."):
|
||||
results = ranx_evaluate(qrels, run, metrics)
|
||||
|
||||
table = Table(title=f"Search Metrics @k for Agent: {settings.agent.name}")
|
||||
table.add_column("k", justify="right", style="cyan")
|
||||
table.add_column("Precision@k", justify="right")
|
||||
table.add_column("Recall@k", justify="right")
|
||||
table.add_column("F1@k", justify="right")
|
||||
table.add_column("nDCG@k", justify="right")
|
||||
table.add_column("MRR@k", justify="right")
|
||||
|
||||
for k in k_values:
|
||||
precision = results.get(f"precision@{k}")
|
||||
recall = results.get(f"recall@{k}")
|
||||
f1 = results.get(f"f1@{k}")
|
||||
ndcg = results.get(f"ndcg@{k}")
|
||||
mrr = results.get(f"mrr@{k}")
|
||||
table.add_row(
|
||||
str(k),
|
||||
f"{precision:.4f}" if precision is not None else "N/A",
|
||||
f"{recall:.4f}" if recall is not None else "N/A",
|
||||
f"{f1:.4f}" if f1 is not None else "N/A",
|
||||
f"{ndcg:.4f}" if ndcg is not None else "N/A",
|
||||
f"{mrr:.4f}" if mrr is not None else "N/A",
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
return pd.DataFrame(detailed_results_list)
|
||||
|
||||
|
||||
def evaluate(
|
||||
input_file: str | None = None,
|
||||
output_file: str | None = None,
|
||||
run_id: str | None = None,
|
||||
):
|
||||
"""Core logic for evaluating search metrics."""
|
||||
console = Console()
|
||||
settings = Settings()
|
||||
|
||||
if input_file:
|
||||
df = load_data_from_local_file(input_file, console, run_id)
|
||||
else:
|
||||
df = load_data_from_bigquery(settings, console, run_id)
|
||||
|
||||
if df.empty:
|
||||
console.print("[bold red]No data to evaluate.[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if not settings.index.name:
|
||||
console.print("[yellow]Skipping as no index is configured.[/yellow]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
console.print(
|
||||
f"[bold blue]Running evaluation for agent: {settings.agent.name}[/bold blue]"
|
||||
)
|
||||
results_df = run_evaluation(df, settings, console)
|
||||
|
||||
if output_file:
|
||||
console.print(
|
||||
f"Saving detailed results to CSV file: [bold cyan]{output_file}[/bold cyan]"
|
||||
)
|
||||
results_df.to_csv(output_file, index=False)
|
||||
console.print(
|
||||
f"Successfully saved {len(results_df)} rows to [bold green]{output_file}[/bold green]"
|
||||
)
|
||||
else:
|
||||
project_id = settings.bigquery.project_id or settings.project_id
|
||||
dataset_id = settings.bigquery.require_dataset_id
|
||||
table_name = settings.bigquery.search_eval_table
|
||||
table_id = f"{project_id}.{dataset_id}.{table_name}"
|
||||
|
||||
console.print(
|
||||
f"Saving detailed results to BigQuery table: [bold cyan]{table_id}[/bold cyan]"
|
||||
)
|
||||
try:
|
||||
results_df.to_gbq(
|
||||
destination_table=f"{dataset_id}.{table_name}",
|
||||
project_id=project_id,
|
||||
if_exists="append",
|
||||
)
|
||||
console.print(
|
||||
f"Successfully saved {len(results_df)} rows to [bold green]{table_id}[/bold green]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while saving to BigQuery: {e}[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
298
src/va_evaluator/synthetic_question_generator.py
Normal file
298
src/va_evaluator/synthetic_question_generator.py
Normal file
@@ -0,0 +1,298 @@
|
||||
import datetime
|
||||
import os
|
||||
import random
|
||||
|
||||
import pandas as pd
|
||||
import typer
|
||||
import vertexai
|
||||
from google.cloud import storage
|
||||
from pydantic import BaseModel
|
||||
from rich.console import Console
|
||||
from rich.progress import track
|
||||
from vertexai.generative_models import GenerationConfig, GenerativeModel
|
||||
|
||||
from va_evaluator.config import Settings
|
||||
|
||||
|
||||
# --- Prompts & Schemas ---
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
Eres un experto en generación de preguntas sintéticas. Tu tarea es crear preguntas sintéticas en español basadas en documentos de referencia proporcionados.
|
||||
|
||||
## INSTRUCCIONES:
|
||||
|
||||
### Requisitos obligatorios:
|
||||
1. **Idioma**: La pregunta DEBE estar completamente en español
|
||||
2. **Basada en documentos**: La pregunta DEBE poder responderse ÚNICAMENTE con la información contenida en los documentos proporcionados
|
||||
3. **Tipo de pregunta**: Sigue estrictamente la definición del tipo de pregunta especificado
|
||||
4. **Identificación de fuentes**: Incluye el ID de fuente de todos los documentos necesarios para responder la pregunta
|
||||
5. **Salida esperada**: Incluye la respuesta perfecta basada en los documentos necesarios para responder la pregunta
|
||||
|
||||
### Tono de pregunta:
|
||||
La pregunta debe ser similar a la que haría un usuario sin contexto sobre el sistema o la información disponible. Ingenuo y curioso.
|
||||
|
||||
### Tipo de pregunta solicitado:
|
||||
**Tipo**: {qtype}
|
||||
**Definición**: {qtype_def}
|
||||
|
||||
### Documentos de referencia:
|
||||
{context}
|
||||
|
||||
Por favor, genera una pregunta siguiendo estas instrucciones.
|
||||
""".strip()
|
||||
|
||||
MULTI_STEP_PROMPT_TEMPLATE = """
|
||||
Eres un experto en la generación de conversaciones sintéticas. Tu tarea es crear una conversación en español con múltiples turnos basada en los documentos de referencia proporcionados.
|
||||
|
||||
## INSTRUCCIONES:
|
||||
|
||||
### Requisitos obligatorios:
|
||||
1. **Idioma**: La conversación DEBE estar completamente en español.
|
||||
2. **Basada en documentos**: Todas las respuestas DEBEN poder responderse ÚNICAMENTE con la información contenida en los documentos de referencia.
|
||||
3. **Número de turnos**: La conversación debe tener exactamente {num_turns} turnos. Un turno consiste en una pregunta del usuario y una respuesta del asistente.
|
||||
4. **Flujo conversacional**: Las preguntas deben seguir un orden lógico, como si un usuario estuviera explorando un tema paso a paso. La segunda pregunta debe ser una continuación de la primera, y así sucesivamente.
|
||||
5. **Salida esperada**: Proporciona la respuesta perfecta para cada pregunta, basada en los documentos de referencia.
|
||||
|
||||
### Tono de las preguntas:
|
||||
Las preguntas deben ser similares a las que haría un usuario sin contexto sobre el sistema o la información disponible. Deben ser ingenuas y curiosas.
|
||||
|
||||
### Documentos de referencia:
|
||||
{context}
|
||||
|
||||
Por favor, genera una conversación de {num_turns} turnos siguiendo estas instrucciones.
|
||||
""".strip()
|
||||
|
||||
QUESTION_TYPE_MAP = {
|
||||
"Factual": "Questions targeting specific details within a reference (e.g., a company's profit in a report, a verdict in a legal case, or symptoms in a medical record) to test RAG's retrieval accuracy.",
|
||||
"Summarization": "Questions that require comprehensive answers, covering all relevant information, to mainly evaluate the recall rate of RAG retrieval.",
|
||||
"Multi-hop Reasoning": "Questions involve logical relationships among events and details within adocument, forming a reasoning chain to assess RAG's logical reasoning ability.",
|
||||
"Unanswerable": "Questions arise from potential information loss during the schema-to-article generation, where no corresponding information fragment exists, or the information is insufficient for an answer.",
|
||||
}
|
||||
|
||||
|
||||
class ResponseSchema(BaseModel):
|
||||
pregunta: str
|
||||
expected_output: str
|
||||
ids: list[str]
|
||||
|
||||
|
||||
class Turn(BaseModel):
|
||||
pregunta: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
class MultiStepResponseSchema(BaseModel):
|
||||
conversation: list[Turn]
|
||||
|
||||
|
||||
# --- Core Logic ---
|
||||
|
||||
|
||||
def generate_structured(
|
||||
model: GenerativeModel,
|
||||
prompt: str,
|
||||
response_model: type[BaseModel],
|
||||
) -> BaseModel:
|
||||
generation_config = GenerationConfig(
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_model,
|
||||
)
|
||||
response = model.generate_content(prompt, generation_config=generation_config)
|
||||
return response_model.model_validate_json(response.text)
|
||||
|
||||
|
||||
def generate_synthetic_question(
|
||||
model: GenerativeModel,
|
||||
file_content: str,
|
||||
file_path: str,
|
||||
q_type: str,
|
||||
q_def: str,
|
||||
) -> ResponseSchema:
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
context=file_content, id=file_path, qtype=q_type, qtype_def=q_def
|
||||
)
|
||||
return generate_structured(model, prompt, ResponseSchema)
|
||||
|
||||
|
||||
def generate_synthetic_conversation(
|
||||
model: GenerativeModel,
|
||||
file_content: str,
|
||||
file_path: str,
|
||||
num_turns: int,
|
||||
) -> MultiStepResponseSchema:
|
||||
prompt = MULTI_STEP_PROMPT_TEMPLATE.format(
|
||||
context=file_content, num_turns=num_turns
|
||||
)
|
||||
return generate_structured(model, prompt, MultiStepResponseSchema)
|
||||
|
||||
|
||||
def generate(
|
||||
num_questions: int,
|
||||
output_csv: str | None = None,
|
||||
num_turns: int = 1,
|
||||
) -> str:
|
||||
console = Console()
|
||||
settings = Settings()
|
||||
|
||||
vertexai.init(project=settings.project_id, location=settings.location)
|
||||
model = GenerativeModel(settings.agent.language_model)
|
||||
gcs_client = storage.Client(project=settings.project_id)
|
||||
bucket = gcs_client.bucket(settings.require_bucket)
|
||||
|
||||
run_id = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
console.print(f"[bold yellow]Generated Run ID: {run_id}[/bold yellow]")
|
||||
|
||||
all_rows = []
|
||||
if not settings.index.name:
|
||||
console.print("[yellow]Skipping as no index is configured.[/yellow]")
|
||||
return ""
|
||||
|
||||
gcs_path = f"{settings.index.name}/contents/"
|
||||
console.print(f"[green]Fetching files from GCS path: {gcs_path}[/green]")
|
||||
|
||||
try:
|
||||
all_files = [
|
||||
blob.name
|
||||
for blob in bucket.list_blobs(prefix=gcs_path)
|
||||
if not blob.name.endswith("/")
|
||||
]
|
||||
console.print(f"Found {len(all_files)} total files to process.")
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error listing files: {e}[/bold red]")
|
||||
return ""
|
||||
|
||||
if not all_files:
|
||||
console.print("[yellow]No files found. Skipping.[/yellow]")
|
||||
return ""
|
||||
|
||||
files_to_process = random.sample(
|
||||
all_files, k=min(num_questions, len(all_files))
|
||||
)
|
||||
console.print(
|
||||
f"Randomly selected {len(files_to_process)} files to generate questions from."
|
||||
)
|
||||
|
||||
for file_path in track(files_to_process, description="Generating questions..."):
|
||||
try:
|
||||
blob = bucket.blob(file_path)
|
||||
file_content = blob.download_as_text(encoding="utf-8-sig")
|
||||
q_type, q_def = random.choice(list(QUESTION_TYPE_MAP.items()))
|
||||
|
||||
if num_turns > 1:
|
||||
conversation_data = None
|
||||
for attempt in range(3):
|
||||
conversation_data = generate_synthetic_conversation(
|
||||
model, file_content, file_path, num_turns
|
||||
)
|
||||
if (
|
||||
conversation_data
|
||||
and conversation_data.conversation
|
||||
and len(conversation_data.conversation) == num_turns
|
||||
):
|
||||
break
|
||||
console.print(
|
||||
f"[yellow]Failed to generate valid conversation for {os.path.basename(file_path)}. Retrying ({attempt + 1}/3)...[/yellow]"
|
||||
)
|
||||
conversation_data = None
|
||||
|
||||
if not conversation_data:
|
||||
console.print(
|
||||
f"[bold red]Failed to generate valid conversation for {os.path.basename(file_path)} after 3 attempts. Skipping.[/bold red]"
|
||||
)
|
||||
continue
|
||||
|
||||
conversation_id = str(random.randint(10000, 99999))
|
||||
for i, turn in enumerate(conversation_data.conversation):
|
||||
row = {
|
||||
"input": turn.pregunta,
|
||||
"expected_output": turn.expected_output,
|
||||
"source": os.path.splitext(os.path.basename(file_path))[0],
|
||||
"type": "Multi-turn",
|
||||
"agent": settings.agent.name,
|
||||
"run_id": run_id,
|
||||
"conversation_id": conversation_id,
|
||||
"turn": i + 1,
|
||||
}
|
||||
all_rows.append(row)
|
||||
|
||||
else:
|
||||
generated_data = None
|
||||
for attempt in range(3):
|
||||
generated_data = generate_synthetic_question(
|
||||
model, file_content, file_path, q_type, q_def
|
||||
)
|
||||
if (
|
||||
generated_data
|
||||
and generated_data.expected_output
|
||||
and generated_data.expected_output.strip()
|
||||
):
|
||||
break
|
||||
console.print(
|
||||
f"[yellow]Empty answer for {q_type} on {os.path.basename(file_path)}. Retrying ({attempt + 1}/3)...[/yellow]"
|
||||
)
|
||||
generated_data = None
|
||||
|
||||
if not generated_data:
|
||||
console.print(
|
||||
f"[bold red]Failed to generate valid answer for {q_type} on {os.path.basename(file_path)} after 3 attempts. Skipping.[/bold red]"
|
||||
)
|
||||
continue
|
||||
|
||||
row = {
|
||||
"input": generated_data.pregunta,
|
||||
"expected_output": generated_data.expected_output,
|
||||
"source": os.path.splitext(os.path.basename(file_path))[0],
|
||||
"type": q_type,
|
||||
"agent": settings.agent.name,
|
||||
"run_id": run_id,
|
||||
}
|
||||
all_rows.append(row)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error processing file {file_path}: {e}[/bold red]")
|
||||
|
||||
if not all_rows:
|
||||
console.print("[bold yellow]No questions were generated.[/bold yellow]")
|
||||
return ""
|
||||
|
||||
df = pd.DataFrame(all_rows)
|
||||
|
||||
if output_csv:
|
||||
console.print(
|
||||
f"\n[bold green]Saving {len(df)} generated questions to {output_csv}...[/bold green]"
|
||||
)
|
||||
df.to_csv(output_csv, index=False, encoding="utf-8-sig")
|
||||
console.print("[bold green]Synthetic question generation complete.[/bold green]")
|
||||
else:
|
||||
console.print(
|
||||
f"\n[bold green]Saving {len(df)} generated questions to BigQuery...[/bold green]"
|
||||
)
|
||||
project_id = settings.bigquery.project_id or settings.project_id
|
||||
dataset_id = settings.bigquery.require_dataset_id
|
||||
table_name = settings.bigquery.synth_gen_table
|
||||
table_id = f"{project_id}.{dataset_id}.{table_name}"
|
||||
|
||||
console.print(f"Saving to BigQuery table: [bold cyan]{table_id}[/bold cyan]")
|
||||
try:
|
||||
if "conversation_id" not in df.columns:
|
||||
df["conversation_id"] = None
|
||||
if "turn" not in df.columns:
|
||||
df["turn"] = None
|
||||
|
||||
df.to_gbq(
|
||||
destination_table=f"{dataset_id}.{table_name}",
|
||||
project_id=project_id,
|
||||
if_exists="append",
|
||||
)
|
||||
console.print(
|
||||
f"Successfully saved {len(df)} rows to [bold green]{table_id}[/bold green]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while saving to BigQuery: {e}[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
console.print(f"[bold yellow]Finished run with ID: {run_id}[/bold yellow]")
|
||||
return run_id
|
||||
|
||||
Reference in New Issue
Block a user