Switch to agent arch

This commit is contained in:
2026-02-20 08:59:43 +00:00
parent a53f8fcf62
commit 259a8528e3
113 changed files with 788 additions and 7820 deletions

View File

@@ -1,2 +1 @@
def main() -> None:
print("Hello from rag-eval!")
"""RAG evaluation agent package."""

View File

@@ -1,196 +1,84 @@
import asyncio
"""Pydantic AI agent with RAG tool for vector search."""
import time
import structlog
from embedder.vertex_ai import VertexAIEmbedder
from google.genai import types
from llm.vertex_ai import VertexAILLM
from vector_search.vertex_ai import GoogleCloudVectorSearch
from pydantic import BaseModel
from pydantic_ai import Agent, Embedder, RunContext
from pydantic_ai.models.google import GoogleModel
from rag_eval.config import settings
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
logger = structlog.get_logger(__name__)
MAX_CONCURRENT_LLM_CALLS = 20
_llm_semaphore = asyncio.Semaphore(MAX_CONCURRENT_LLM_CALLS)
class Agent:
"""A class to handle the RAG workflow."""
class Deps(BaseModel):
"""Dependencies injected into the agent at runtime."""
def __init__(self):
"""Initializes the Agent class."""
self.settings = settings.agent
self.index_settings = settings.index
self.model = self.settings.language_model
self.system_prompt = self.settings.instructions
self.llm = VertexAILLM(
project=settings.project_id,
location=settings.location,
thinking=self.settings.thinking,
)
self.vector_search = GoogleCloudVectorSearch(
project_id=settings.project_id,
location=settings.location,
bucket=settings.bucket,
index_name=self.index_settings.name,
)
self.vector_search.load_index_endpoint(self.index_settings.endpoint)
self.embedder = VertexAIEmbedder(
project=settings.project_id,
location=settings.location,
model_name=self.settings.embedding_model, task="RETRIEVAL_QUERY"
)
self.min_sim = 0.60
vector_search: GoogleCloudVectorSearch
embedder: Embedder
def call(self, query: str | list[dict[str, str]]) -> str:
"""Calls the LLM with the provided query and tools.
model_config = {"arbitrary_types_allowed": True}
Args:
query: The user's query.
Returns:
The response from the LLM.
"""
if isinstance(query, str):
search_query = query
else:
search_query = query[-1]["content"]
model = GoogleModel(
settings.agent_language_model,
provider=settings.provider,
)
agent = Agent(
model,
deps_type=Deps,
system_prompt=settings.agent_instructions,
)
context = self.search(search_query)
user_prompt = f"{search_query}\n\n{context}"
contents = []
if isinstance(query, str):
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
else:
for turn in query[:-1]:
role = "model" if turn["role"] == "assistant" else "user"
contents.append(
types.Content(role=role, parts=[types.Part(text=turn["content"])])
)
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
@agent.tool
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
"""Search the vector index for the given query.
generation = self.llm.generate(
model=self.model,
prompt=contents,
system_prompt=self.system_prompt,
)
Args:
ctx: The run context containing dependencies.
query: The query to search for.
logger.info(f"total usage={generation.usage}")
logger.info(f"costo ${generation.usage.get_cost(self.model)} MXN")
Returns:
A formatted string containing the search results.
return generation.text
"""
t0 = time.perf_counter()
min_sim = 0.6
def search(self, query: str ) -> str:
"""Searches the vector index for the given query.
query_embedding = await ctx.deps.embedder.embed_query(query)
t_embed = time.perf_counter()
Args:
query: The query to search for.
search_results = await ctx.deps.vector_search.async_run_query(
deployed_index_id=settings.index_name,
query=list(query_embedding.embeddings[0]),
limit=5,
)
t_search = time.perf_counter()
Returns:
A formatted string containing the search results.
"""
logger.debug(f"Search term: {query}")
query_embedding = self.embedder.generate_embedding(query)
search_results = self.vector_search.run_query(
deployed_index_id=self.index_settings.deployment,
query=query_embedding,
limit=5,
)
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
cutoff= max_sim * 0.9
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
logger.debug(f"{max_sim=}")
logger.debug(f"{cutoff=}")
logger.debug(f"chunks={[s['id'] for s in search_results]}")
logger.debug(f"distancias={[s['distance'] for s in search_results]}")
return self._format_results(search_results)
async def async_search(self, query: str) -> str:
"""Searches the vector index for the given query.
Args:
query: The query to search for.
Returns:
A formatted string containing the search results.
"""
t0 = time.perf_counter()
query_embedding = await self.embedder.async_generate_embedding(query)
t_embed = time.perf_counter()
search_results = await self.vector_search.async_run_query(
deployed_index_id=self.index_settings.deployment,
query=query_embedding,
limit=5,
)
t_search = time.perf_counter()
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
if search_results:
max_sim = max(r["distance"] for r in search_results)
cutoff = max_sim * 0.9
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
logger.info(
"async_search.timing",
embedding_ms=round((t_embed - t0) * 1000, 1),
vector_search_ms=round((t_search - t_embed) * 1000, 1),
total_ms=round((t_search - t0) * 1000, 1),
chunks=[s["id"] for s in search_results],
)
return self._format_results(search_results)
def _format_results(self, search_results):
formatted_results = [
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
for i, result in enumerate(search_results, start=1)
search_results = [
s
for s in search_results
if s["distance"] > cutoff and s["distance"] > min_sim
]
return "\n".join(formatted_results)
async def async_call(self, query: str) -> str:
"""Calls the LLM with the provided query and tools.
logger.info(
"conocimiento.timing",
embedding_ms=round((t_embed - t0) * 1000, 1),
vector_search_ms=round((t_search - t_embed) * 1000, 1),
total_ms=round((t_search - t0) * 1000, 1),
chunks=[s["id"] for s in search_results],
)
Args:
query: The user's query.
Returns:
The response from the LLM.
"""
t_start = time.perf_counter()
t0 = time.perf_counter()
context = await self.async_search(query)
t_search = time.perf_counter()
contents = [types.Content(role="user", parts=[types.Part(text=f"{query}\n\n{context}")])]
t1 = time.perf_counter()
async with _llm_semaphore:
generation = await self.llm.async_generate(
model=self.model,
prompt=contents,
system_prompt=self.system_prompt,
)
t_llm = time.perf_counter()
t_end = time.perf_counter()
logger.info(
"async_call.timing",
total_ms=round((t_end - t_start) * 1000, 1),
stages=[
{"stage": "search", "ms": round((t_search - t0) * 1000, 1)},
{"stage": "llm", "ms": round((t_llm - t1) * 1000, 1)},
],
llm_iterations=1,
prompt_tokens=generation.usage.prompt_tokens,
response_tokens=generation.usage.response_tokens,
cost_mxn=generation.usage.get_cost(self.model),
)
return generation.text
formatted_results = [
f"<document {i} name={result['id']}>\n"
f"{result['content']}\n"
f"</document {i}>"
for i, result in enumerate(search_results, start=1)
]
return "\n".join(formatted_results)

View File

@@ -1,140 +0,0 @@
"""Main CLI for the RAG evaluation tool."""
import logging
import warnings
from typing import Annotated
import typer
from rich.console import Console
from rich.logging import RichHandler
from rich.prompt import Prompt
from rag_eval.agent import Agent
from rag_eval.pipelines.submit import app as pipelines_app
warnings.filterwarnings("ignore")
app = typer.Typer()
console = Console()
@app.command()
def ask(
message: Annotated[str, typer.Argument(help="The message to send to the agent.")],
verbose: Annotated[
int,
typer.Option(
"--verbose",
"-v",
help="Set verbosity. -v: INFO, -vv: DEBUG, -vvv: DEBUG with LLM logs.",
count=True,
),
] = 0,
):
"""Sends a single message to a specified agent and prints the response."""
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG, logging.DEBUG]
level_index = min(verbose, len(log_levels) - 1)
log_level = log_levels[level_index]
# Determine which loggers to show
loggers_to_configure = ["rag_eval"]
if verbose >= 3:
loggers_to_configure.append("llm")
# Create a single handler for all loggers
handler = RichHandler(rich_tracebacks=True, show_path=False)
for logger_name in loggers_to_configure:
logger = logging.getLogger(logger_name)
logger.setLevel(log_level)
logger.propagate = False
if not logger.handlers:
logger.addHandler(handler)
try:
console.print("[bold blue]Initializing agent...[/bold blue]")
agent = Agent()
console.print("[bold blue]Sending message...[/bold blue]")
response = agent.call(query=message)
console.print("\n[bold green]Agent Response:[/bold green]")
console.print(response)
except ValueError as e:
console.print(f"[bold red]Error: {e}[/bold red]")
raise typer.Exit(code=1)
except Exception as e:
console.print(f"[bold red]An unexpected error occurred: {e}[/bold red]")
raise typer.Exit(code=1)
@app.command()
def chat(
verbose: Annotated[
int,
typer.Option(
"--verbose",
"-v",
help="Set verbosity. -v: INFO, -vv: DEBUG, -vvv: DEBUG with LLM logs.",
count=True,
),
] = 0,
):
"""Starts an interactive chat session with a specified agent."""
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG, logging.DEBUG]
level_index = min(verbose, len(log_levels) - 1)
log_level = log_levels[level_index]
# Determine which loggers to show
loggers_to_configure = ["rag_eval"]
if verbose >= 3:
loggers_to_configure.append("llm")
# Create a single handler for all loggers
handler = RichHandler(rich_tracebacks=True, show_path=False)
for logger_name in loggers_to_configure:
logger = logging.getLogger(logger_name)
logger.setLevel(log_level)
logger.propagate = False
if not logger.handlers:
logger.addHandler(handler)
try:
console.print("[bold blue]Initializing agent...[/bold blue]")
agent = Agent()
console.print(
"[bold green]Agent initialized. Start chatting! (type 'exit' or 'quit' to end)[/bold green]"
)
history = []
while True:
user_input = Prompt.ask("[bold yellow]You[/bold yellow]")
if user_input.lower() in ["exit", "quit"]:
console.print("[bold blue]Ending chat. Goodbye![/bold blue]")
break
history.append({"role": "user", "content": user_input})
response = agent.call(query=history)
console.print(f"[bold green]Agent:[/bold green] {response}")
history.append({"role": "assistant", "content": response})
except ValueError as e:
console.print(f"[bold red]Error: {e}[/bold red]")
raise typer.Exit(code=1)
except Exception as e:
console.print(f"[bold red]An unexpected error occurred: {e}[/bold red]")
raise typer.Exit(code=1)
app.add_typer(pipelines_app, name="pipelines")
if __name__ == "__main__":
app()

