Migrate search evaluator from vertex sdk to genai

This commit is contained in:
Anibal Angulo
2026-02-23 17:54:30 +00:00
parent 099f6a50d1
commit ddd13805c9

View File

@@ -3,7 +3,7 @@ import sqlite3
import pandas as pd import pandas as pd
import typer import typer
import vertexai from google import genai
from google.cloud import bigquery from google.cloud import bigquery
from google.cloud.aiplatform.matching_engine import MatchingEngineIndexEndpoint from google.cloud.aiplatform.matching_engine import MatchingEngineIndexEndpoint
from ranx import Qrels, Run from ranx import Qrels, Run
@@ -11,7 +11,6 @@ from ranx import evaluate as ranx_evaluate
from rich.console import Console from rich.console import Console
from rich.progress import track from rich.progress import track
from rich.table import Table from rich.table import Table
from vertexai.language_models import TextEmbeddingModel
from va_evaluator.config import Settings from va_evaluator.config import Settings
@@ -111,8 +110,11 @@ def run_evaluation(
) )
console.print(f"Index Name: [bold cyan]{settings.index.name}[/bold cyan]") console.print(f"Index Name: [bold cyan]{settings.index.name}[/bold cyan]")
vertexai.init(project=settings.project_id, location=settings.location) client = genai.Client(
embedding_model = TextEmbeddingModel.from_pretrained(settings.agent.embedding_model) vertexai=True,
project=settings.project_id,
location=settings.location,
)
index_endpoint = MatchingEngineIndexEndpoint(settings.index.require_endpoint) index_endpoint = MatchingEngineIndexEndpoint(settings.index.require_endpoint)
# Prepare qrels # Prepare qrels
@@ -126,8 +128,11 @@ def run_evaluation(
run_data = {} run_data = {}
detailed_results_list = [] detailed_results_list = []
for _, row in track(df.iterrows(), total=len(df), description="Preparing run..."): for _, row in track(df.iterrows(), total=len(df), description="Preparing run..."):
embeddings = embedding_model.get_embeddings([row["question"]]) result = client.models.embed_content(
question_embedding = embeddings[0].values model=settings.agent.embedding_model,
contents=row["question"],
)
question_embedding = result.embeddings[0].values
results = index_endpoint.find_neighbors( results = index_endpoint.find_neighbors(
deployed_index_id=settings.index.require_deployment, deployed_index_id=settings.index.require_deployment,