This commit is contained in:
Rogelio
2025-10-13 18:16:25 +00:00
parent 739f087cef
commit 325f1ef439
415 changed files with 46870 additions and 0 deletions

View File

@@ -0,0 +1,189 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
with app.setup:
import marimo as mo
from banortegpt.embedding.azure_ada import Ada
from banortegpt.vector.qdrant import Qdrant
ada = Ada.from_vault("banortegpt")
qdrant = Qdrant.from_vault("banortegpt")
collections = qdrant.list_collections()
@app.cell
def _():
import os
settings = (
mo.md(
"""
Content Field: {campo_texto}\n
Embedding Model: {embedding_model}\n
Collection: {collection}\n
Score Threshold: {threshold}\n
Synthetic Questions: {synthetic_questions}
"""
)
.batch(
campo_texto=mo.ui.text(value="page_content"),
embedding_model=mo.ui.text(value="text-embedding-3-large"),
collection=mo.ui.dropdown(collections, searchable=True),
threshold=mo.ui.number(value=0.5, step=0.1),
synthetic_questions=mo.ui.file(filetypes=[".json"]),
)
.form(bordered=True)
)
settings
return (settings,)
@app.cell
def _(settings):
import json
mo.stop(not settings.value)
stg = settings.value
EMBEDDING_MODEL = stg["embedding_model"]
COLLECTION = stg["collection"]
THRESHOLD = stg["threshold"]
QUESTIONS = json.loads(stg["synthetic_questions"][0].contents)
ada.model = EMBEDDING_MODEL
return COLLECTION, QUESTIONS, THRESHOLD
@app.cell
def _(COLLECTION, THRESHOLD):
import ranx
def create_qrels(questions):
qrels_dict = {}
for q in questions:
question = q["pregunta"]
source_ids = q["ids"]
qrels_dict[question] = {}
for id in source_ids:
qrels_dict[question][id] = 1
return ranx.Qrels(qrels_dict)
def create_run(questions):
run_dict = {}
for q in questions:
question = q["pregunta"]
embedding = ada.embed(question)
query_response = qdrant.client.query_points(
collection_name=COLLECTION,
query=embedding,
limit=100,
score_threshold=THRESHOLD,
)
run_dict[question] = {}
for point in query_response.points:
run_dict[question][point.id] = point.score
return ranx.Run(run_dict)
return create_qrels, create_run, ranx
@app.cell
def _(create_qrels, create_run, ranx):
def create_evals(questions, ks):
qrels = create_qrels(questions)
run = create_run(questions)
return [
ranx.evaluate(qrels, run, [f"precision@{k}", f"recall@{k}", f"ndcg@{k}"])
for k in ks
]
return (create_evals,)
@app.cell
def _():
import matplotlib.pyplot as plt
def plot_retrieval_metrics(results):
# Extract k values and metrics
k_values = [int(list(result.keys())[0].split("@")[1]) for result in results]
# Prepare data for plotting
precision_values = [
list(result.values())[0]
for result in results
if "precision" in list(result.keys())[0]
]
recall_values = [
list(result.values())[1]
for result in results
if "recall" in list(result.keys())[1]
]
ndcg_values = [
list(result.values())[2]
for result in results
if "ndcg" in list(result.keys())[2]
]
# Create a figure with three subplots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
# Precision Plot
ax1.plot(k_values, precision_values, marker="o", linestyle="-", color="blue")
ax1.set_title("Precision @ K")
ax1.set_xlabel("Number of Retrieved Documents (K)")
ax1.set_ylabel("Precision")
ax1.set_xticks(k_values)
# Recall Plot
ax2.plot(k_values, recall_values, marker="o", linestyle="-", color="green")
ax2.set_title("Recall @ K")
ax2.set_xlabel("Number of Retrieved Documents (K)")
ax2.set_ylabel("Recall")
ax2.set_xticks(k_values)
# NDCG Plot
ax3.plot(k_values, ndcg_values, marker="o", linestyle="-", color="red")
ax3.set_title("NDCG @ K")
ax3.set_xlabel("Number of Retrieved Documents (K)")
ax3.set_ylabel("NDCG")
ax3.set_xticks(k_values)
# Add value labels
for ax, values in zip(
[ax1, ax2, ax3], [precision_values, recall_values, ndcg_values]
):
for i, v in enumerate(values):
ax.text(k_values[i], v, f"{v:.2f}", ha="center", va="bottom")
plt.tight_layout()
return plt.gca()
return (plot_retrieval_metrics,)
@app.cell
def _(QUESTIONS, create_evals, plot_retrieval_metrics):
results = create_evals(QUESTIONS, [1, 3, 5, 10, 20])
plot_retrieval_metrics(results)
return
if __name__ == "__main__":
app.run()