View File

@@ -1,6 +1,10 @@
import os
"""Application settings loaded from YAML and environment variables."""
from pydantic import BaseModel
import os
from functools import cached_property
from pydantic_ai import Embedder
from pydantic_ai.providers.google import GoogleProvider
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
@@ -8,46 +12,18 @@ from pydantic_settings import (
YamlConfigSettingsSource,
)
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
class IndexConfig(BaseModel):
name: str
endpoint: str
dimensions: int
machine_type: str = "e2-standard-16"
origin: str
destination: str
chunk_limit: int
@property
def deployment(self) -> str:
return self.name.replace("-", "_") + "_deployed"
@property
def data(self) -> str:
return self.destination + self.name
class AgentConfig(BaseModel):
name: str
instructions: str
language_model: str
embedding_model: str
thinking: int
class BigQueryConfig(BaseModel):
dataset_id: str
project_id: str | None = None
table_ids: dict[str, str]
class Settings(BaseSettings):
"""Application settings loaded from config.yaml and env vars."""
project_id: str
location: str
service_account: str
# Flattened fields from nested models
agent_name: str
agent_instructions: str
agent_language_model: str
@@ -70,52 +46,52 @@ class Settings(BaseSettings):
base_image: str
dialogflow_agent_id: str
processing_image: str
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
@property
def agent(self) -> AgentConfig:
return AgentConfig(
name=self.agent_name,
instructions=self.agent_instructions,
language_model=self.agent_language_model,
embedding_model=self.agent_embedding_model,
thinking=self.agent_thinking,
)
@property
def index(self) -> IndexConfig:
return IndexConfig(
name=self.index_name,
endpoint=self.index_endpoint,
dimensions=self.index_dimensions,
machine_type=self.index_machine_type,
origin=self.index_origin,
destination=self.index_destination,
chunk_limit=self.index_chunk_limit,
)
@property
def bigquery(self) -> BigQueryConfig:
return BigQueryConfig(
dataset_id=self.bigquery_dataset_id,
project_id=self.bigquery_project_id,
table_ids=self.bigquery_table_ids,
)
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
) -> tuple[PydanticBaseSettingsSource, ...]:
"""Use env vars and YAML as settings sources."""
return (
env_settings,
YamlConfigSettingsSource(settings_cls),
)
@cached_property
def provider(self) -> GoogleProvider:
"""Return a Google provider configured for Vertex AI."""
return GoogleProvider(
project=self.project_id,
location=self.location,
)
settings = Settings()
@cached_property
def vector_search(self) -> GoogleCloudVectorSearch:
"""Return a configured vector search client."""
return GoogleCloudVectorSearch(
project_id=self.project_id,
location=self.location,
bucket=self.bucket,
index_name=self.index_name,
)
@cached_property
def embedder(self) -> Embedder:
"""Return an embedder configured for the agent's embedding model."""
from pydantic_ai.embeddings.google import GoogleEmbeddingModel # noqa: PLC0415
model = GoogleEmbeddingModel(
self.agent_embedding_model,
provider=self.provider,
)
return Embedder(model)
settings = Settings.model_validate({})

