forked from innovacion/Mayacontigo
ic
This commit is contained in:
189
notebooks/search-evaluator/main.py
Normal file
189
notebooks/search-evaluator/main.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.13.15"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
with app.setup:
|
||||
import marimo as mo
|
||||
|
||||
from banortegpt.embedding.azure_ada import Ada
|
||||
from banortegpt.vector.qdrant import Qdrant
|
||||
|
||||
ada = Ada.from_vault("banortegpt")
|
||||
qdrant = Qdrant.from_vault("banortegpt")
|
||||
|
||||
collections = qdrant.list_collections()
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import os
|
||||
|
||||
settings = (
|
||||
mo.md(
|
||||
"""
|
||||
Content Field: {campo_texto}\n
|
||||
Embedding Model: {embedding_model}\n
|
||||
Collection: {collection}\n
|
||||
Score Threshold: {threshold}\n
|
||||
Synthetic Questions: {synthetic_questions}
|
||||
"""
|
||||
)
|
||||
.batch(
|
||||
campo_texto=mo.ui.text(value="page_content"),
|
||||
embedding_model=mo.ui.text(value="text-embedding-3-large"),
|
||||
collection=mo.ui.dropdown(collections, searchable=True),
|
||||
threshold=mo.ui.number(value=0.5, step=0.1),
|
||||
synthetic_questions=mo.ui.file(filetypes=[".json"]),
|
||||
)
|
||||
.form(bordered=True)
|
||||
)
|
||||
|
||||
settings
|
||||
return (settings,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(settings):
|
||||
import json
|
||||
|
||||
mo.stop(not settings.value)
|
||||
|
||||
stg = settings.value
|
||||
|
||||
EMBEDDING_MODEL = stg["embedding_model"]
|
||||
COLLECTION = stg["collection"]
|
||||
THRESHOLD = stg["threshold"]
|
||||
QUESTIONS = json.loads(stg["synthetic_questions"][0].contents)
|
||||
|
||||
ada.model = EMBEDDING_MODEL
|
||||
return COLLECTION, QUESTIONS, THRESHOLD
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(COLLECTION, THRESHOLD):
|
||||
import ranx
|
||||
|
||||
def create_qrels(questions):
|
||||
qrels_dict = {}
|
||||
|
||||
for q in questions:
|
||||
question = q["pregunta"]
|
||||
source_ids = q["ids"]
|
||||
|
||||
qrels_dict[question] = {}
|
||||
for id in source_ids:
|
||||
qrels_dict[question][id] = 1
|
||||
|
||||
return ranx.Qrels(qrels_dict)
|
||||
|
||||
def create_run(questions):
|
||||
run_dict = {}
|
||||
|
||||
for q in questions:
|
||||
question = q["pregunta"]
|
||||
|
||||
embedding = ada.embed(question)
|
||||
|
||||
query_response = qdrant.client.query_points(
|
||||
collection_name=COLLECTION,
|
||||
query=embedding,
|
||||
limit=100,
|
||||
score_threshold=THRESHOLD,
|
||||
)
|
||||
|
||||
run_dict[question] = {}
|
||||
for point in query_response.points:
|
||||
run_dict[question][point.id] = point.score
|
||||
|
||||
return ranx.Run(run_dict)
|
||||
|
||||
return create_qrels, create_run, ranx
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(create_qrels, create_run, ranx):
|
||||
def create_evals(questions, ks):
|
||||
qrels = create_qrels(questions)
|
||||
run = create_run(questions)
|
||||
|
||||
return [
|
||||
ranx.evaluate(qrels, run, [f"precision@{k}", f"recall@{k}", f"ndcg@{k}"])
|
||||
for k in ks
|
||||
]
|
||||
|
||||
return (create_evals,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def plot_retrieval_metrics(results):
|
||||
# Extract k values and metrics
|
||||
k_values = [int(list(result.keys())[0].split("@")[1]) for result in results]
|
||||
|
||||
# Prepare data for plotting
|
||||
precision_values = [
|
||||
list(result.values())[0]
|
||||
for result in results
|
||||
if "precision" in list(result.keys())[0]
|
||||
]
|
||||
recall_values = [
|
||||
list(result.values())[1]
|
||||
for result in results
|
||||
if "recall" in list(result.keys())[1]
|
||||
]
|
||||
ndcg_values = [
|
||||
list(result.values())[2]
|
||||
for result in results
|
||||
if "ndcg" in list(result.keys())[2]
|
||||
]
|
||||
|
||||
# Create a figure with three subplots
|
||||
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
|
||||
|
||||
# Precision Plot
|
||||
ax1.plot(k_values, precision_values, marker="o", linestyle="-", color="blue")
|
||||
ax1.set_title("Precision @ K")
|
||||
ax1.set_xlabel("Number of Retrieved Documents (K)")
|
||||
ax1.set_ylabel("Precision")
|
||||
ax1.set_xticks(k_values)
|
||||
|
||||
# Recall Plot
|
||||
ax2.plot(k_values, recall_values, marker="o", linestyle="-", color="green")
|
||||
ax2.set_title("Recall @ K")
|
||||
ax2.set_xlabel("Number of Retrieved Documents (K)")
|
||||
ax2.set_ylabel("Recall")
|
||||
ax2.set_xticks(k_values)
|
||||
|
||||
# NDCG Plot
|
||||
ax3.plot(k_values, ndcg_values, marker="o", linestyle="-", color="red")
|
||||
ax3.set_title("NDCG @ K")
|
||||
ax3.set_xlabel("Number of Retrieved Documents (K)")
|
||||
ax3.set_ylabel("NDCG")
|
||||
ax3.set_xticks(k_values)
|
||||
|
||||
# Add value labels
|
||||
for ax, values in zip(
|
||||
[ax1, ax2, ax3], [precision_values, recall_values, ndcg_values]
|
||||
):
|
||||
for i, v in enumerate(values):
|
||||
ax.text(k_values[i], v, f"{v:.2f}", ha="center", va="bottom")
|
||||
|
||||
plt.tight_layout()
|
||||
return plt.gca()
|
||||
|
||||
return (plot_retrieval_metrics,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(QUESTIONS, create_evals, plot_retrieval_metrics):
|
||||
results = create_evals(QUESTIONS, [1, 3, 5, 10, 20])
|
||||
|
||||
plot_retrieval_metrics(results)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
Reference in New Issue
Block a user