import marimo __generated_with = "0.13.15" app = marimo.App(width="medium") with app.setup: import marimo as mo from banortegpt.embedding.azure_ada import Ada from banortegpt.vector.qdrant import Qdrant ada = Ada.from_vault("banortegpt") qdrant = Qdrant.from_vault("banortegpt") collections = qdrant.list_collections() @app.cell def _(): import os settings = ( mo.md( """ Content Field: {campo_texto}\n Embedding Model: {embedding_model}\n Collection: {collection}\n Score Threshold: {threshold}\n Synthetic Questions: {synthetic_questions} """ ) .batch( campo_texto=mo.ui.text(value="page_content"), embedding_model=mo.ui.text(value="text-embedding-3-large"), collection=mo.ui.dropdown(collections, searchable=True), threshold=mo.ui.number(value=0.5, step=0.1), synthetic_questions=mo.ui.file(filetypes=[".json"]), ) .form(bordered=True) ) settings return (settings,) @app.cell def _(settings): import json mo.stop(not settings.value) stg = settings.value EMBEDDING_MODEL = stg["embedding_model"] COLLECTION = stg["collection"] THRESHOLD = stg["threshold"] QUESTIONS = json.loads(stg["synthetic_questions"][0].contents) ada.model = EMBEDDING_MODEL return COLLECTION, QUESTIONS, THRESHOLD @app.cell def _(COLLECTION, THRESHOLD): import ranx def create_qrels(questions): qrels_dict = {} for q in questions: question = q["pregunta"] source_ids = q["ids"] qrels_dict[question] = {} for id in source_ids: qrels_dict[question][id] = 1 return ranx.Qrels(qrels_dict) def create_run(questions): run_dict = {} for q in questions: question = q["pregunta"] embedding = ada.embed(question) query_response = qdrant.client.query_points( collection_name=COLLECTION, query=embedding, limit=100, score_threshold=THRESHOLD, ) run_dict[question] = {} for point in query_response.points: run_dict[question][point.id] = point.score return ranx.Run(run_dict) return create_qrels, create_run, ranx @app.cell def _(create_qrels, create_run, ranx): def create_evals(questions, ks): qrels = create_qrels(questions) run = create_run(questions) return [ ranx.evaluate(qrels, run, [f"precision@{k}", f"recall@{k}", f"ndcg@{k}"]) for k in ks ] return (create_evals,) @app.cell def _(): import matplotlib.pyplot as plt def plot_retrieval_metrics(results): # Extract k values and metrics k_values = [int(list(result.keys())[0].split("@")[1]) for result in results] # Prepare data for plotting precision_values = [ list(result.values())[0] for result in results if "precision" in list(result.keys())[0] ] recall_values = [ list(result.values())[1] for result in results if "recall" in list(result.keys())[1] ] ndcg_values = [ list(result.values())[2] for result in results if "ndcg" in list(result.keys())[2] ] # Create a figure with three subplots fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) # Precision Plot ax1.plot(k_values, precision_values, marker="o", linestyle="-", color="blue") ax1.set_title("Precision @ K") ax1.set_xlabel("Number of Retrieved Documents (K)") ax1.set_ylabel("Precision") ax1.set_xticks(k_values) # Recall Plot ax2.plot(k_values, recall_values, marker="o", linestyle="-", color="green") ax2.set_title("Recall @ K") ax2.set_xlabel("Number of Retrieved Documents (K)") ax2.set_ylabel("Recall") ax2.set_xticks(k_values) # NDCG Plot ax3.plot(k_values, ndcg_values, marker="o", linestyle="-", color="red") ax3.set_title("NDCG @ K") ax3.set_xlabel("Number of Retrieved Documents (K)") ax3.set_ylabel("NDCG") ax3.set_xticks(k_values) # Add value labels for ax, values in zip( [ax1, ax2, ax3], [precision_values, recall_values, ndcg_values] ): for i, v in enumerate(values): ax.text(k_values[i], v, f"{v:.2f}", ha="center", va="bottom") plt.tight_layout() return plt.gca() return (plot_retrieval_metrics,) @app.cell def _(QUESTIONS, create_evals, plot_retrieval_metrics): results = create_evals(QUESTIONS, [1, 3, 5, 10, 20]) plot_retrieval_metrics(results) return if __name__ == "__main__": app.run()