View File

@@ -0,0 +1 @@
"""File storage provider implementations."""

View File

@@ -0,0 +1,56 @@
"""Abstract base class for file storage providers."""
from abc import ABC, abstractmethod
from typing import BinaryIO
class BaseFileStorage(ABC):
"""Abstract base class for a remote file processor.
Defines the interface for listing and processing files from
a remote source.
"""
@abstractmethod
def upload_file(
self,
file_path: str,
destination_blob_name: str,
content_type: str | None = None,
) -> None:
"""Upload a file to the remote source.
Args:
file_path: The local path to the file to upload.
destination_blob_name: Name of the file in remote storage.
content_type: The content type of the file.
"""
...
@abstractmethod
def list_files(self, path: str | None = None) -> list[str]:
"""List files from a remote location.
Args:
path: Path to a specific file or directory. If None,
recursively lists all files in the bucket.
Returns:
A list of file paths.
"""
...
@abstractmethod
def get_file_stream(self, file_name: str) -> BinaryIO:
"""Get a file from the remote source as a file-like object.
Args:
file_name: The name of the file to retrieve.
Returns:
A file-like object containing the file data.
"""
...

View File

@@ -0,0 +1,188 @@
"""Google Cloud Storage file storage implementation."""
import asyncio
import io
import logging
from typing import BinaryIO
import aiohttp
from gcloud.aio.storage import Storage
from google.cloud import storage
from rag_eval.file_storage.base import BaseFileStorage
logger = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500
class GoogleCloudFileStorage(BaseFileStorage):
"""File storage backed by Google Cloud Storage."""
def __init__(self, bucket: str) -> None: # noqa: D107
self.bucket_name = bucket
self.storage_client = storage.Client()
self.bucket_client = self.storage_client.bucket(self.bucket_name)
self._aio_session: aiohttp.ClientSession | None = None
self._aio_storage: Storage | None = None
self._cache: dict[str, bytes] = {}
def upload_file(
self,
file_path: str,
destination_blob_name: str,
content_type: str | None = None,
) -> None:
"""Upload a file to Cloud Storage.
Args:
file_path: The local path to the file to upload.
destination_blob_name: Name of the blob in the bucket.
content_type: The content type of the file.
"""
blob = self.bucket_client.blob(destination_blob_name)
blob.upload_from_filename(
file_path,
content_type=content_type,
if_generation_match=0,
)
self._cache.pop(destination_blob_name, None)
def list_files(self, path: str | None = None) -> list[str]:
"""List all files at the given path in the bucket.
If path is None, recursively lists all files.
Args:
path: Prefix to filter files by.
Returns:
A list of blob names.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
return [blob.name for blob in blobs]
def get_file_stream(self, file_name: str) -> BinaryIO:
"""Get a file as a file-like object, using cache.
Args:
file_name: The blob name to retrieve.
Returns:
A BytesIO stream with the file contents.
"""
if file_name not in self._cache:
blob = self.bucket_client.blob(file_name)
self._cache[file_name] = blob.download_as_bytes()
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def _get_aio_storage(self) -> Storage:
if self._aio_storage is None:
self._aio_storage = Storage(
session=self._get_aio_session(),
)
return self._aio_storage
async def async_get_file_stream(
self, file_name: str, max_retries: int = 3,
) -> BinaryIO:
"""Get a file asynchronously with retry on transient errors.
Args:
file_name: The blob name to retrieve.
max_retries: Maximum number of retry attempts.
Returns:
A BytesIO stream with the file contents.
Raises:
TimeoutError: If all retry attempts fail.
"""
if file_name in self._cache:
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
storage_client = self._get_aio_storage()
last_exception: Exception | None = None
for attempt in range(max_retries):
try:
self._cache[file_name] = await storage_client.download(
self.bucket_name, file_name,
)
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
except TimeoutError as exc:
last_exception = exc
logger.warning(
"Timeout downloading gs://%s/%s (attempt %d/%d)",
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
except aiohttp.ClientResponseError as exc:
last_exception = exc
if (
exc.status == HTTP_TOO_MANY_REQUESTS
or exc.status >= HTTP_SERVER_ERROR
):
logger.warning(
"HTTP %d downloading gs://%s/%s "
"(attempt %d/%d)",
exc.status,
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
else:
raise
else:
return file_stream
if attempt < max_retries - 1:
delay = 0.5 * (2**attempt)
await asyncio.sleep(delay)
msg = (
f"Failed to download gs://{self.bucket_name}/{file_name} "
f"after {max_retries} attempts"
)
raise TimeoutError(msg) from last_exception
def delete_files(self, path: str) -> None:
"""Delete all files at the given path in the bucket.
Args:
path: Prefix of blobs to delete.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
for blob in blobs:
blob.delete()
self._cache.pop(blob.name, None)

