133 lines
4.6 KiB
Python
133 lines
4.6 KiB
Python
import pathlib
|
|
import sqlite3
|
|
|
|
import pandas as pd
|
|
from google.cloud import bigquery
|
|
from rich.console import Console
|
|
|
|
from rag_eval.config import settings as config
|
|
|
|
|
|
def load_data_from_local_file(
|
|
file_path: str, console: Console, run_id: str = None
|
|
) -> pd.DataFrame:
|
|
"""Loads evaluation data from a local CSV or SQLite file and returns a DataFrame."""
|
|
console.print(f"Loading data from {file_path}...")
|
|
path = pathlib.Path(file_path)
|
|
if not path.exists():
|
|
raise Exception(f"Error: File not found at {file_path}")
|
|
|
|
if path.suffix == ".csv":
|
|
try:
|
|
df = pd.read_csv(path)
|
|
except Exception as e:
|
|
raise Exception(f"An error occurred while reading the CSV file: {e}")
|
|
|
|
elif path.suffix in [".db", ".sqlite"]:
|
|
try:
|
|
con = sqlite3.connect(path)
|
|
# Assuming table name is the file stem
|
|
table_name = path.stem
|
|
df = pd.read_sql(f"SELECT * FROM {table_name}", con)
|
|
con.close()
|
|
except Exception as e:
|
|
raise Exception(f"An error occurred while reading the SQLite DB: {e}")
|
|
else:
|
|
raise Exception(
|
|
f"Unsupported file type: {path.suffix}. Please use .csv or .db/.sqlite"
|
|
)
|
|
|
|
# Check for required columns
|
|
if (
|
|
"input" not in df.columns
|
|
or "expected_output" not in df.columns
|
|
):
|
|
raise Exception(
|
|
"Error: The input file must contain 'input' and 'expected_output' columns."
|
|
)
|
|
df["agent"] = config.agent.name
|
|
|
|
print(f"{run_id=}")
|
|
if run_id:
|
|
if "run_id" in df.columns:
|
|
df = df[df["run_id"] == run_id].copy()
|
|
console.print(f"Filtered data for run_id: {run_id}")
|
|
if df.empty:
|
|
console.print(
|
|
f"[yellow]Warning: No data found for run_id '{run_id}' in {file_path}.[/yellow]"
|
|
)
|
|
else:
|
|
console.print(
|
|
f"[yellow]Warning: --run-id provided, but 'run_id' column not found in {file_path}. Using all data.[/yellow]"
|
|
)
|
|
|
|
# Filter out unanswerable questions if 'type' column exists
|
|
if "type" in df.columns:
|
|
df = df[df["type"] != "Unanswerable"].copy()
|
|
|
|
df.dropna(subset=["input", "expected_output"], inplace=True)
|
|
|
|
console.print(f"Loaded {len(df)} questions for evaluation from {file_path}.")
|
|
return df
|
|
|
|
|
|
def load_data_from_bigquery(console: Console, run_id: str = None) -> pd.DataFrame:
|
|
"""Loads evaluation data from the BigQuery table and returns a DataFrame."""
|
|
console.print("Loading data from BigQuery...")
|
|
bq_project_id = config.bigquery.project_id or config.project_id
|
|
client = bigquery.Client(project=bq_project_id)
|
|
table_ref = f"{bq_project_id}.{config.bigquery.dataset_id}.{config.bigquery.table_ids['synth_gen']}"
|
|
|
|
console.print(f"Querying table: {table_ref}")
|
|
try:
|
|
table = client.get_table(table_ref)
|
|
all_columns = [schema.name for schema in table.schema]
|
|
|
|
select_cols = ["input", "expected_output"]
|
|
if "category" in all_columns:
|
|
select_cols.append("category")
|
|
|
|
query_parts = [f"SELECT {', '.join(select_cols)}", f"FROM `{table_ref}`"]
|
|
|
|
# Build WHERE clauses
|
|
where_clauses = []
|
|
if "type" in all_columns:
|
|
where_clauses.append("type != 'Unanswerable'")
|
|
if run_id:
|
|
if "run_id" in all_columns:
|
|
where_clauses.append(f"run_id = '{run_id}'")
|
|
console.print(f"Filtering data for run_id: {run_id}")
|
|
else:
|
|
console.print(
|
|
"[yellow]Warning: --run-id provided, but 'run_id' column not found in BigQuery table. Using all data.[/yellow]"
|
|
)
|
|
|
|
if where_clauses:
|
|
query_parts.append("WHERE " + " AND ".join(where_clauses))
|
|
|
|
query = "\n".join(query_parts)
|
|
df = client.query(query).to_dataframe()
|
|
|
|
except Exception as e:
|
|
if "Not found" in str(e):
|
|
console.print(f"[bold red]Error: Table {table_ref} not found.[/bold red]")
|
|
console.print(
|
|
"Please ensure the table exists and the configuration in 'config.yaml' is correct."
|
|
)
|
|
raise
|
|
else:
|
|
console.print(
|
|
f"[bold red]An error occurred while querying BigQuery: {e}[/bold red]"
|
|
)
|
|
raise
|
|
|
|
df.dropna(subset=["input", "expected_output"], inplace=True)
|
|
df["agent"] = config.agent.name
|
|
|
|
console.print(f"Loaded {len(df)} questions for evaluation.")
|
|
if run_id and df.empty:
|
|
console.print(
|
|
f"[yellow]Warning: No data found for run_id '{run_id}' in BigQuery.[/yellow]"
|
|
)
|
|
return df
|