diff --git a/src/va_evaluator/search_metrics_evaluator.py b/src/va_evaluator/search_metrics_evaluator.py index 060486c..adf0a05 100644 --- a/src/va_evaluator/search_metrics_evaluator.py +++ b/src/va_evaluator/search_metrics_evaluator.py @@ -3,7 +3,7 @@ import sqlite3 import pandas as pd import typer -import vertexai +from google import genai from google.cloud import bigquery from google.cloud.aiplatform.matching_engine import MatchingEngineIndexEndpoint from ranx import Qrels, Run @@ -11,7 +11,6 @@ from ranx import evaluate as ranx_evaluate from rich.console import Console from rich.progress import track from rich.table import Table -from vertexai.language_models import TextEmbeddingModel from va_evaluator.config import Settings @@ -111,8 +110,11 @@ def run_evaluation( ) console.print(f"Index Name: [bold cyan]{settings.index.name}[/bold cyan]") - vertexai.init(project=settings.project_id, location=settings.location) - embedding_model = TextEmbeddingModel.from_pretrained(settings.agent.embedding_model) + client = genai.Client( + vertexai=True, + project=settings.project_id, + location=settings.location, + ) index_endpoint = MatchingEngineIndexEndpoint(settings.index.require_endpoint) # Prepare qrels @@ -126,8 +128,11 @@ def run_evaluation( run_data = {} detailed_results_list = [] for _, row in track(df.iterrows(), total=len(df), description="Preparing run..."): - embeddings = embedding_model.get_embeddings([row["question"]]) - question_embedding = embeddings[0].values + result = client.models.embed_content( + model=settings.agent.embedding_model, + contents=row["question"], + ) + question_embedding = result.embeddings[0].values results = index_endpoint.find_neighbors( deployed_index_id=settings.index.require_deployment,