View File

@@ -1,10 +1,13 @@
"""Structured logging configuration using structlog."""
import logging
import sys
import structlog
def setup_logging(json: bool = True, level: int = logging.INFO) -> None:
def setup_logging(*, json: bool = True, level: int = logging.INFO) -> None:
"""Configure structlog with JSON or console output."""
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level,

View File

@@ -1,76 +0,0 @@
"""
KFP pipeline definition for the full RAG evaluation pipeline.
"""
from kfp import dsl
from rag_eval.config import settings as config
@dsl.component(base_image=config.base_image)
def run_synth_gen_op(num_questions: int, config_yaml_content: str) -> str:
"""
KFP component to run the synthetic question generation script.
Returns the generated run_id.
"""
CONFIG_FILE_PATH = "config.yaml"
if config_yaml_content:
with open(CONFIG_FILE_PATH, "w") as f:
f.write(config_yaml_content)
from synth_gen.main import generate as synth_gen_generate
return synth_gen_generate(num_questions=num_questions)
@dsl.component(base_image=config.base_image)
def run_search_eval_op(config_yaml_content: str, run_id: str = None):
"""
KFP component to run the search evaluation script by direct import.
Optionally filters by run_id.
"""
CONFIG_FILE_PATH = "config.yaml"
if config_yaml_content:
with open(CONFIG_FILE_PATH, "w") as f:
f.write(config_yaml_content)
from search_eval.main import evaluate as search_eval_evaluate
search_eval_evaluate(run_id=run_id)
@dsl.component(base_image=config.base_image)
def run_keypoint_eval_op(config_yaml_content: str, run_id: str = None):
"""
KFP component to run the keypoint evaluation script by direct import.
Optionally filters by run_id.
"""
CONFIG_FILE_PATH = "config.yaml"
if config_yaml_content:
with open(CONFIG_FILE_PATH, "w") as f:
f.write(config_yaml_content)
from keypoint_eval.main import run_keypoint_evaluation
run_keypoint_evaluation(run_id=run_id)
@dsl.pipeline(
name="generative-evaluation-pipeline",
description="A pipeline that generates synthetic questions and then runs search evaluation on them.",
)
def generative_evaluation_pipeline(num_questions: int = 20, config_yaml_content: str = ""):
"""
Defines the generative evaluation pipeline structure.
1. Generates synthetic questions.
2. Runs search evaluation on the generated questions.
3. Runs keypoint evaluation on the generated questions.
"""
synth_gen_task = run_synth_gen_op(
num_questions=num_questions, config_yaml_content=config_yaml_content
)
run_search_eval_op(
run_id=synth_gen_task.output, config_yaml_content=config_yaml_content
)
run_keypoint_eval_op(
run_id=synth_gen_task.output, config_yaml_content=config_yaml_content
)

