First commit
This commit is contained in:
132
apps/keypoint-eval/src/keypoint_eval/loaders.py
Normal file
132
apps/keypoint-eval/src/keypoint_eval/loaders.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import pathlib
|
||||
import sqlite3
|
||||
|
||||
import pandas as pd
|
||||
from google.cloud import bigquery
|
||||
from rich.console import Console
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
|
||||
|
||||
def load_data_from_local_file(
|
||||
file_path: str, console: Console, run_id: str = None
|
||||
) -> pd.DataFrame:
|
||||
"""Loads evaluation data from a local CSV or SQLite file and returns a DataFrame."""
|
||||
console.print(f"Loading data from {file_path}...")
|
||||
path = pathlib.Path(file_path)
|
||||
if not path.exists():
|
||||
raise Exception(f"Error: File not found at {file_path}")
|
||||
|
||||
if path.suffix == ".csv":
|
||||
try:
|
||||
df = pd.read_csv(path)
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while reading the CSV file: {e}")
|
||||
|
||||
elif path.suffix in [".db", ".sqlite"]:
|
||||
try:
|
||||
con = sqlite3.connect(path)
|
||||
# Assuming table name is the file stem
|
||||
table_name = path.stem
|
||||
df = pd.read_sql(f"SELECT * FROM {table_name}", con)
|
||||
con.close()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while reading the SQLite DB: {e}")
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unsupported file type: {path.suffix}. Please use .csv or .db/.sqlite"
|
||||
)
|
||||
|
||||
# Check for required columns
|
||||
if (
|
||||
"input" not in df.columns
|
||||
or "expected_output" not in df.columns
|
||||
):
|
||||
raise Exception(
|
||||
"Error: The input file must contain 'input' and 'expected_output' columns."
|
||||
)
|
||||
df["agent"] = config.agent.name
|
||||
|
||||
print(f"{run_id=}")
|
||||
if run_id:
|
||||
if "run_id" in df.columns:
|
||||
df = df[df["run_id"] == run_id].copy()
|
||||
console.print(f"Filtered data for run_id: {run_id}")
|
||||
if df.empty:
|
||||
console.print(
|
||||
f"[yellow]Warning: No data found for run_id '{run_id}' in {file_path}.[/yellow]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]Warning: --run-id provided, but 'run_id' column not found in {file_path}. Using all data.[/yellow]"
|
||||
)
|
||||
|
||||
# Filter out unanswerable questions if 'type' column exists
|
||||
if "type" in df.columns:
|
||||
df = df[df["type"] != "Unanswerable"].copy()
|
||||
|
||||
df.dropna(subset=["input", "expected_output"], inplace=True)
|
||||
|
||||
console.print(f"Loaded {len(df)} questions for evaluation from {file_path}.")
|
||||
return df
|
||||
|
||||
|
||||
def load_data_from_bigquery(console: Console, run_id: str = None) -> pd.DataFrame:
|
||||
"""Loads evaluation data from the BigQuery table and returns a DataFrame."""
|
||||
console.print("Loading data from BigQuery...")
|
||||
bq_project_id = config.bigquery.project_id or config.project_id
|
||||
client = bigquery.Client(project=bq_project_id)
|
||||
table_ref = f"{bq_project_id}.{config.bigquery.dataset_id}.{config.bigquery.table_ids['synth_gen']}"
|
||||
|
||||
console.print(f"Querying table: {table_ref}")
|
||||
try:
|
||||
table = client.get_table(table_ref)
|
||||
all_columns = [schema.name for schema in table.schema]
|
||||
|
||||
select_cols = ["input", "expected_output"]
|
||||
if "category" in all_columns:
|
||||
select_cols.append("category")
|
||||
|
||||
query_parts = [f"SELECT {', '.join(select_cols)}", f"FROM `{table_ref}`"]
|
||||
|
||||
# Build WHERE clauses
|
||||
where_clauses = []
|
||||
if "type" in all_columns:
|
||||
where_clauses.append("type != 'Unanswerable'")
|
||||
if run_id:
|
||||
if "run_id" in all_columns:
|
||||
where_clauses.append(f"run_id = '{run_id}'")
|
||||
console.print(f"Filtering data for run_id: {run_id}")
|
||||
else:
|
||||
console.print(
|
||||
"[yellow]Warning: --run-id provided, but 'run_id' column not found in BigQuery table. Using all data.[/yellow]"
|
||||
)
|
||||
|
||||
if where_clauses:
|
||||
query_parts.append("WHERE " + " AND ".join(where_clauses))
|
||||
|
||||
query = "\n".join(query_parts)
|
||||
df = client.query(query).to_dataframe()
|
||||
|
||||
except Exception as e:
|
||||
if "Not found" in str(e):
|
||||
console.print(f"[bold red]Error: Table {table_ref} not found.[/bold red]")
|
||||
console.print(
|
||||
"Please ensure the table exists and the configuration in 'config.yaml' is correct."
|
||||
)
|
||||
raise
|
||||
else:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while querying BigQuery: {e}[/bold red]"
|
||||
)
|
||||
raise
|
||||
|
||||
df.dropna(subset=["input", "expected_output"], inplace=True)
|
||||
df["agent"] = config.agent.name
|
||||
|
||||
console.print(f"Loaded {len(df)} questions for evaluation.")
|
||||
if run_id and df.empty:
|
||||
console.print(
|
||||
f"[yellow]Warning: No data found for run_id '{run_id}' in BigQuery.[/yellow]"
|
||||
)
|
||||
return df
|
||||
Reference in New Issue
Block a user