Switch to agent arch
This commit is contained in:
@@ -1,2 +1 @@
|
||||
def main() -> None:
|
||||
print("Hello from rag-eval!")
|
||||
"""RAG evaluation agent package."""
|
||||
|
||||
@@ -1,196 +1,84 @@
|
||||
import asyncio
|
||||
"""Pydantic AI agent with RAG tool for vector search."""
|
||||
|
||||
import time
|
||||
|
||||
import structlog
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from google.genai import types
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
from pydantic import BaseModel
|
||||
from pydantic_ai import Agent, Embedder, RunContext
|
||||
from pydantic_ai.models.google import GoogleModel
|
||||
|
||||
from rag_eval.config import settings
|
||||
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
MAX_CONCURRENT_LLM_CALLS = 20
|
||||
_llm_semaphore = asyncio.Semaphore(MAX_CONCURRENT_LLM_CALLS)
|
||||
|
||||
class Agent:
|
||||
"""A class to handle the RAG workflow."""
|
||||
class Deps(BaseModel):
|
||||
"""Dependencies injected into the agent at runtime."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the Agent class."""
|
||||
self.settings = settings.agent
|
||||
self.index_settings = settings.index
|
||||
self.model = self.settings.language_model
|
||||
self.system_prompt = self.settings.instructions
|
||||
self.llm = VertexAILLM(
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
thinking=self.settings.thinking,
|
||||
)
|
||||
self.vector_search = GoogleCloudVectorSearch(
|
||||
project_id=settings.project_id,
|
||||
location=settings.location,
|
||||
bucket=settings.bucket,
|
||||
index_name=self.index_settings.name,
|
||||
)
|
||||
self.vector_search.load_index_endpoint(self.index_settings.endpoint)
|
||||
self.embedder = VertexAIEmbedder(
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
model_name=self.settings.embedding_model, task="RETRIEVAL_QUERY"
|
||||
)
|
||||
self.min_sim = 0.60
|
||||
vector_search: GoogleCloudVectorSearch
|
||||
embedder: Embedder
|
||||
|
||||
def call(self, query: str | list[dict[str, str]]) -> str:
|
||||
"""Calls the LLM with the provided query and tools.
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
Args:
|
||||
query: The user's query.
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
search_query = query
|
||||
else:
|
||||
search_query = query[-1]["content"]
|
||||
model = GoogleModel(
|
||||
settings.agent_language_model,
|
||||
provider=settings.provider,
|
||||
)
|
||||
agent = Agent(
|
||||
model,
|
||||
deps_type=Deps,
|
||||
system_prompt=settings.agent_instructions,
|
||||
)
|
||||
|
||||
context = self.search(search_query)
|
||||
user_prompt = f"{search_query}\n\n{context}"
|
||||
|
||||
contents = []
|
||||
if isinstance(query, str):
|
||||
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
|
||||
else:
|
||||
for turn in query[:-1]:
|
||||
role = "model" if turn["role"] == "assistant" else "user"
|
||||
contents.append(
|
||||
types.Content(role=role, parts=[types.Part(text=turn["content"])])
|
||||
)
|
||||
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
|
||||
@agent.tool
|
||||
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
|
||||
"""Search the vector index for the given query.
|
||||
|
||||
generation = self.llm.generate(
|
||||
model=self.model,
|
||||
prompt=contents,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
Args:
|
||||
ctx: The run context containing dependencies.
|
||||
query: The query to search for.
|
||||
|
||||
logger.info(f"total usage={generation.usage}")
|
||||
logger.info(f"costo ${generation.usage.get_cost(self.model)} MXN")
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
|
||||
return generation.text
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
min_sim = 0.6
|
||||
|
||||
def search(self, query: str ) -> str:
|
||||
"""Searches the vector index for the given query.
|
||||
query_embedding = await ctx.deps.embedder.embed_query(query)
|
||||
t_embed = time.perf_counter()
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
search_results = await ctx.deps.vector_search.async_run_query(
|
||||
deployed_index_id=settings.index_name,
|
||||
query=list(query_embedding.embeddings[0]),
|
||||
limit=5,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
"""
|
||||
logger.debug(f"Search term: {query}")
|
||||
|
||||
query_embedding = self.embedder.generate_embedding(query)
|
||||
search_results = self.vector_search.run_query(
|
||||
deployed_index_id=self.index_settings.deployment,
|
||||
query=query_embedding,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
|
||||
cutoff= max_sim * 0.9
|
||||
|
||||
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
|
||||
|
||||
logger.debug(f"{max_sim=}")
|
||||
logger.debug(f"{cutoff=}")
|
||||
logger.debug(f"chunks={[s['id'] for s in search_results]}")
|
||||
logger.debug(f"distancias={[s['distance'] for s in search_results]}")
|
||||
|
||||
return self._format_results(search_results)
|
||||
|
||||
async def async_search(self, query: str) -> str:
|
||||
"""Searches the vector index for the given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
query_embedding = await self.embedder.async_generate_embedding(query)
|
||||
t_embed = time.perf_counter()
|
||||
|
||||
search_results = await self.vector_search.async_run_query(
|
||||
deployed_index_id=self.index_settings.deployment,
|
||||
query=query_embedding,
|
||||
limit=5,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
|
||||
if search_results:
|
||||
max_sim = max(r["distance"] for r in search_results)
|
||||
cutoff = max_sim * 0.9
|
||||
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
|
||||
|
||||
logger.info(
|
||||
"async_search.timing",
|
||||
embedding_ms=round((t_embed - t0) * 1000, 1),
|
||||
vector_search_ms=round((t_search - t_embed) * 1000, 1),
|
||||
total_ms=round((t_search - t0) * 1000, 1),
|
||||
chunks=[s["id"] for s in search_results],
|
||||
)
|
||||
|
||||
return self._format_results(search_results)
|
||||
|
||||
def _format_results(self, search_results):
|
||||
formatted_results = [
|
||||
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
search_results = [
|
||||
s
|
||||
for s in search_results
|
||||
if s["distance"] > cutoff and s["distance"] > min_sim
|
||||
]
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
async def async_call(self, query: str) -> str:
|
||||
"""Calls the LLM with the provided query and tools.
|
||||
logger.info(
|
||||
"conocimiento.timing",
|
||||
embedding_ms=round((t_embed - t0) * 1000, 1),
|
||||
vector_search_ms=round((t_search - t_embed) * 1000, 1),
|
||||
total_ms=round((t_search - t0) * 1000, 1),
|
||||
chunks=[s["id"] for s in search_results],
|
||||
)
|
||||
|
||||
Args:
|
||||
query: The user's query.
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
"""
|
||||
t_start = time.perf_counter()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
context = await self.async_search(query)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
contents = [types.Content(role="user", parts=[types.Part(text=f"{query}\n\n{context}")])]
|
||||
|
||||
t1 = time.perf_counter()
|
||||
async with _llm_semaphore:
|
||||
generation = await self.llm.async_generate(
|
||||
model=self.model,
|
||||
prompt=contents,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
t_llm = time.perf_counter()
|
||||
|
||||
t_end = time.perf_counter()
|
||||
logger.info(
|
||||
"async_call.timing",
|
||||
total_ms=round((t_end - t_start) * 1000, 1),
|
||||
stages=[
|
||||
{"stage": "search", "ms": round((t_search - t0) * 1000, 1)},
|
||||
{"stage": "llm", "ms": round((t_llm - t1) * 1000, 1)},
|
||||
],
|
||||
llm_iterations=1,
|
||||
prompt_tokens=generation.usage.prompt_tokens,
|
||||
response_tokens=generation.usage.response_tokens,
|
||||
cost_mxn=generation.usage.get_cost(self.model),
|
||||
)
|
||||
|
||||
return generation.text
|
||||
formatted_results = [
|
||||
f"<document {i} name={result['id']}>\n"
|
||||
f"{result['content']}\n"
|
||||
f"</document {i}>"
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
]
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
"""Main CLI for the RAG evaluation tool."""
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from rag_eval.agent import Agent
|
||||
from rag_eval.pipelines.submit import app as pipelines_app
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
app = typer.Typer()
|
||||
console = Console()
|
||||
|
||||
|
||||
@app.command()
|
||||
def ask(
|
||||
message: Annotated[str, typer.Argument(help="The message to send to the agent.")],
|
||||
verbose: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Set verbosity. -v: INFO, -vv: DEBUG, -vvv: DEBUG with LLM logs.",
|
||||
count=True,
|
||||
),
|
||||
] = 0,
|
||||
):
|
||||
"""Sends a single message to a specified agent and prints the response."""
|
||||
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG, logging.DEBUG]
|
||||
level_index = min(verbose, len(log_levels) - 1)
|
||||
log_level = log_levels[level_index]
|
||||
|
||||
# Determine which loggers to show
|
||||
loggers_to_configure = ["rag_eval"]
|
||||
if verbose >= 3:
|
||||
loggers_to_configure.append("llm")
|
||||
|
||||
# Create a single handler for all loggers
|
||||
handler = RichHandler(rich_tracebacks=True, show_path=False)
|
||||
|
||||
for logger_name in loggers_to_configure:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(log_level)
|
||||
logger.propagate = False
|
||||
if not logger.handlers:
|
||||
logger.addHandler(handler)
|
||||
|
||||
try:
|
||||
console.print("[bold blue]Initializing agent...[/bold blue]")
|
||||
|
||||
agent = Agent()
|
||||
|
||||
console.print("[bold blue]Sending message...[/bold blue]")
|
||||
response = agent.call(query=message)
|
||||
|
||||
console.print("\n[bold green]Agent Response:[/bold green]")
|
||||
console.print(response)
|
||||
|
||||
except ValueError as e:
|
||||
console.print(f"[bold red]Error: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An unexpected error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@app.command()
|
||||
def chat(
|
||||
verbose: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--verbose",
|
||||
"-v",
|
||||
help="Set verbosity. -v: INFO, -vv: DEBUG, -vvv: DEBUG with LLM logs.",
|
||||
count=True,
|
||||
),
|
||||
] = 0,
|
||||
):
|
||||
"""Starts an interactive chat session with a specified agent."""
|
||||
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG, logging.DEBUG]
|
||||
level_index = min(verbose, len(log_levels) - 1)
|
||||
log_level = log_levels[level_index]
|
||||
|
||||
# Determine which loggers to show
|
||||
loggers_to_configure = ["rag_eval"]
|
||||
if verbose >= 3:
|
||||
loggers_to_configure.append("llm")
|
||||
|
||||
# Create a single handler for all loggers
|
||||
handler = RichHandler(rich_tracebacks=True, show_path=False)
|
||||
|
||||
for logger_name in loggers_to_configure:
|
||||
logger = logging.getLogger(logger_name)
|
||||
logger.setLevel(log_level)
|
||||
logger.propagate = False
|
||||
if not logger.handlers:
|
||||
logger.addHandler(handler)
|
||||
|
||||
try:
|
||||
console.print("[bold blue]Initializing agent...[/bold blue]")
|
||||
agent = Agent()
|
||||
console.print(
|
||||
"[bold green]Agent initialized. Start chatting! (type 'exit' or 'quit' to end)[/bold green]"
|
||||
)
|
||||
|
||||
history = []
|
||||
while True:
|
||||
user_input = Prompt.ask("[bold yellow]You[/bold yellow]")
|
||||
if user_input.lower() in ["exit", "quit"]:
|
||||
console.print("[bold blue]Ending chat. Goodbye![/bold blue]")
|
||||
break
|
||||
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
response = agent.call(query=history)
|
||||
|
||||
console.print(f"[bold green]Agent:[/bold green] {response}")
|
||||
history.append({"role": "assistant", "content": response})
|
||||
|
||||
except ValueError as e:
|
||||
console.print(f"[bold red]Error: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An unexpected error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
app.add_typer(pipelines_app, name="pipelines")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -1,6 +1,10 @@
|
||||
import os
|
||||
"""Application settings loaded from YAML and environment variables."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
from pydantic_ai import Embedder
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
@@ -8,46 +12,18 @@ from pydantic_settings import (
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||
|
||||
|
||||
class IndexConfig(BaseModel):
|
||||
name: str
|
||||
endpoint: str
|
||||
dimensions: int
|
||||
machine_type: str = "e2-standard-16"
|
||||
origin: str
|
||||
destination: str
|
||||
chunk_limit: int
|
||||
|
||||
@property
|
||||
def deployment(self) -> str:
|
||||
return self.name.replace("-", "_") + "_deployed"
|
||||
|
||||
@property
|
||||
def data(self) -> str:
|
||||
return self.destination + self.name
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
name: str
|
||||
instructions: str
|
||||
language_model: str
|
||||
embedding_model: str
|
||||
thinking: int
|
||||
|
||||
|
||||
class BigQueryConfig(BaseModel):
|
||||
dataset_id: str
|
||||
project_id: str | None = None
|
||||
table_ids: dict[str, str]
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from config.yaml and env vars."""
|
||||
|
||||
project_id: str
|
||||
location: str
|
||||
service_account: str
|
||||
|
||||
# Flattened fields from nested models
|
||||
|
||||
agent_name: str
|
||||
agent_instructions: str
|
||||
agent_language_model: str
|
||||
@@ -70,52 +46,52 @@ class Settings(BaseSettings):
|
||||
base_image: str
|
||||
dialogflow_agent_id: str
|
||||
processing_image: str
|
||||
|
||||
|
||||
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
|
||||
|
||||
@property
|
||||
def agent(self) -> AgentConfig:
|
||||
return AgentConfig(
|
||||
name=self.agent_name,
|
||||
instructions=self.agent_instructions,
|
||||
language_model=self.agent_language_model,
|
||||
embedding_model=self.agent_embedding_model,
|
||||
thinking=self.agent_thinking,
|
||||
)
|
||||
|
||||
@property
|
||||
def index(self) -> IndexConfig:
|
||||
return IndexConfig(
|
||||
name=self.index_name,
|
||||
endpoint=self.index_endpoint,
|
||||
dimensions=self.index_dimensions,
|
||||
machine_type=self.index_machine_type,
|
||||
origin=self.index_origin,
|
||||
destination=self.index_destination,
|
||||
chunk_limit=self.index_chunk_limit,
|
||||
)
|
||||
|
||||
@property
|
||||
def bigquery(self) -> BigQueryConfig:
|
||||
return BigQueryConfig(
|
||||
dataset_id=self.bigquery_dataset_id,
|
||||
project_id=self.bigquery_project_id,
|
||||
table_ids=self.bigquery_table_ids,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""Use env vars and YAML as settings sources."""
|
||||
return (
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def provider(self) -> GoogleProvider:
|
||||
"""Return a Google provider configured for Vertex AI."""
|
||||
return GoogleProvider(
|
||||
project=self.project_id,
|
||||
location=self.location,
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
@cached_property
|
||||
def vector_search(self) -> GoogleCloudVectorSearch:
|
||||
"""Return a configured vector search client."""
|
||||
return GoogleCloudVectorSearch(
|
||||
project_id=self.project_id,
|
||||
location=self.location,
|
||||
bucket=self.bucket,
|
||||
index_name=self.index_name,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def embedder(self) -> Embedder:
|
||||
"""Return an embedder configured for the agent's embedding model."""
|
||||
from pydantic_ai.embeddings.google import GoogleEmbeddingModel # noqa: PLC0415
|
||||
|
||||
model = GoogleEmbeddingModel(
|
||||
self.agent_embedding_model,
|
||||
provider=self.provider,
|
||||
)
|
||||
return Embedder(model)
|
||||
|
||||
|
||||
settings = Settings.model_validate({})
|
||||
|
||||
1
src/rag_eval/file_storage/__init__.py
Normal file
1
src/rag_eval/file_storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""File storage provider implementations."""
|
||||
56
src/rag_eval/file_storage/base.py
Normal file
56
src/rag_eval/file_storage/base.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Abstract base class for file storage providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO
|
||||
|
||||
|
||||
class BaseFileStorage(ABC):
|
||||
"""Abstract base class for a remote file processor.
|
||||
|
||||
Defines the interface for listing and processing files from
|
||||
a remote source.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""Upload a file to the remote source.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: Name of the file in remote storage.
|
||||
content_type: The content type of the file.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List files from a remote location.
|
||||
|
||||
Args:
|
||||
path: Path to a specific file or directory. If None,
|
||||
recursively lists all files in the bucket.
|
||||
|
||||
Returns:
|
||||
A list of file paths.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""Get a file from the remote source as a file-like object.
|
||||
|
||||
Args:
|
||||
file_name: The name of the file to retrieve.
|
||||
|
||||
Returns:
|
||||
A file-like object containing the file data.
|
||||
|
||||
"""
|
||||
...
|
||||
188
src/rag_eval/file_storage/google_cloud.py
Normal file
188
src/rag_eval/file_storage/google_cloud.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Google Cloud Storage file storage implementation."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
from typing import BinaryIO
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio.storage import Storage
|
||||
from google.cloud import storage
|
||||
|
||||
from rag_eval.file_storage.base import BaseFileStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HTTP_TOO_MANY_REQUESTS = 429
|
||||
HTTP_SERVER_ERROR = 500
|
||||
|
||||
|
||||
class GoogleCloudFileStorage(BaseFileStorage):
|
||||
"""File storage backed by Google Cloud Storage."""
|
||||
|
||||
def __init__(self, bucket: str) -> None: # noqa: D107
|
||||
self.bucket_name = bucket
|
||||
|
||||
self.storage_client = storage.Client()
|
||||
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._aio_storage: Storage | None = None
|
||||
self._cache: dict[str, bytes] = {}
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""Upload a file to Cloud Storage.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: Name of the blob in the bucket.
|
||||
content_type: The content type of the file.
|
||||
|
||||
"""
|
||||
blob = self.bucket_client.blob(destination_blob_name)
|
||||
blob.upload_from_filename(
|
||||
file_path,
|
||||
content_type=content_type,
|
||||
if_generation_match=0,
|
||||
)
|
||||
self._cache.pop(destination_blob_name, None)
|
||||
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List all files at the given path in the bucket.
|
||||
|
||||
If path is None, recursively lists all files.
|
||||
|
||||
Args:
|
||||
path: Prefix to filter files by.
|
||||
|
||||
Returns:
|
||||
A list of blob names.
|
||||
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(
|
||||
self.bucket_name, prefix=path,
|
||||
)
|
||||
return [blob.name for blob in blobs]
|
||||
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""Get a file as a file-like object, using cache.
|
||||
|
||||
Args:
|
||||
file_name: The blob name to retrieve.
|
||||
|
||||
Returns:
|
||||
A BytesIO stream with the file contents.
|
||||
|
||||
"""
|
||||
if file_name not in self._cache:
|
||||
blob = self.bucket_client.blob(file_name)
|
||||
self._cache[file_name] = blob.download_as_bytes()
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=300, limit_per_host=50,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector,
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def _get_aio_storage(self) -> Storage:
|
||||
if self._aio_storage is None:
|
||||
self._aio_storage = Storage(
|
||||
session=self._get_aio_session(),
|
||||
)
|
||||
return self._aio_storage
|
||||
|
||||
async def async_get_file_stream(
|
||||
self, file_name: str, max_retries: int = 3,
|
||||
) -> BinaryIO:
|
||||
"""Get a file asynchronously with retry on transient errors.
|
||||
|
||||
Args:
|
||||
file_name: The blob name to retrieve.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
|
||||
Returns:
|
||||
A BytesIO stream with the file contents.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If all retry attempts fail.
|
||||
|
||||
"""
|
||||
if file_name in self._cache:
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
storage_client = self._get_aio_storage()
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self._cache[file_name] = await storage_client.download(
|
||||
self.bucket_name, file_name,
|
||||
)
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
except TimeoutError as exc:
|
||||
last_exception = exc
|
||||
logger.warning(
|
||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
last_exception = exc
|
||||
if (
|
||||
exc.status == HTTP_TOO_MANY_REQUESTS
|
||||
or exc.status >= HTTP_SERVER_ERROR
|
||||
):
|
||||
logger.warning(
|
||||
"HTTP %d downloading gs://%s/%s "
|
||||
"(attempt %d/%d)",
|
||||
exc.status,
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
return file_stream
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
delay = 0.5 * (2**attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
msg = (
|
||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
||||
f"after {max_retries} attempts"
|
||||
)
|
||||
raise TimeoutError(msg) from last_exception
|
||||
|
||||
def delete_files(self, path: str) -> None:
|
||||
"""Delete all files at the given path in the bucket.
|
||||
|
||||
Args:
|
||||
path: Prefix of blobs to delete.
|
||||
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(
|
||||
self.bucket_name, prefix=path,
|
||||
)
|
||||
for blob in blobs:
|
||||
blob.delete()
|
||||
self._cache.pop(blob.name, None)
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Structured logging configuration using structlog."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def setup_logging(json: bool = True, level: int = logging.INFO) -> None:
|
||||
def setup_logging(*, json: bool = True, level: int = logging.INFO) -> None:
|
||||
"""Configure structlog with JSON or console output."""
|
||||
shared_processors: list[structlog.types.Processor] = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.add_log_level,
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
"""
|
||||
KFP pipeline definition for the full RAG evaluation pipeline.
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
|
||||
|
||||
@dsl.component(base_image=config.base_image)
|
||||
def run_synth_gen_op(num_questions: int, config_yaml_content: str) -> str:
|
||||
"""
|
||||
KFP component to run the synthetic question generation script.
|
||||
Returns the generated run_id.
|
||||
"""
|
||||
CONFIG_FILE_PATH = "config.yaml"
|
||||
if config_yaml_content:
|
||||
with open(CONFIG_FILE_PATH, "w") as f:
|
||||
f.write(config_yaml_content)
|
||||
from synth_gen.main import generate as synth_gen_generate
|
||||
|
||||
return synth_gen_generate(num_questions=num_questions)
|
||||
|
||||
|
||||
@dsl.component(base_image=config.base_image)
|
||||
def run_search_eval_op(config_yaml_content: str, run_id: str = None):
|
||||
"""
|
||||
KFP component to run the search evaluation script by direct import.
|
||||
Optionally filters by run_id.
|
||||
"""
|
||||
CONFIG_FILE_PATH = "config.yaml"
|
||||
if config_yaml_content:
|
||||
with open(CONFIG_FILE_PATH, "w") as f:
|
||||
f.write(config_yaml_content)
|
||||
from search_eval.main import evaluate as search_eval_evaluate
|
||||
|
||||
search_eval_evaluate(run_id=run_id)
|
||||
|
||||
|
||||
@dsl.component(base_image=config.base_image)
|
||||
def run_keypoint_eval_op(config_yaml_content: str, run_id: str = None):
|
||||
"""
|
||||
KFP component to run the keypoint evaluation script by direct import.
|
||||
Optionally filters by run_id.
|
||||
"""
|
||||
CONFIG_FILE_PATH = "config.yaml"
|
||||
if config_yaml_content:
|
||||
with open(CONFIG_FILE_PATH, "w") as f:
|
||||
f.write(config_yaml_content)
|
||||
from keypoint_eval.main import run_keypoint_evaluation
|
||||
|
||||
run_keypoint_evaluation(run_id=run_id)
|
||||
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="generative-evaluation-pipeline",
|
||||
description="A pipeline that generates synthetic questions and then runs search evaluation on them.",
|
||||
)
|
||||
def generative_evaluation_pipeline(num_questions: int = 20, config_yaml_content: str = ""):
|
||||
"""
|
||||
Defines the generative evaluation pipeline structure.
|
||||
1. Generates synthetic questions.
|
||||
2. Runs search evaluation on the generated questions.
|
||||
3. Runs keypoint evaluation on the generated questions.
|
||||
"""
|
||||
synth_gen_task = run_synth_gen_op(
|
||||
num_questions=num_questions, config_yaml_content=config_yaml_content
|
||||
)
|
||||
|
||||
run_search_eval_op(
|
||||
run_id=synth_gen_task.output, config_yaml_content=config_yaml_content
|
||||
)
|
||||
run_keypoint_eval_op(
|
||||
run_id=synth_gen_task.output, config_yaml_content=config_yaml_content
|
||||
)
|
||||
@@ -1,53 +0,0 @@
|
||||
"""
|
||||
This script defines a simplified Kubeflow Pipeline (KFP) for ingesting and
|
||||
processing documents for a single agent by calling the index-gen CLI.
|
||||
"""
|
||||
|
||||
from kfp import dsl
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
# --- KFP Components ---------------------------------------------------------
|
||||
|
||||
|
||||
@dsl.component(base_image=settings.processing_image)
|
||||
def run_index_gen_cli(config_yaml_content: str):
|
||||
"""Runs the index-gen CLI."""
|
||||
|
||||
CONFIG_FILE_PATH = "config.yaml"
|
||||
#CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "config.yaml")
|
||||
if config_yaml_content:
|
||||
with open(CONFIG_FILE_PATH, "w") as f:
|
||||
f.write(config_yaml_content)
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
# The command needs to be installed in the environment.
|
||||
# Assuming the processing_image has the project installed.
|
||||
command = ["uv", "run", "index-gen"]
|
||||
logging.info(f"Running command: {' '.join(command)}")
|
||||
|
||||
# Using subprocess.run to capture output and check for errors
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=False)
|
||||
|
||||
logging.info(result.stdout)
|
||||
if result.stderr:
|
||||
logging.error(result.stderr)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"index-gen CLI failed with return code {result.returncode}")
|
||||
|
||||
|
||||
# --- KFP Pipeline Definition ------------------------------------------------
|
||||
|
||||
|
||||
@dsl.pipeline(
|
||||
name="rag-ingestion-pipeline",
|
||||
description="A pipeline to run index-gen.",
|
||||
)
|
||||
def ingestion_pipeline(
|
||||
config_yaml_content: str = "",
|
||||
):
|
||||
"""Defines the KFP pipeline structure."""
|
||||
run_index_gen_cli(config_yaml_content=config_yaml_content)
|
||||
@@ -1,146 +0,0 @@
|
||||
"""
|
||||
Command-line interface for submitting pipelines.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from google.cloud import aiplatform
|
||||
from kfp import compiler
|
||||
|
||||
from rag_eval.config import CONFIG_FILE_PATH, settings
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def ingestion(
|
||||
compile_only: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--compile-only",
|
||||
help="If set, compiles the pipeline and exits without running it.",
|
||||
),
|
||||
] = False,
|
||||
):
|
||||
"""Compiles and/or runs the KFP ingestion pipeline."""
|
||||
from rag_eval.pipelines.ingestion import ingestion_pipeline
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
pipeline_root = f"gs://{settings.bucket}"
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as temp_pipeline_file:
|
||||
pipeline_spec_path = temp_pipeline_file.name
|
||||
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=ingestion_pipeline, package_path=pipeline_spec_path
|
||||
)
|
||||
logging.info(f"Pipeline compiled to temporary file: {pipeline_spec_path}")
|
||||
|
||||
if compile_only:
|
||||
logging.info(
|
||||
"Compilation successful. Temporary pipeline file will be deleted."
|
||||
)
|
||||
return
|
||||
|
||||
logging.info(f"Reading config from {CONFIG_FILE_PATH}")
|
||||
with open(CONFIG_FILE_PATH, "r") as f:
|
||||
config_yaml_content = f.read()
|
||||
|
||||
aiplatform.init(project=settings.project_id, location=settings.location)
|
||||
|
||||
job = aiplatform.PipelineJob(
|
||||
display_name=f"rag-ingestion-run-{settings.agent.name}",
|
||||
template_path=pipeline_spec_path,
|
||||
pipeline_root=pipeline_root,
|
||||
parameter_values={
|
||||
"config_yaml_content": config_yaml_content,
|
||||
},
|
||||
enable_caching=False,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
"Submitting pipeline job to Vertex AI... This will wait for the job to complete."
|
||||
)
|
||||
job.run(sync=True, service_account=settings.service_account)
|
||||
|
||||
logging.info("Pipeline finished. Fetching endpoint details...")
|
||||
|
||||
if not settings.index or not settings.index.name:
|
||||
logging.warning(
|
||||
"No index configuration found. Cannot fetch endpoint."
|
||||
)
|
||||
return
|
||||
|
||||
expected_endpoint_name = f"{settings.index.name}-endpoint"
|
||||
logging.info(f"Looking for endpoint with display name: {expected_endpoint_name}")
|
||||
|
||||
endpoints = aiplatform.MatchingEngineIndexEndpoint.list()
|
||||
found_endpoint = next(
|
||||
(e for e in endpoints if e.display_name == expected_endpoint_name), None
|
||||
)
|
||||
|
||||
if not found_endpoint:
|
||||
logging.warning(
|
||||
"Could not find a matching deployed endpoint."
|
||||
)
|
||||
else:
|
||||
print("\n--- Deployed Index Endpoint ---")
|
||||
print(f"Display Name: {found_endpoint.display_name}")
|
||||
print(f" Resource Name: {found_endpoint.resource_name}")
|
||||
if found_endpoint.public_endpoint_domain_name:
|
||||
print(f" Public URI: {found_endpoint.public_endpoint_domain_name}")
|
||||
else:
|
||||
print(" Public URI: Not available")
|
||||
print("-" * 20)
|
||||
|
||||
|
||||
@app.command()
|
||||
def evaluation():
|
||||
"""
|
||||
Compiles and runs the RAG evaluation Vertex AI pipeline without saving the compiled file.
|
||||
"""
|
||||
from rag_eval.pipelines.evaluation import generative_evaluation_pipeline
|
||||
project_id = settings.project_id
|
||||
location = settings.location
|
||||
pipeline_root = settings.pipeline_root
|
||||
pipeline_name = "rag-evaluation-pipeline"
|
||||
|
||||
# Create a temporary file to store the compiled pipeline
|
||||
with tempfile.NamedTemporaryFile(suffix=".json") as temp_file:
|
||||
pipeline_path = temp_file.name
|
||||
|
||||
print(f"Compiling pipeline to temporary file: {pipeline_path}")
|
||||
compiler.Compiler().compile(
|
||||
pipeline_func=generative_evaluation_pipeline,
|
||||
package_path=pipeline_path,
|
||||
)
|
||||
|
||||
# Read the config.yaml content to pass to the pipeline
|
||||
print(f"Reading config from {CONFIG_FILE_PATH}")
|
||||
with open(CONFIG_FILE_PATH, "r") as f:
|
||||
config_yaml_content = f.read()
|
||||
|
||||
print("Submitting pipeline job to Vertex AI...")
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
|
||||
job = aiplatform.PipelineJob(
|
||||
display_name=pipeline_name,
|
||||
template_path=pipeline_path,
|
||||
pipeline_root=pipeline_root,
|
||||
parameter_values={
|
||||
"config_yaml_content": config_yaml_content,
|
||||
},
|
||||
enable_caching=False,
|
||||
)
|
||||
|
||||
job.run(service_account=settings.service_account)
|
||||
print(
|
||||
f"Pipeline job submitted. View it in the Vertex AI console: {job.gca_resource.name}"
|
||||
)
|
||||
|
||||
print(f"Temporary compiled file {pipeline_path} has been removed.")
|
||||
|
||||
|
||||
61
src/rag_eval/server.py
Normal file
61
src/rag_eval/server.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""FastAPI server exposing the RAG agent endpoint."""
|
||||
|
||||
import time
|
||||
from typing import Literal
|
||||
from uuid import uuid4
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from rag_eval.agent import Deps, agent
|
||||
from rag_eval.config import settings
|
||||
from rag_eval.logging import setup_logging
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
setup_logging()
|
||||
|
||||
app = FastAPI(title="RAG Agent")
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""A single chat message."""
|
||||
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class AgentRequest(BaseModel):
|
||||
"""Request body for the agent endpoint."""
|
||||
|
||||
messages: list[Message]
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response body from the agent endpoint."""
|
||||
|
||||
response: str
|
||||
|
||||
|
||||
@app.post("/agent")
|
||||
async def run_agent(request: AgentRequest) -> AgentResponse:
|
||||
"""Run the RAG agent with the provided messages."""
|
||||
request_id = uuid4().hex[:8]
|
||||
structlog.contextvars.clear_contextvars()
|
||||
structlog.contextvars.bind_contextvars(request_id=request_id)
|
||||
|
||||
prompt = request.messages[-1].content
|
||||
logger.info("request.start", prompt_length=len(prompt))
|
||||
t0 = time.perf_counter()
|
||||
|
||||
deps = Deps(
|
||||
vector_search=settings.vector_search,
|
||||
embedder=settings.embedder,
|
||||
)
|
||||
result = await agent.run(prompt, deps=deps)
|
||||
|
||||
elapsed = round((time.perf_counter() - t0) * 1000, 1)
|
||||
logger.info("request.end", elapsed_ms=elapsed)
|
||||
|
||||
return AgentResponse(response=result.output)
|
||||
@@ -1,10 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
from rag_eval.logging import setup_logging
|
||||
|
||||
setup_logging()
|
||||
|
||||
from .routes import rag_router # noqa: E402
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(rag_router)
|
||||
@@ -1,52 +0,0 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ---------------------------------- #
|
||||
# Maps Gemini response
|
||||
class Answer(BaseModel):
|
||||
answer: str
|
||||
|
||||
|
||||
# ---------------------------------- #
|
||||
# Dialogflow request parameters
|
||||
class RequestParameters(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
class RequestSessionInfo(BaseModel):
|
||||
parameters: RequestParameters
|
||||
|
||||
|
||||
class Request(BaseModel):
|
||||
sessionInfo: RequestSessionInfo
|
||||
|
||||
|
||||
# ---------------------------------- #
|
||||
# Dialogflow response parameters
|
||||
class ResponseParameters(BaseModel):
|
||||
webhook_success: bool
|
||||
response: str
|
||||
pregunta_nueva: str = "NO"
|
||||
fin_turno: bool = True
|
||||
respuesta_entregada: bool = True
|
||||
|
||||
|
||||
class ResponseSessionInfo(BaseModel):
|
||||
parameters: ResponseParameters
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
sessionInfo: ResponseSessionInfo
|
||||
|
||||
|
||||
# ---------------------------------- #
|
||||
# Dialogflow proxy models
|
||||
class DialogflowRequest(BaseModel):
|
||||
prompt: str
|
||||
session_id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
class DialogflowResponse(BaseModel):
|
||||
response: dict
|
||||
@@ -1,75 +0,0 @@
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
|
||||
from ..agent import Agent
|
||||
from .models import (
|
||||
DialogflowRequest,
|
||||
DialogflowResponse,
|
||||
Request,
|
||||
Response,
|
||||
ResponseParameters,
|
||||
ResponseSessionInfo,
|
||||
)
|
||||
|
||||
try:
|
||||
from dialogflow.main import DialogflowAgent
|
||||
except ImportError:
|
||||
DialogflowAgent = None
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
rag_router = APIRouter()
|
||||
|
||||
sigma_agent = Agent()
|
||||
|
||||
|
||||
@rag_router.post("/sigma-rag", response_model=Response)
|
||||
async def generate_sigma_agent(request: Request):
|
||||
request_id = uuid4().hex[:8]
|
||||
structlog.contextvars.clear_contextvars()
|
||||
structlog.contextvars.bind_contextvars(request_id=request_id)
|
||||
|
||||
prompt = request.sessionInfo.parameters.query
|
||||
logger.info("request.start", prompt_length=len(prompt))
|
||||
t0 = time.perf_counter()
|
||||
|
||||
answer = await sigma_agent.async_call(prompt)
|
||||
|
||||
elapsed = round((time.perf_counter() - t0) * 1000, 1)
|
||||
logger.info("request.end", elapsed_ms=elapsed)
|
||||
|
||||
response = Response(
|
||||
sessionInfo=ResponseSessionInfo(
|
||||
parameters=ResponseParameters(
|
||||
webhook_success=True,
|
||||
response=answer,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
if DialogflowAgent:
|
||||
|
||||
@rag_router.post("/dialogflow-proxy", response_model=DialogflowResponse)
|
||||
async def dialogflow_proxy(request: DialogflowRequest):
|
||||
"""
|
||||
Proxies a message to a Dialogflow agent.
|
||||
|
||||
This endpoint is only available if the 'dialogflow' package is installed.
|
||||
"""
|
||||
try:
|
||||
print(request)
|
||||
agent = DialogflowAgent()
|
||||
response = agent.call(query=request.prompt, session_id=request.session_id)
|
||||
return DialogflowResponse(response=response)
|
||||
except Exception as e:
|
||||
logger.error("Error calling Dialogflow agent", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error communicating with Dialogflow agent.",
|
||||
)
|
||||
1
src/rag_eval/vector_search/__init__.py
Normal file
1
src/rag_eval/vector_search/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Vector search provider implementations."""
|
||||
68
src/rag_eval/vector_search/base.py
Normal file
68
src/rag_eval/vector_search/base.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Abstract base class for vector search providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
"""A single vector search result."""
|
||||
|
||||
id: str
|
||||
distance: float
|
||||
content: str
|
||||
|
||||
|
||||
class BaseVectorSearch(ABC):
|
||||
"""Abstract base class for a vector search provider.
|
||||
|
||||
This class defines the standard interface for creating a vector search
|
||||
index and running queries against it.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_index(
|
||||
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
||||
) -> None:
|
||||
"""Create a new vector search index with the provided content.
|
||||
|
||||
Args:
|
||||
name: The desired name for the new index.
|
||||
content_path: Path to the data used to populate the index.
|
||||
**kwargs: Additional provider-specific arguments.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def update_index(
|
||||
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
||||
) -> None:
|
||||
"""Update an existing vector search index with new content.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to update.
|
||||
content_path: Path to the data used to populate the index.
|
||||
**kwargs: Additional provider-specific arguments.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: list[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run a similarity search query against the index.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
...
|
||||
310
src/rag_eval/vector_search/vertex_ai.py
Normal file
310
src/rag_eval/vector_search/vertex_ai.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Google Cloud Vertex AI Vector Search implementation."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import google.auth
|
||||
import google.auth.credentials
|
||||
import google.auth.transport.requests
|
||||
from gcloud.aio.auth import Token
|
||||
from google.cloud import aiplatform
|
||||
|
||||
from rag_eval.file_storage.google_cloud import GoogleCloudFileStorage
|
||||
from rag_eval.vector_search.base import BaseVectorSearch, SearchResult
|
||||
|
||||
|
||||
class GoogleCloudVectorSearch(BaseVectorSearch):
|
||||
"""A vector search provider using Vertex AI Vector Search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
bucket: str,
|
||||
index_name: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the GoogleCloudVectorSearch client.
|
||||
|
||||
Args:
|
||||
project_id: The Google Cloud project ID.
|
||||
location: The Google Cloud location (e.g., 'us-central1').
|
||||
bucket: The GCS bucket to use for file storage.
|
||||
index_name: The name of the index.
|
||||
|
||||
"""
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||
self.index_name = index_name
|
||||
self._credentials: google.auth.credentials.Credentials | None = None
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._async_token: Token | None = None
|
||||
|
||||
def _get_auth_headers(self) -> dict[str, str]:
|
||||
if self._credentials is None:
|
||||
self._credentials, _ = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
if not self._credentials.token or self._credentials.expired:
|
||||
self._credentials.refresh(
|
||||
google.auth.transport.requests.Request(),
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {self._credentials.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def _async_get_auth_headers(self) -> dict[str, str]:
|
||||
if self._async_token is None:
|
||||
self._async_token = Token(
|
||||
session=self._get_aio_session(),
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
],
|
||||
)
|
||||
access_token = await self._async_token.get()
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=300, limit_per_host=50,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector,
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
name: str,
|
||||
content_path: str,
|
||||
*,
|
||||
dimensions: int = 3072,
|
||||
approximate_neighbors_count: int = 150,
|
||||
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
||||
**kwargs: Any, # noqa: ANN401, ARG002
|
||||
) -> None:
|
||||
"""Create a new Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
name: The display name for the new index.
|
||||
content_path: GCS URI to the embeddings JSON file.
|
||||
dimensions: Number of dimensions in embedding vectors.
|
||||
approximate_neighbors_count: Neighbors to find per vector.
|
||||
distance_measure_type: The distance measure to use.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
||||
display_name=name,
|
||||
contents_delta_uri=content_path,
|
||||
dimensions=dimensions,
|
||||
approximate_neighbors_count=approximate_neighbors_count,
|
||||
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
|
||||
leaf_node_embedding_count=1000,
|
||||
leaf_nodes_to_search_percent=10,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def update_index(
|
||||
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
|
||||
) -> None:
|
||||
"""Update an existing Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index to update.
|
||||
content_path: GCS URI to the new embeddings JSON file.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
||||
index.update_embeddings(
|
||||
contents_delta_uri=content_path,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def deploy_index(
|
||||
self,
|
||||
index_name: str,
|
||||
machine_type: str = "e2-standard-2",
|
||||
) -> None:
|
||||
"""Deploy a Vertex AI Vector Search index to an endpoint.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to deploy.
|
||||
machine_type: The machine type for the endpoint.
|
||||
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
||||
display_name=f"{index_name}-endpoint",
|
||||
public_endpoint_enabled=True,
|
||||
)
|
||||
index_endpoint.deploy_index(
|
||||
index=self.index,
|
||||
deployed_index_id=(
|
||||
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
|
||||
),
|
||||
machine_type=machine_type,
|
||||
)
|
||||
self.index_endpoint = index_endpoint
|
||||
|
||||
def load_index_endpoint(self, endpoint_name: str) -> None:
|
||||
"""Load an existing Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: The resource name of the index endpoint.
|
||||
|
||||
"""
|
||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
||||
endpoint_name,
|
||||
)
|
||||
if not self.index_endpoint.public_endpoint_domain_name:
|
||||
msg = (
|
||||
"The index endpoint does not have a public endpoint. "
|
||||
"Ensure the endpoint is configured for public access."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: list[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run a similarity search query against the deployed index.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
response = self.index_endpoint.find_neighbors(
|
||||
deployed_index_id=deployed_index_id,
|
||||
queries=[query],
|
||||
num_neighbors=limit,
|
||||
)
|
||||
results = []
|
||||
for neighbor in response[0]:
|
||||
file_path = (
|
||||
f"{self.index_name}/contents/{neighbor.id}.md"
|
||||
)
|
||||
content = (
|
||||
self.storage.get_file_stream(file_path)
|
||||
.read()
|
||||
.decode("utf-8")
|
||||
)
|
||||
results.append(
|
||||
SearchResult(
|
||||
id=neighbor.id,
|
||||
distance=float(neighbor.distance or 0),
|
||||
content=content,
|
||||
),
|
||||
)
|
||||
return results
|
||||
|
||||
async def async_run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: Sequence[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run an async similarity search via the REST API.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
domain = self.index_endpoint.public_endpoint_domain_name
|
||||
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
||||
url = (
|
||||
f"https://{domain}/v1/projects/{self.project_id}"
|
||||
f"/locations/{self.location}"
|
||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||
)
|
||||
payload = {
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"queries": [
|
||||
{
|
||||
"datapoint": {"feature_vector": list(query)},
|
||||
"neighbor_count": limit,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
headers = await self._async_get_auth_headers()
|
||||
session = self._get_aio_session()
|
||||
async with session.post(
|
||||
url, json=payload, headers=headers,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
neighbors = (
|
||||
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||
)
|
||||
content_tasks = []
|
||||
for neighbor in neighbors:
|
||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||
file_path = (
|
||||
f"{self.index_name}/contents/{datapoint_id}.md"
|
||||
)
|
||||
content_tasks.append(
|
||||
self.storage.async_get_file_stream(file_path),
|
||||
)
|
||||
|
||||
file_streams = await asyncio.gather(*content_tasks)
|
||||
results: list[SearchResult] = []
|
||||
for neighbor, stream in zip(
|
||||
neighbors, file_streams, strict=True,
|
||||
):
|
||||
results.append(
|
||||
SearchResult(
|
||||
id=neighbor["datapoint"]["datapointId"],
|
||||
distance=neighbor["distance"],
|
||||
content=stream.read().decode("utf-8"),
|
||||
),
|
||||
)
|
||||
return results
|
||||
|
||||
def delete_index(self, index_name: str) -> None:
|
||||
"""Delete a Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index.
|
||||
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name)
|
||||
index.delete()
|
||||
|
||||
def delete_index_endpoint(
|
||||
self, index_endpoint_name: str,
|
||||
) -> None:
|
||||
"""Delete a Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
index_endpoint_name: The resource name of the endpoint.
|
||||
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
||||
index_endpoint_name,
|
||||
)
|
||||
index_endpoint.undeploy_all()
|
||||
index_endpoint.delete(force=True)
|
||||
Reference in New Issue
Block a user