View File

@@ -1,53 +0,0 @@
"""
This script defines a simplified Kubeflow Pipeline (KFP) for ingesting and
processing documents for a single agent by calling the index-gen CLI.
"""
from kfp import dsl
from rag_eval.config import settings
# --- KFP Components ---------------------------------------------------------
@dsl.component(base_image=settings.processing_image)
def run_index_gen_cli(config_yaml_content: str):
"""Runs the index-gen CLI."""
CONFIG_FILE_PATH = "config.yaml"
#CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "config.yaml")
if config_yaml_content:
with open(CONFIG_FILE_PATH, "w") as f:
f.write(config_yaml_content)
import logging
import subprocess
# The command needs to be installed in the environment.
# Assuming the processing_image has the project installed.
command = ["uv", "run", "index-gen"]
logging.info(f"Running command: {' '.join(command)}")
# Using subprocess.run to capture output and check for errors
result = subprocess.run(command, capture_output=True, text=True, check=False)
logging.info(result.stdout)
if result.stderr:
logging.error(result.stderr)
if result.returncode != 0:
raise RuntimeError(f"index-gen CLI failed with return code {result.returncode}")
# --- KFP Pipeline Definition ------------------------------------------------
@dsl.pipeline(
name="rag-ingestion-pipeline",
description="A pipeline to run index-gen.",
)
def ingestion_pipeline(
config_yaml_content: str = "",
):
"""Defines the KFP pipeline structure."""
run_index_gen_cli(config_yaml_content=config_yaml_content)

View File

