Migrate search evaluator from vertex sdk to genai
This commit is contained in:
@@ -3,7 +3,7 @@ import sqlite3
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import typer
|
import typer
|
||||||
import vertexai
|
from google import genai
|
||||||
from google.cloud import bigquery
|
from google.cloud import bigquery
|
||||||
from google.cloud.aiplatform.matching_engine import MatchingEngineIndexEndpoint
|
from google.cloud.aiplatform.matching_engine import MatchingEngineIndexEndpoint
|
||||||
from ranx import Qrels, Run
|
from ranx import Qrels, Run
|
||||||
@@ -11,7 +11,6 @@ from ranx import evaluate as ranx_evaluate
|
|||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.progress import track
|
from rich.progress import track
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from vertexai.language_models import TextEmbeddingModel
|
|
||||||
|
|
||||||
from va_evaluator.config import Settings
|
from va_evaluator.config import Settings
|
||||||
|
|
||||||
@@ -111,8 +110,11 @@ def run_evaluation(
|
|||||||
)
|
)
|
||||||
console.print(f"Index Name: [bold cyan]{settings.index.name}[/bold cyan]")
|
console.print(f"Index Name: [bold cyan]{settings.index.name}[/bold cyan]")
|
||||||
|
|
||||||
vertexai.init(project=settings.project_id, location=settings.location)
|
client = genai.Client(
|
||||||
embedding_model = TextEmbeddingModel.from_pretrained(settings.agent.embedding_model)
|
vertexai=True,
|
||||||
|
project=settings.project_id,
|
||||||
|
location=settings.location,
|
||||||
|
)
|
||||||
index_endpoint = MatchingEngineIndexEndpoint(settings.index.require_endpoint)
|
index_endpoint = MatchingEngineIndexEndpoint(settings.index.require_endpoint)
|
||||||
|
|
||||||
# Prepare qrels
|
# Prepare qrels
|
||||||
@@ -126,8 +128,11 @@ def run_evaluation(
|
|||||||
run_data = {}
|
run_data = {}
|
||||||
detailed_results_list = []
|
detailed_results_list = []
|
||||||
for _, row in track(df.iterrows(), total=len(df), description="Preparing run..."):
|
for _, row in track(df.iterrows(), total=len(df), description="Preparing run..."):
|
||||||
embeddings = embedding_model.get_embeddings([row["question"]])
|
result = client.models.embed_content(
|
||||||
question_embedding = embeddings[0].values
|
model=settings.agent.embedding_model,
|
||||||
|
contents=row["question"],
|
||||||
|
)
|
||||||
|
question_embedding = result.embeddings[0].values
|
||||||
|
|
||||||
results = index_endpoint.find_neighbors(
|
results = index_endpoint.find_neighbors(
|
||||||
deployed_index_id=settings.index.require_deployment,
|
deployed_index_id=settings.index.require_deployment,
|
||||||
|
|||||||
Reference in New Issue
Block a user