forked from innovacion/Mayacontigo
190 lines
4.9 KiB
Python
190 lines
4.9 KiB
Python
import marimo
|
|
|
|
__generated_with = "0.13.15"
|
|
app = marimo.App(width="medium")
|
|
|
|
with app.setup:
|
|
import marimo as mo
|
|
|
|
from banortegpt.embedding.azure_ada import Ada
|
|
from banortegpt.vector.qdrant import Qdrant
|
|
|
|
ada = Ada.from_vault("banortegpt")
|
|
qdrant = Qdrant.from_vault("banortegpt")
|
|
|
|
collections = qdrant.list_collections()
|
|
|
|
|
|
@app.cell
|
|
def _():
|
|
import os
|
|
|
|
settings = (
|
|
mo.md(
|
|
"""
|
|
Content Field: {campo_texto}\n
|
|
Embedding Model: {embedding_model}\n
|
|
Collection: {collection}\n
|
|
Score Threshold: {threshold}\n
|
|
Synthetic Questions: {synthetic_questions}
|
|
"""
|
|
)
|
|
.batch(
|
|
campo_texto=mo.ui.text(value="page_content"),
|
|
embedding_model=mo.ui.text(value="text-embedding-3-large"),
|
|
collection=mo.ui.dropdown(collections, searchable=True),
|
|
threshold=mo.ui.number(value=0.5, step=0.1),
|
|
synthetic_questions=mo.ui.file(filetypes=[".json"]),
|
|
)
|
|
.form(bordered=True)
|
|
)
|
|
|
|
settings
|
|
return (settings,)
|
|
|
|
|
|
@app.cell
|
|
def _(settings):
|
|
import json
|
|
|
|
mo.stop(not settings.value)
|
|
|
|
stg = settings.value
|
|
|
|
EMBEDDING_MODEL = stg["embedding_model"]
|
|
COLLECTION = stg["collection"]
|
|
THRESHOLD = stg["threshold"]
|
|
QUESTIONS = json.loads(stg["synthetic_questions"][0].contents)
|
|
|
|
ada.model = EMBEDDING_MODEL
|
|
return COLLECTION, QUESTIONS, THRESHOLD
|
|
|
|
|
|
@app.cell
|
|
def _(COLLECTION, THRESHOLD):
|
|
import ranx
|
|
|
|
def create_qrels(questions):
|
|
qrels_dict = {}
|
|
|
|
for q in questions:
|
|
question = q["pregunta"]
|
|
source_ids = q["ids"]
|
|
|
|
qrels_dict[question] = {}
|
|
for id in source_ids:
|
|
qrels_dict[question][id] = 1
|
|
|
|
return ranx.Qrels(qrels_dict)
|
|
|
|
def create_run(questions):
|
|
run_dict = {}
|
|
|
|
for q in questions:
|
|
question = q["pregunta"]
|
|
|
|
embedding = ada.embed(question)
|
|
|
|
query_response = qdrant.client.query_points(
|
|
collection_name=COLLECTION,
|
|
query=embedding,
|
|
limit=100,
|
|
score_threshold=THRESHOLD,
|
|
)
|
|
|
|
run_dict[question] = {}
|
|
for point in query_response.points:
|
|
run_dict[question][point.id] = point.score
|
|
|
|
return ranx.Run(run_dict)
|
|
|
|
return create_qrels, create_run, ranx
|
|
|
|
|
|
@app.cell
|
|
def _(create_qrels, create_run, ranx):
|
|
def create_evals(questions, ks):
|
|
qrels = create_qrels(questions)
|
|
run = create_run(questions)
|
|
|
|
return [
|
|
ranx.evaluate(qrels, run, [f"precision@{k}", f"recall@{k}", f"ndcg@{k}"])
|
|
for k in ks
|
|
]
|
|
|
|
return (create_evals,)
|
|
|
|
|
|
@app.cell
|
|
def _():
|
|
import matplotlib.pyplot as plt
|
|
|
|
def plot_retrieval_metrics(results):
|
|
# Extract k values and metrics
|
|
k_values = [int(list(result.keys())[0].split("@")[1]) for result in results]
|
|
|
|
# Prepare data for plotting
|
|
precision_values = [
|
|
list(result.values())[0]
|
|
for result in results
|
|
if "precision" in list(result.keys())[0]
|
|
]
|
|
recall_values = [
|
|
list(result.values())[1]
|
|
for result in results
|
|
if "recall" in list(result.keys())[1]
|
|
]
|
|
ndcg_values = [
|
|
list(result.values())[2]
|
|
for result in results
|
|
if "ndcg" in list(result.keys())[2]
|
|
]
|
|
|
|
# Create a figure with three subplots
|
|
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
|
|
|
|
# Precision Plot
|
|
ax1.plot(k_values, precision_values, marker="o", linestyle="-", color="blue")
|
|
ax1.set_title("Precision @ K")
|
|
ax1.set_xlabel("Number of Retrieved Documents (K)")
|
|
ax1.set_ylabel("Precision")
|
|
ax1.set_xticks(k_values)
|
|
|
|
# Recall Plot
|
|
ax2.plot(k_values, recall_values, marker="o", linestyle="-", color="green")
|
|
ax2.set_title("Recall @ K")
|
|
ax2.set_xlabel("Number of Retrieved Documents (K)")
|
|
ax2.set_ylabel("Recall")
|
|
ax2.set_xticks(k_values)
|
|
|
|
# NDCG Plot
|
|
ax3.plot(k_values, ndcg_values, marker="o", linestyle="-", color="red")
|
|
ax3.set_title("NDCG @ K")
|
|
ax3.set_xlabel("Number of Retrieved Documents (K)")
|
|
ax3.set_ylabel("NDCG")
|
|
ax3.set_xticks(k_values)
|
|
|
|
# Add value labels
|
|
for ax, values in zip(
|
|
[ax1, ax2, ax3], [precision_values, recall_values, ndcg_values]
|
|
):
|
|
for i, v in enumerate(values):
|
|
ax.text(k_values[i], v, f"{v:.2f}", ha="center", va="bottom")
|
|
|
|
plt.tight_layout()
|
|
return plt.gca()
|
|
|
|
return (plot_retrieval_metrics,)
|
|
|
|
|
|
@app.cell
|
|
def _(QUESTIONS, create_evals, plot_retrieval_metrics):
|
|
results = create_evals(QUESTIONS, [1, 3, 5, 10, 20])
|
|
|
|
plot_retrieval_metrics(results)
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run()
|