@@ -1,146 +0,0 @@
"""
Command-line interface for submitting pipelines.
"""
import logging
import tempfile
from typing import Annotated
import typer
from google.cloud import aiplatform
from kfp import compiler
from rag_eval.config import CONFIG_FILE_PATH, settings
app = typer.Typer()
@app.command()
def ingestion(
compile_only: Annotated[
bool,
typer.Option(
"--compile-only",
help="If set, compiles the pipeline and exits without running it.",
),
] = False,
):
"""Compiles and/or runs the KFP ingestion pipeline."""
from rag_eval.pipelines.ingestion import ingestion_pipeline
logging.getLogger().setLevel(logging.INFO)
pipeline_root = f"gs://{settings.bucket}"
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as temp_pipeline_file:
pipeline_spec_path = temp_pipeline_file.name
compiler.Compiler().compile(
pipeline_func=ingestion_pipeline, package_path=pipeline_spec_path
)
logging.info(f"Pipeline compiled to temporary file: {pipeline_spec_path}")
if compile_only:
logging.info(
"Compilation successful. Temporary pipeline file will be deleted."
)
return
logging.info(f"Reading config from {CONFIG_FILE_PATH}")
with open(CONFIG_FILE_PATH, "r") as f:
config_yaml_content = f.read()
aiplatform.init(project=settings.project_id, location=settings.location)
job = aiplatform.PipelineJob(
display_name=f"rag-ingestion-run-{settings.agent.name}",
template_path=pipeline_spec_path,
pipeline_root=pipeline_root,
parameter_values={
"config_yaml_content": config_yaml_content,
},
enable_caching=False,
)
logging.info(
"Submitting pipeline job to Vertex AI... This will wait for the job to complete."
)
job.run(sync=True, service_account=settings.service_account)
logging.info("Pipeline finished. Fetching endpoint details...")
if not settings.index or not settings.index.name:
logging.warning(
"No index configuration found. Cannot fetch endpoint."
)
return
expected_endpoint_name = f"{settings.index.name}-endpoint"
logging.info(f"Looking for endpoint with display name: {expected_endpoint_name}")
endpoints = aiplatform.MatchingEngineIndexEndpoint.list()
found_endpoint = next(
(e for e in endpoints if e.display_name == expected_endpoint_name), None
)
if not found_endpoint:
logging.warning(
"Could not find a matching deployed endpoint."
)
else:
print("\n--- Deployed Index Endpoint ---")
print(f"Display Name: {found_endpoint.display_name}")
print(f" Resource Name: {found_endpoint.resource_name}")
if found_endpoint.public_endpoint_domain_name:
print(f" Public URI: {found_endpoint.public_endpoint_domain_name}")
else:
print(" Public URI: Not available")
print("-" * 20)
@app.command()
def evaluation():
"""
Compiles and runs the RAG evaluation Vertex AI pipeline without saving the compiled file.
"""
from rag_eval.pipelines.evaluation import generative_evaluation_pipeline
project_id = settings.project_id
location = settings.location
pipeline_root = settings.pipeline_root
pipeline_name = "rag-evaluation-pipeline"
# Create a temporary file to store the compiled pipeline
with tempfile.NamedTemporaryFile(suffix=".json") as temp_file:
pipeline_path = temp_file.name
print(f"Compiling pipeline to temporary file: {pipeline_path}")
compiler.Compiler().compile(
pipeline_func=generative_evaluation_pipeline,
package_path=pipeline_path,
)
# Read the config.yaml content to pass to the pipeline
print(f"Reading config from {CONFIG_FILE_PATH}")
with open(CONFIG_FILE_PATH, "r") as f:
config_yaml_content = f.read()
print("Submitting pipeline job to Vertex AI...")
aiplatform.init(project=project_id, location=location)
job = aiplatform.PipelineJob(
display_name=pipeline_name,
template_path=pipeline_path,
pipeline_root=pipeline_root,
parameter_values={
"config_yaml_content": config_yaml_content,
},
enable_caching=False,
)
job.run(service_account=settings.service_account)
print(
f"Pipeline job submitted. View it in the Vertex AI console: {job.gca_resource.name}"
)
print(f"Temporary compiled file {pipeline_path} has been removed.")

61
src/rag_eval/server.py Normal file
View File

@@ -0,0 +1,61 @@
"""FastAPI server exposing the RAG agent endpoint."""
import time
from typing import Literal
from uuid import uuid4
import structlog
from fastapi import FastAPI
from pydantic import BaseModel
from rag_eval.agent import Deps, agent
from rag_eval.config import settings
from rag_eval.logging import setup_logging
logger = structlog.get_logger(__name__)
setup_logging()
app = FastAPI(title="RAG Agent")
class Message(BaseModel):
"""A single chat message."""
role: Literal["system", "user", "assistant"]
content: str
class AgentRequest(BaseModel):
"""Request body for the agent endpoint."""
messages: list[Message]
class AgentResponse(BaseModel):
"""Response body from the agent endpoint."""
response: str
@app.post("/agent")
async def run_agent(request: AgentRequest) -> AgentResponse:
"""Run the RAG agent with the provided messages."""
request_id = uuid4().hex[:8]
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(request_id=request_id)
prompt = request.messages[-1].content
logger.info("request.start", prompt_length=len(prompt))
t0 = time.perf_counter()
deps = Deps(
vector_search=settings.vector_search,
embedder=settings.embedder,
)
result = await agent.run(prompt, deps=deps)
elapsed = round((time.perf_counter() - t0) * 1000, 1)
logger.info("request.end", elapsed_ms=elapsed)
return AgentResponse(response=result.output)

View File

@@ -1,10 +0,0 @@
from fastapi import FastAPI
from rag_eval.logging import setup_logging
setup_logging()
from .routes import rag_router # noqa: E402
app = FastAPI()
app.include_router(rag_router)

View File

