First commit
This commit is contained in:
2
src/rag_eval/__init__.py
Normal file
2
src/rag_eval/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def main() -> None:
|
||||
print("Hello from rag-eval!")
|
||||
196
src/rag_eval/agent.py
Normal file
196
src/rag_eval/agent.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import structlog
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from google.genai import types
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
MAX_CONCURRENT_LLM_CALLS = 20
|
||||
_llm_semaphore = asyncio.Semaphore(MAX_CONCURRENT_LLM_CALLS)
|
||||
|
||||
class Agent:
|
||||
"""A class to handle the RAG workflow."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the Agent class."""
|
||||
self.settings = settings.agent
|
||||
self.index_settings = settings.index
|
||||
self.model = self.settings.language_model
|
||||
self.system_prompt = self.settings.instructions
|
||||
self.llm = VertexAILLM(
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
thinking=self.settings.thinking,
|
||||
)
|
||||
self.vector_search = GoogleCloudVectorSearch(
|
||||
project_id=settings.project_id,
|
||||
location=settings.location,
|
||||
bucket=settings.bucket,
|
||||
index_name=self.index_settings.name,
|
||||
)
|
||||
self.vector_search.load_index_endpoint(self.index_settings.endpoint)
|
||||
self.embedder = VertexAIEmbedder(
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
model_name=self.settings.embedding_model, task="RETRIEVAL_QUERY"
|
||||
)
|
||||
self.min_sim = 0.60
|
||||
|
||||
def call(self, query: str | list[dict[str, str]]) -> str:
|
||||
"""Calls the LLM with the provided query and tools.
|
||||
|
||||
Args:
|
||||
query: The user's query.
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
search_query = query
|
||||
else:
|
||||
search_query = query[-1]["content"]
|
||||
|
||||
context = self.search(search_query)
|
||||
user_prompt = f"{search_query}\n\n{context}"
|
||||
|
||||
contents = []
|
||||
if isinstance(query, str):
|
||||
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
|
||||
else:
|
||||
for turn in query[:-1]:
|
||||
role = "model" if turn["role"] == "assistant" else "user"
|
||||
contents.append(
|
||||
types.Content(role=role, parts=[types.Part(text=turn["content"])])
|
||||
)
|
||||
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
|
||||
|
||||
generation = self.llm.generate(
|
||||
model=self.model,
|
||||
prompt=contents,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
logger.info(f"total usage={generation.usage}")
|
||||
logger.info(f"costo ${generation.usage.get_cost(self.model)} MXN")
|
||||
|
||||
return generation.text
|
||||
|
||||
def search(self, query: str ) -> str:
|
||||
"""Searches the vector index for the given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
"""
|
||||
logger.debug(f"Search term: {query}")
|
||||
|
||||
query_embedding = self.embedder.generate_embedding(query)
|
||||
search_results = self.vector_search.run_query(
|
||||
deployed_index_id=self.index_settings.deployment,
|
||||
query=query_embedding,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
|
||||
cutoff= max_sim * 0.9
|
||||
|
||||
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
|
||||
|
||||
logger.debug(f"{max_sim=}")
|
||||
logger.debug(f"{cutoff=}")
|
||||
logger.debug(f"chunks={[s['id'] for s in search_results]}")
|
||||
logger.debug(f"distancias={[s['distance'] for s in search_results]}")
|
||||
|
||||
return self._format_results(search_results)
|
||||
|
||||
async def async_search(self, query: str) -> str:
|
||||
"""Searches the vector index for the given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
query_embedding = await self.embedder.async_generate_embedding(query)
|
||||
t_embed = time.perf_counter()
|
||||
|
||||
search_results = await self.vector_search.async_run_query(
|
||||
deployed_index_id=self.index_settings.deployment,
|
||||
query=query_embedding,
|
||||
limit=5,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
|
||||
cutoff = max_sim * 0.9
|
||||
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
|
||||
|
||||
logger.info(
|
||||
"async_search.timing",
|
||||
embedding_ms=round((t_embed - t0) * 1000, 1),
|
||||
vector_search_ms=round((t_search - t_embed) * 1000, 1),
|
||||
total_ms=round((t_search - t0) * 1000, 1),
|
||||
chunks=[s["id"] for s in search_results],
|
||||
)
|
||||
|
||||
return self._format_results(search_results)
|
||||
|
||||
def _format_results(self, search_results):
|
||||
formatted_results = [
|
||||
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
]
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
async def async_call(self, query: str) -> str:
|
||||
"""Calls the LLM with the provided query and tools.
|
||||
|
||||
Args:
|
||||
query: The user's query.
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
"""
|
||||
t_start = time.perf_counter()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
context = await self.async_search(query)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
contents = [types.Content(role="user", parts=[types.Part(text=f"{query}\n\n{context}")])]
|
||||
|
||||
t1 = time.perf_counter()
|
||||
async with _llm_semaphore:
|
||||
generation = await self.llm.async_generate(
|
||||
model=self.model,
|
||||
prompt=contents,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
t_llm = time.perf_counter()
|
||||
|
||||
t_end = time.perf_counter()
|
||||
logger.info(
|
||||
"async_call.timing",
|
||||
total_ms=round((t_end - t_start) * 1000, 1),
|
||||
stages=[
|
||||
{"stage": "search", "ms": round((t_search - t0) * 1000, 1)},
|
||||
{"stage": "llm", "ms": round((t_llm - t1) * 1000, 1)},
|
||||
],
|
||||
llm_iterations=1,
|
||||
prompt_tokens=generation.usage.prompt_tokens,
|
||||
response_tokens=generation.usage.response_tokens,
|
||||
cost_mxn=generation.usage.get_cost(self.model),
|
||||
)
|
||||
|
||||
return generation.text
|
||||
140
src/rag_eval/cli.py
Normal file
140
src/rag_eval/cli.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Main CLI for the RAG evaluation tool."""
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from rag_eval.agent import Agent
|
||||
from rag_eval.pipelines.submit import app as pipelines_app
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
app = typer.Typer()
|
||||
console = Console()
|
||||
|
||||
|
||||
@app.command()
|
||||
def ask(
|
||||
message: Annotated[str, typer.Argument(help="The message to send to the agent.")],
|
||||
verbose: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Set verbosity. -v: INFO, -vv: DEBUG, -vvv: DEBUG with LLM logs.",
|
||||
count=True,
|
||||
),
|
||||
] = 0,
|
||||
):
|
||||
"""Sends a single message to a specified agent and prints the response."""
|
||||
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG, logging.DEBUG]
|
||||
level_index = min(verbose, len(log_levels) - 1)
|
||||
log_level = log_levels[level_index]
|
||||
|
||||
# Determine which loggers to show
|
||||
loggers_to_configure = ["rag_eval"]
|
||||
if verbose >= 3:
|
||||
loggers_to_configure.append("llm")
|
||||
|
||||
# Create a single handler for all loggers
|
||||
handler = RichHandler(rich_tracebacks=True, show_path=False)
|
||||
|
||||
for logger_name in loggers_to_configure:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(log_level)
|
||||
logger.propagate = False
|
||||
if not logger.handlers:
|
||||
logger.addHandler(handler)
|
||||
|
||||
try:
|
||||
console.print("[bold blue]Initializing agent...[/bold blue]")
|
||||
|
||||
agent = Agent()
|
||||
|
||||
console.print("[bold blue]Sending message...[/bold blue]")
|
||||
response = agent.call(query=message)
|
||||
|
||||
console.print("\n[bold green]Agent Response:[/bold green]")
|
||||
console.print(response)
|
||||
|
||||
except ValueError as e:
|
||||
console.print(f"[bold red]Error: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An unexpected error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@app.command()
|
||||
def chat(
|
||||
verbose: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Set verbosity. -v: INFO, -vv: DEBUG, -vvv: DEBUG with LLM logs.",
|
||||
count=True,
|
||||
),
|
||||
] = 0,
|
||||
):
|
||||
"""Starts an interactive chat session with a specified agent."""
|
||||
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG, logging.DEBUG]
|
||||
level_index = min(verbose, len(log_levels) - 1)
|
||||
log_level = log_levels[level_index]
|
||||
|
||||
# Determine which loggers to show
|
||||
loggers_to_configure = ["rag_eval"]
|
||||
if verbose >= 3:
|
||||
loggers_to_configure.append("llm")
|
||||
|
||||
# Create a single handler for all loggers
|
||||
handler = RichHandler(rich_tracebacks=True, show_path=False)
|
||||
|
||||
for logger_name in loggers_to_configure:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(log_level)
|
||||
logger.propagate = False
|
||||
if not logger.handlers:
|
||||
logger.addHandler(handler)
|
||||
|
||||
try:
|
||||
console.print("[bold blue]Initializing agent...[/bold blue]")
|
||||
agent = Agent()
|
||||
console.print(
|
||||
"[bold green]Agent initialized. Start chatting! (type 'exit' or 'quit' to end)[/bold green]"
|
||||
)
|
||||
|
||||
history = []
|
||||
while True:
|
||||
user_input = Prompt.ask("[bold yellow]You[/bold yellow]")
|
||||
if user_input.lower() in ["exit", "quit"]:
|
||||
console.print("[bold blue]Ending chat. Goodbye![/bold blue]")
|
||||
break
|
||||
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
response = agent.call(query=history)
|
||||
|
||||
console.print(f"[bold green]Agent:[/bold green] {response}")
|
||||
history.append({"role": "assistant", "content": response})
|
||||
|
||||
except ValueError as e:
|
||||
console.print(f"[bold red]Error: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An unexpected error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
app.add_typer(pipelines_app, name="pipelines")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
121
src/rag_eval/config.py
Normal file
121
src/rag_eval/config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||
|
||||
|
||||
class IndexConfig(BaseModel):
|
||||
name: str
|
||||
endpoint: str
|
||||
dimensions: int
|
||||
machine_type: str = "e2-standard-16"
|
||||
origin: str
|
||||
destination: str
|
||||
chunk_limit: int
|
||||
|
||||
@property
|
||||
def deployment(self) -> str:
|
||||
return self.name.replace("-", "_") + "_deployed"
|
||||
|
||||
@property
|
||||
def data(self) -> str:
|
||||
return self.destination + self.name
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
name: str
|
||||
instructions: str
|
||||
language_model: str
|
||||
embedding_model: str
|
||||
thinking: int
|
||||
|
||||
|
||||
class BigQueryConfig(BaseModel):
|
||||
dataset_id: str
|
||||
project_id: str | None = None
|
||||
table_ids: dict[str, str]
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
project_id: str
|
||||
location: str
|
||||
service_account: str
|
||||
|
||||
# Flattened fields from nested models
|
||||
agent_name: str
|
||||
agent_instructions: str
|
||||
agent_language_model: str
|
||||
agent_embedding_model: str
|
||||
agent_thinking: int
|
||||
|
||||
index_name: str
|
||||
index_endpoint: str
|
||||
index_dimensions: int
|
||||
index_machine_type: str = "e2-standard-16"
|
||||
index_origin: str
|
||||
index_destination: str
|
||||
index_chunk_limit: int
|
||||
|
||||
bigquery_dataset_id: str
|
||||
bigquery_project_id: str | None = None
|
||||
bigquery_table_ids: dict[str, str]
|
||||
|
||||
bucket: str
|
||||
base_image: str
|
||||
dialogflow_agent_id: str
|
||||
processing_image: str
|
||||
|
||||
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
|
||||
|
||||
@property
|
||||
def agent(self) -> AgentConfig:
|
||||
return AgentConfig(
|
||||
name=self.agent_name,
|
||||
instructions=self.agent_instructions,
|
||||
language_model=self.agent_language_model,
|
||||
embedding_model=self.agent_embedding_model,
|
||||
thinking=self.agent_thinking,
|
||||
)
|
||||
|
||||
@property
|
||||
def index(self) -> IndexConfig:
|
||||
return IndexConfig(
|
||||
name=self.index_name,
|
||||
endpoint=self.index_endpoint,
|
||||
dimensions=self.index_dimensions,
|
||||
machine_type=self.index_machine_type,
|
||||
origin=self.index_origin,
|
||||
destination=self.index_destination,
|
||||
chunk_limit=self.index_chunk_limit,
|
||||
)
|
||||
|
||||
@property
|
||||
def bigquery(self) -> BigQueryConfig:
|
||||
return BigQueryConfig(
|
||||
dataset_id=self.bigquery_dataset_id,
|
||||
project_id=self.bigquery_project_id,
|
||||
table_ids=self.bigquery_table_ids,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
44
src/rag_eval/logging.py
Normal file
44
src/rag_eval/logging.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def setup_logging(json: bool = True, level: int = logging.INFO) -> None:
|
||||
shared_processors: list[structlog.types.Processor] = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
]
|
||||
|
||||
if json:
|
||||
formatter = structlog.stdlib.ProcessorFormatter(
|
||||
processors=[
|
||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
)
|
||||
else:
|
||||
formatter = structlog.stdlib.ProcessorFormatter(
|
||||
processors=[
|
||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||
structlog.dev.ConsoleRenderer(),
|
||||
],
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
root.addHandler(handler)
|
||||
root.setLevel(level)
|
||||
|
||||
structlog.configure(
|
||||
processors=shared_processors,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
0
src/rag_eval/pipelines/__init__.py
Normal file
0
src/rag_eval/pipelines/__init__.py
Normal file
76
src/rag_eval/pipelines/evaluation.py
Normal file
76
src/rag_eval/pipelines/evaluation.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
KFP pipeline definition for the full RAG evaluation pipeline.
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
|
||||
|
||||
@dsl.component(base_image=config.base_image)
|
||||
def run_synth_gen_op(num_questions: int, config_yaml_content: str) -> str:
|
||||
"""
|
||||
KFP component to run the synthetic question generation script.
|
||||
Returns the generated run_id.
|
||||
"""
|
||||
CONFIG_FILE_PATH = "config.yaml"
|
||||
if config_yaml_content:
|
||||
with open(CONFIG_FILE_PATH, "w") as f:
|
||||
f.write(config_yaml_content)
|
||||
from synth_gen.main import generate as synth_gen_generate
|
||||
|
||||
return synth_gen_generate(num_questions=num_questions)
|
||||
|
||||
|
||||
@dsl.component(base_image=config.base_image)
|
||||
def run_search_eval_op(config_yaml_content: str, run_id: str = None):
|
||||
"""
|
||||
KFP component to run the search evaluation script by direct import.
|
||||
Optionally filters by run_id.
|
||||
"""
|
||||
CONFIG_FILE_PATH = "config.yaml"
|
||||
if config_yaml_content:
|
||||
with open(CONFIG_FILE_PATH, "w") as f:
|
||||
f.write(config_yaml_content)
|
||||
from search_eval.main import evaluate as search_eval_evaluate
|
||||
|
||||
search_eval_evaluate(run_id=run_id)
|
||||
|
||||
|
||||
@dsl.component(base_image=config.base_image)
|
||||
def run_keypoint_eval_op(config_yaml_content: str, run_id: str = None):
|
||||
"""
|
||||
KFP component to run the keypoint evaluation script by direct import.
|
||||
Optionally filters by run_id.
|
||||
"""
|
||||
CONFIG_FILE_PATH = "config.yaml"
|
||||
if config_yaml_content:
|
||||
with open(CONFIG_FILE_PATH, "w") as f:
|
||||
f.write(config_yaml_content)
|
||||
from keypoint_eval.main import run_keypoint_evaluation
|
||||
|
||||
run_keypoint_evaluation(run_id=run_id)
|
||||
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="generative-evaluation-pipeline",
|
||||
description="A pipeline that generates synthetic questions and then runs search evaluation on them.",
|
||||
)
|
||||
def generative_evaluation_pipeline(num_questions: int = 20, config_yaml_content: str = ""):
|
||||
"""
|
||||
Defines the generative evaluation pipeline structure.
|
||||
1. Generates synthetic questions.
|
||||
2. Runs search evaluation on the generated questions.
|
||||
3. Runs keypoint evaluation on the generated questions.
|
||||
"""
|
||||
synth_gen_task = run_synth_gen_op(
|
||||
num_questions=num_questions, config_yaml_content=config_yaml_content
|
||||
)
|
||||
|
||||
run_search_eval_op(
|
||||
run_id=synth_gen_task.output, config_yaml_content=config_yaml_content
|
||||
)
|
||||
run_keypoint_eval_op(
|
||||
run_id=synth_gen_task.output, config_yaml_content=config_yaml_content
|
||||
)
|
||||
53
src/rag_eval/pipelines/ingestion.py
Normal file
53
src/rag_eval/pipelines/ingestion.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
This script defines a simplified Kubeflow Pipeline (KFP) for ingesting and
|
||||
processing documents for a single agent by calling the index-gen CLI.
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
# --- KFP Components ---------------------------------------------------------
|
||||
|
||||
|
||||
@dsl.component(base_image=settings.processing_image)
|
||||
def run_index_gen_cli(config_yaml_content: str):
|
||||
"""Runs the index-gen CLI."""
|
||||
|
||||
CONFIG_FILE_PATH = "config.yaml"
|
||||
#CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "config.yaml")
|
||||
if config_yaml_content:
|
||||
with open(CONFIG_FILE_PATH, "w") as f:
|
||||
f.write(config_yaml_content)
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
# The command needs to be installed in the environment.
|
||||
# Assuming the processing_image has the project installed.
|
||||
command = ["uv", "run", "index-gen"]
|
||||
logging.info(f"Running command: {' '.join(command)}")
|
||||
|
||||
# Using subprocess.run to capture output and check for errors
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=False)
|
||||
|
||||
logging.info(result.stdout)
|
||||
if result.stderr:
|
||||
logging.error(result.stderr)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"index-gen CLI failed with return code {result.returncode}")
|
||||
|
||||
|
||||
# --- KFP Pipeline Definition ------------------------------------------------
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="rag-ingestion-pipeline",
|
||||
description="A pipeline to run index-gen.",
|
||||
)
|
||||
def ingestion_pipeline(
|
||||
config_yaml_content: str = "",
|
||||
):
|
||||
"""Defines the KFP pipeline structure."""
|
||||
run_index_gen_cli(config_yaml_content=config_yaml_content)
|
||||
146
src/rag_eval/pipelines/submit.py
Normal file
146
src/rag_eval/pipelines/submit.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Command-line interface for submitting pipelines.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from google.cloud import aiplatform
|
||||
from kfp import compiler
|
||||
|
||||
from rag_eval.config import CONFIG_FILE_PATH, settings
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def ingestion(
|
||||
compile_only: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--compile-only",
|
||||
help="If set, compiles the pipeline and exits without running it.",
|
||||
),
|
||||
] = False,
|
||||
):
|
||||
"""Compiles and/or runs the KFP ingestion pipeline."""
|
||||
from rag_eval.pipelines.ingestion import ingestion_pipeline
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
pipeline_root = f"gs://{settings.bucket}"
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as temp_pipeline_file:
|
||||
pipeline_spec_path = temp_pipeline_file.name
|
||||
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=ingestion_pipeline, package_path=pipeline_spec_path
|
||||
)
|
||||
logging.info(f"Pipeline compiled to temporary file: {pipeline_spec_path}")
|
||||
|
||||
if compile_only:
|
||||
logging.info(
|
||||
"Compilation successful. Temporary pipeline file will be deleted."
|
||||
)
|
||||
return
|
||||
|
||||
logging.info(f"Reading config from {CONFIG_FILE_PATH}")
|
||||
with open(CONFIG_FILE_PATH, "r") as f:
|
||||
config_yaml_content = f.read()
|
||||
|
||||
aiplatform.init(project=settings.project_id, location=settings.location)
|
||||
|
||||
job = aiplatform.PipelineJob(
|
||||
display_name=f"rag-ingestion-run-{settings.agent.name}",
|
||||
template_path=pipeline_spec_path,
|
||||
pipeline_root=pipeline_root,
|
||||
parameter_values={
|
||||
"config_yaml_content": config_yaml_content,
|
||||
},
|
||||
enable_caching=False,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"Submitting pipeline job to Vertex AI... This will wait for the job to complete."
|
||||
)
|
||||
job.run(sync=True, service_account=settings.service_account)
|
||||
|
||||
logging.info("Pipeline finished. Fetching endpoint details...")
|
||||
|
||||
if not settings.index or not settings.index.name:
|
||||
logging.warning(
|
||||
"No index configuration found. Cannot fetch endpoint."
|
||||
)
|
||||
return
|
||||
|
||||
expected_endpoint_name = f"{settings.index.name}-endpoint"
|
||||
logging.info(f"Looking for endpoint with display name: {expected_endpoint_name}")
|
||||
|
||||
endpoints = aiplatform.MatchingEngineIndexEndpoint.list()
|
||||
found_endpoint = next(
|
||||
(e for e in endpoints if e.display_name == expected_endpoint_name), None
|
||||
)
|
||||
|
||||
if not found_endpoint:
|
||||
logging.warning(
|
||||
"Could not find a matching deployed endpoint."
|
||||
)
|
||||
else:
|
||||
print("\n--- Deployed Index Endpoint ---")
|
||||
print(f"Display Name: {found_endpoint.display_name}")
|
||||
print(f" Resource Name: {found_endpoint.resource_name}")
|
||||
if found_endpoint.public_endpoint_domain_name:
|
||||
print(f" Public URI: {found_endpoint.public_endpoint_domain_name}")
|
||||
else:
|
||||
print(" Public URI: Not available")
|
||||
print("-" * 20)
|
||||
|
||||
|
||||
@app.command()
|
||||
def evaluation():
|
||||
"""
|
||||
Compiles and runs the RAG evaluation Vertex AI pipeline without saving the compiled file.
|
||||
"""
|
||||
from rag_eval.pipelines.evaluation import generative_evaluation_pipeline
|
||||
project_id = settings.project_id
|
||||
location = settings.location
|
||||
pipeline_root = settings.pipeline_root
|
||||
pipeline_name = "rag-evaluation-pipeline"
|
||||
|
||||
# Create a temporary file to store the compiled pipeline
|
||||
with tempfile.NamedTemporaryFile(suffix=".json") as temp_file:
|
||||
pipeline_path = temp_file.name
|
||||
|
||||
print(f"Compiling pipeline to temporary file: {pipeline_path}")
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=generative_evaluation_pipeline,
|
||||
package_path=pipeline_path,
|
||||
)
|
||||
|
||||
# Read the config.yaml content to pass to the pipeline
|
||||
print(f"Reading config from {CONFIG_FILE_PATH}")
|
||||
with open(CONFIG_FILE_PATH, "r") as f:
|
||||
config_yaml_content = f.read()
|
||||
|
||||
print("Submitting pipeline job to Vertex AI...")
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
|
||||
job = aiplatform.PipelineJob(
|
||||
display_name=pipeline_name,
|
||||
template_path=pipeline_path,
|
||||
pipeline_root=pipeline_root,
|
||||
parameter_values={
|
||||
"config_yaml_content": config_yaml_content,
|
||||
},
|
||||
enable_caching=False,
|
||||
)
|
||||
|
||||
job.run(service_account=settings.service_account)
|
||||
print(
|
||||
f"Pipeline job submitted. View it in the Vertex AI console: {job.gca_resource.name}"
|
||||
)
|
||||
|
||||
print(f"Temporary compiled file {pipeline_path} has been removed.")
|
||||
|
||||
|
||||
10
src/rag_eval/server/__init__.py
Normal file
10
src/rag_eval/server/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from rag_eval.logging import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
from .routes import rag_router # noqa: E402
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(rag_router)
|
||||
52
src/rag_eval/server/models.py
Normal file
52
src/rag_eval/server/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ---------------------------------- #
|
||||
# Maps Gemini response
|
||||
class Answer(BaseModel):
|
||||
answer: str
|
||||
|
||||
|
||||
# ---------------------------------- #
|
||||
# Dialogflow request parameters
|
||||
class RequestParameters(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
class RequestSessionInfo(BaseModel):
|
||||
parameters: RequestParameters
|
||||
|
||||
|
||||
class Request(BaseModel):
|
||||
sessionInfo: RequestSessionInfo
|
||||
|
||||
|
||||
# ---------------------------------- #
|
||||
# Dialogflow response parameters
|
||||
class ResponseParameters(BaseModel):
|
||||
webhook_success: bool
|
||||
response: str
|
||||
pregunta_nueva: str = "NO"
|
||||
fin_turno: bool = True
|
||||
respuesta_entregada: bool = True
|
||||
|
||||
|
||||
class ResponseSessionInfo(BaseModel):
|
||||
parameters: ResponseParameters
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
sessionInfo: ResponseSessionInfo
|
||||
|
||||
|
||||
# ---------------------------------- #
|
||||
# Dialogflow proxy models
|
||||
class DialogflowRequest(BaseModel):
|
||||
prompt: str
|
||||
session_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
class DialogflowResponse(BaseModel):
|
||||
response: dict
|
||||
75
src/rag_eval/server/routes.py
Normal file
75
src/rag_eval/server/routes.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
|
||||
from ..agent import Agent
|
||||
from .models import (
|
||||
DialogflowRequest,
|
||||
DialogflowResponse,
|
||||
Request,
|
||||
Response,
|
||||
ResponseParameters,
|
||||
ResponseSessionInfo,
|
||||
)
|
||||
|
||||
try:
|
||||
from dialogflow.main import DialogflowAgent
|
||||
except ImportError:
|
||||
DialogflowAgent = None
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
rag_router = APIRouter()
|
||||
|
||||
sigma_agent = Agent()
|
||||
|
||||
|
||||
@rag_router.post("/sigma-rag", response_model=Response)
|
||||
async def generate_sigma_agent(request: Request):
|
||||
request_id = uuid4().hex[:8]
|
||||
structlog.contextvars.clear_contextvars()
|
||||
structlog.contextvars.bind_contextvars(request_id=request_id)
|
||||
|
||||
prompt = request.sessionInfo.parameters.query
|
||||
logger.info("request.start", prompt_length=len(prompt))
|
||||
t0 = time.perf_counter()
|
||||
|
||||
answer = await sigma_agent.async_call(prompt)
|
||||
|
||||
elapsed = round((time.perf_counter() - t0) * 1000, 1)
|
||||
logger.info("request.end", elapsed_ms=elapsed)
|
||||
|
||||
response = Response(
|
||||
sessionInfo=ResponseSessionInfo(
|
||||
parameters=ResponseParameters(
|
||||
webhook_success=True,
|
||||
response=answer,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
if DialogflowAgent:
|
||||
|
||||
@rag_router.post("/dialogflow-proxy", response_model=DialogflowResponse)
|
||||
async def dialogflow_proxy(request: DialogflowRequest):
|
||||
"""
|
||||
Proxies a message to a Dialogflow agent.
|
||||
|
||||
This endpoint is only available if the 'dialogflow' package is installed.
|
||||
"""
|
||||
try:
|
||||
print(request)
|
||||
agent = DialogflowAgent()
|
||||
response = agent.call(query=request.prompt, session_id=request.session_id)
|
||||
return DialogflowResponse(response=response)
|
||||
except Exception as e:
|
||||
logger.error("Error calling Dialogflow agent", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error communicating with Dialogflow agent.",
|
||||
)
|
||||
Reference in New Issue
Block a user