@@ -1,52 +0,0 @@
from uuid import uuid4
from pydantic import BaseModel, Field
# ---------------------------------- #
# Maps Gemini response
class Answer(BaseModel):
answer: str
# ---------------------------------- #
# Dialogflow request parameters
class RequestParameters(BaseModel):
query: str
class RequestSessionInfo(BaseModel):
parameters: RequestParameters
class Request(BaseModel):
sessionInfo: RequestSessionInfo
# ---------------------------------- #
# Dialogflow response parameters
class ResponseParameters(BaseModel):
webhook_success: bool
response: str
pregunta_nueva: str = "NO"
fin_turno: bool = True
respuesta_entregada: bool = True
class ResponseSessionInfo(BaseModel):
parameters: ResponseParameters
class Response(BaseModel):
sessionInfo: ResponseSessionInfo
# ---------------------------------- #
# Dialogflow proxy models
class DialogflowRequest(BaseModel):
prompt: str
session_id: str = Field(default_factory=lambda: str(uuid4()))
class DialogflowResponse(BaseModel):
response: dict

View File

@@ -1,75 +0,0 @@
import time
from uuid import uuid4
import structlog
from fastapi import APIRouter, HTTPException, status
from ..agent import Agent
from .models import (
DialogflowRequest,
DialogflowResponse,
Request,
Response,
ResponseParameters,
ResponseSessionInfo,
)
try:
from dialogflow.main import DialogflowAgent
except ImportError:
DialogflowAgent = None
logger = structlog.get_logger(__name__)
rag_router = APIRouter()
sigma_agent = Agent()
@rag_router.post("/sigma-rag", response_model=Response)
async def generate_sigma_agent(request: Request):
request_id = uuid4().hex[:8]
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(request_id=request_id)
prompt = request.sessionInfo.parameters.query
logger.info("request.start", prompt_length=len(prompt))
t0 = time.perf_counter()
answer = await sigma_agent.async_call(prompt)
elapsed = round((time.perf_counter() - t0) * 1000, 1)
logger.info("request.end", elapsed_ms=elapsed)
response = Response(
sessionInfo=ResponseSessionInfo(
parameters=ResponseParameters(
webhook_success=True,
response=answer,
)
)
)
return response
if DialogflowAgent:
@rag_router.post("/dialogflow-proxy", response_model=DialogflowResponse)
async def dialogflow_proxy(request: DialogflowRequest):
"""
Proxies a message to a Dialogflow agent.
This endpoint is only available if the 'dialogflow' package is installed.
"""
try:
print(request)
agent = DialogflowAgent()
response = agent.call(query=request.prompt, session_id=request.session_id)
return DialogflowResponse(response=response)
except Exception as e:
logger.error("Error calling Dialogflow agent", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error communicating with Dialogflow agent.",
)

View File

@@ -0,0 +1 @@
"""Vector search provider implementations."""

View File

@@ -0,0 +1,68 @@
"""Abstract base class for vector search providers."""
from abc import ABC, abstractmethod
from typing import Any, TypedDict
class SearchResult(TypedDict):
"""A single vector search result."""
id: str
distance: float
content: str
class BaseVectorSearch(ABC):
"""Abstract base class for a vector search provider.
This class defines the standard interface for creating a vector search
index and running queries against it.
"""
@abstractmethod
def create_index(
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
) -> None:
"""Create a new vector search index with the provided content.
Args:
name: The desired name for the new index.
content_path: Path to the data used to populate the index.
**kwargs: Additional provider-specific arguments.
"""
...
@abstractmethod
def update_index(
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
) -> None:
"""Update an existing vector search index with new content.
Args:
index_name: The name of the index to update.
content_path: Path to the data used to populate the index.
**kwargs: Additional provider-specific arguments.
"""
...
@abstractmethod
def run_query(
self,
deployed_index_id: str,
query: list[float],
limit: int,
) -> list[SearchResult]:
"""Run a similarity search query against the index.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
...

View File

@@ -0,0 +1,310 @@
"""Google Cloud Vertex AI Vector Search implementation."""
import asyncio
from collections.abc import Sequence
from typing import Any
from uuid import uuid4
import aiohttp
import google.auth
import google.auth.credentials
import google.auth.transport.requests
from gcloud.aio.auth import Token
from google.cloud import aiplatform
from rag_eval.file_storage.google_cloud import GoogleCloudFileStorage
from rag_eval.vector_search.base import BaseVectorSearch, SearchResult
class GoogleCloudVectorSearch(BaseVectorSearch):
"""A vector search provider using Vertex AI Vector Search."""
def __init__(
self,
project_id: str,
location: str,
bucket: str,
index_name: str | None = None,
) -> None:
"""Initialize the GoogleCloudVectorSearch client.
Args:
project_id: The Google Cloud project ID.
location: The Google Cloud location (e.g., 'us-central1').
bucket: The GCS bucket to use for file storage.
index_name: The name of the index.
"""
aiplatform.init(project=project_id, location=location)
self.project_id = project_id
self.location = location
self.storage = GoogleCloudFileStorage(bucket=bucket)
self.index_name = index_name
self._credentials: google.auth.credentials.Credentials | None = None
self._aio_session: aiohttp.ClientSession | None = None
self._async_token: Token | None = None
def _get_auth_headers(self) -> dict[str, str]:
if self._credentials is None:
self._credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
if not self._credentials.token or self._credentials.expired:
self._credentials.refresh(
google.auth.transport.requests.Request(),
)
return {
"Authorization": f"Bearer {self._credentials.token}",
"Content-Type": "application/json",
}
async def _async_get_auth_headers(self) -> dict[str, str]:
if self._async_token is None:
self._async_token = Token(
session=self._get_aio_session(),
scopes=[
"https://www.googleapis.com/auth/cloud-platform",
],
)
access_token = await self._async_token.get()
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def create_index(
self,
name: str,
content_path: str,
*,
dimensions: int = 3072,
approximate_neighbors_count: int = 150,
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
**kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Create a new Vertex AI Vector Search index.
Args:
name: The display name for the new index.
content_path: GCS URI to the embeddings JSON file.
dimensions: Number of dimensions in embedding vectors.
approximate_neighbors_count: Neighbors to find per vector.
distance_measure_type: The distance measure to use.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
display_name=name,
contents_delta_uri=content_path,
dimensions=dimensions,
approximate_neighbors_count=approximate_neighbors_count,
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
leaf_node_embedding_count=1000,
leaf_nodes_to_search_percent=10,
)
self.index = index
def update_index(
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Update an existing Vertex AI Vector Search index.
Args:
index_name: The resource name of the index to update.
content_path: GCS URI to the new embeddings JSON file.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex(index_name=index_name)
index.update_embeddings(
contents_delta_uri=content_path,
)
self.index = index
def deploy_index(
self,
index_name: str,
machine_type: str = "e2-standard-2",
) -> None:
"""Deploy a Vertex AI Vector Search index to an endpoint.
Args:
index_name: The name of the index to deploy.
machine_type: The machine type for the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=f"{index_name}-endpoint",
public_endpoint_enabled=True,
)
index_endpoint.deploy_index(
index=self.index,
deployed_index_id=(
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
),
machine_type=machine_type,
)
self.index_endpoint = index_endpoint
def load_index_endpoint(self, endpoint_name: str) -> None:
"""Load an existing Vertex AI Vector Search index endpoint.
Args:
endpoint_name: The resource name of the index endpoint.
"""
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
endpoint_name,
)
if not self.index_endpoint.public_endpoint_domain_name:
msg = (
"The index endpoint does not have a public endpoint. "
"Ensure the endpoint is configured for public access."
)
raise ValueError(msg)
def run_query(
self,
deployed_index_id: str,
query: list[float],
limit: int,
) -> list[SearchResult]:
"""Run a similarity search query against the deployed index.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
response = self.index_endpoint.find_neighbors(
deployed_index_id=deployed_index_id,
queries=[query],
num_neighbors=limit,
)
results = []
for neighbor in response[0]:
file_path = (
f"{self.index_name}/contents/{neighbor.id}.md"
)
content = (
self.storage.get_file_stream(file_path)
.read()
.decode("utf-8")
)
results.append(
SearchResult(
id=neighbor.id,
distance=float(neighbor.distance or 0),
content=content,
),
)
return results
async def async_run_query(
self,
deployed_index_id: str,
query: Sequence[float],
limit: int,
) -> list[SearchResult]:
"""Run an async similarity search via the REST API.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
domain = self.index_endpoint.public_endpoint_domain_name
endpoint_id = self.index_endpoint.name.split("/")[-1]
url = (
f"https://{domain}/v1/projects/{self.project_id}"
f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors"
)
payload = {
"deployed_index_id": deployed_index_id,
"queries": [
{
"datapoint": {"feature_vector": list(query)},
"neighbor_count": limit,
},
],
}
headers = await self._async_get_auth_headers()
session = self._get_aio_session()
async with session.post(
url, json=payload, headers=headers,
) as response:
response.raise_for_status()
data = await response.json()
neighbors = (
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
)
content_tasks = []
for neighbor in neighbors:
datapoint_id = neighbor["datapoint"]["datapointId"]
file_path = (
f"{self.index_name}/contents/{datapoint_id}.md"
)
content_tasks.append(
self.storage.async_get_file_stream(file_path),
)
file_streams = await asyncio.gather(*content_tasks)
results: list[SearchResult] = []
for neighbor, stream in zip(
neighbors, file_streams, strict=True,
):
results.append(
SearchResult(
id=neighbor["datapoint"]["datapointId"],
distance=neighbor["distance"],
content=stream.read().decode("utf-8"),
),
)
return results
def delete_index(self, index_name: str) -> None:
"""Delete a Vertex AI Vector Search index.
Args:
index_name: The resource name of the index.
"""
index = aiplatform.MatchingEngineIndex(index_name)
index.delete()
def delete_index_endpoint(
self, index_endpoint_name: str,
) -> None:
"""Delete a Vertex AI Vector Search index endpoint.
Args:
index_endpoint_name: The resource name of the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name,
)
index_endpoint.undeploy_all()
index_endpoint.delete(force=True)