First commit
This commit is contained in:
3
.dockerignore
Normal file
3
.dockerignore
Normal file
@@ -0,0 +1,3 @@
|
||||
.venv
|
||||
tmp
|
||||
GEMINI.md
|
||||
14
.gitignore
vendored
Normal file
14
.gitignore
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
.env
|
||||
.ipynb_checkpoints
|
||||
tmp/*
|
||||
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12
|
||||
3
AGENTS.md
Normal file
3
AGENTS.md
Normal file
@@ -0,0 +1,3 @@
|
||||
Use 'uv' for project management.
|
||||
- `uv add`
|
||||
- `uv run`
|
||||
13
DockerfileConnector
Normal file
13
DockerfileConnector
Normal file
@@ -0,0 +1,13 @@
|
||||
FROM quay.ocp.banorte.com/golden/python-312:latest
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN uv sync --group rag -U
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
CMD ["uv", "run", "uvicorn", "rag_eval.server:app", "--host", "0.0.0.0"]
|
||||
11
DockerfileEval
Normal file
11
DockerfileEval
Normal file
@@ -0,0 +1,11 @@
|
||||
FROM quay.ocp.banorte.com/golden/python-312:latest
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN uv sync --group evals
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
9
DockerfileProcessor
Normal file
9
DockerfileProcessor
Normal file
@@ -0,0 +1,9 @@
|
||||
FROM quay.ocp.banorte.com/golden/python-312:latest
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN uv sync --no-managed-python --group processor -U
|
||||
92
README.md
Normal file
92
README.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# rag-edu-fin
|
||||
|
||||
|
||||
## Getting started
|
||||
|
||||
To make it easy for you to get started with GitLab, here's a list of recommended next steps.
|
||||
|
||||
Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)!
|
||||
|
||||
## Add your files
|
||||
|
||||
- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files
|
||||
- [ ] [Add files using the command line](https://docs.gitlab.com/topics/git/add_files/#add-files-to-a-git-repository) or push an existing Git repository with the following command:
|
||||
|
||||
```
|
||||
cd existing_repo
|
||||
git remote add origin https://lnxocpgit1.dev.ocp.banorte.com:5443/desarrollo/evoluci-n-tecnol-gica/ap01194-orq-cog/autoservicio/rag-edu-fin.git
|
||||
git branch -M main
|
||||
git push -uf origin main
|
||||
```
|
||||
|
||||
## Integrate with your tools
|
||||
|
||||
- [ ] [Set up project integrations](https://lnxocpgit1.dev.ocp.banorte.com:5443/desarrollo/evoluci-n-tecnol-gica/ap01194-orq-cog/autoservicio/rag-edu-fin/-/settings/integrations)
|
||||
|
||||
## Collaborate with your team
|
||||
|
||||
- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/)
|
||||
- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html)
|
||||
- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically)
|
||||
- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/)
|
||||
- [ ] [Set auto-merge](https://docs.gitlab.com/user/project/merge_requests/auto_merge/)
|
||||
|
||||
## Test and Deploy
|
||||
|
||||
Use the built-in continuous integration in GitLab.
|
||||
|
||||
- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/)
|
||||
- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing (SAST)](https://docs.gitlab.com/ee/user/application_security/sast/)
|
||||
- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html)
|
||||
- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/)
|
||||
- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html)
|
||||
|
||||
***
|
||||
|
||||
# Editing this README
|
||||
|
||||
When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thanks to [makeareadme.com](https://www.makeareadme.com/) for this template.
|
||||
|
||||
## Suggestions for a good README
|
||||
|
||||
Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information.
|
||||
|
||||
## Name
|
||||
Choose a self-explaining name for your project.
|
||||
|
||||
## Description
|
||||
Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors.
|
||||
|
||||
## Badges
|
||||
On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge.
|
||||
|
||||
## Visuals
|
||||
Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method.
|
||||
|
||||
## Installation
|
||||
Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection.
|
||||
|
||||
## Usage
|
||||
Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README.
|
||||
|
||||
## Support
|
||||
Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc.
|
||||
|
||||
## Roadmap
|
||||
If you have ideas for releases in the future, it is a good idea to list them in the README.
|
||||
|
||||
## Contributing
|
||||
State if you are open to contributions and what your requirements are for accepting them.
|
||||
|
||||
For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self.
|
||||
|
||||
You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser.
|
||||
|
||||
## Authors and acknowledgment
|
||||
Show your appreciation to those who have contributed to the project.
|
||||
|
||||
## License
|
||||
For open source projects, say how it is licensed.
|
||||
|
||||
## Project status
|
||||
If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers.
|
||||
0
apps/index-gen/README.md
Normal file
0
apps/index-gen/README.md
Normal file
34
apps/index-gen/pyproject.toml
Normal file
34
apps/index-gen/pyproject.toml
Normal file
@@ -0,0 +1,34 @@
|
||||
[project]
|
||||
name = "index-gen"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"chunker",
|
||||
"document-converter",
|
||||
"embedder",
|
||||
"file-storage",
|
||||
"llm",
|
||||
"utils",
|
||||
"vector-search",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
index-gen = "index_gen.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.12,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.uv.sources]
|
||||
file-storage = { workspace = true }
|
||||
vector-search = { workspace = true }
|
||||
utils = { workspace = true }
|
||||
embedder = { workspace = true }
|
||||
chunker = { workspace = true }
|
||||
document-converter = { workspace = true }
|
||||
llm = { workspace = true }
|
||||
2
apps/index-gen/src/index_gen/__init__.py
Normal file
2
apps/index-gen/src/index_gen/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def main() -> None:
|
||||
print("Hello from index-gen!")
|
||||
68
apps/index-gen/src/index_gen/cli.py
Normal file
68
apps/index-gen/src/index_gen/cli.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
|
||||
from index_gen.main import (
|
||||
aggregate_vectors,
|
||||
build_gcs_path,
|
||||
create_vector_index,
|
||||
gather_files,
|
||||
process_file,
|
||||
)
|
||||
from rag_eval.config import settings
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_ingestion():
|
||||
"""Main function for the CLI script."""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
agent_config = settings.agent
|
||||
index_config = settings.index
|
||||
|
||||
if not agent_config or not index_config:
|
||||
raise ValueError("Agent or index configuration not found in config.yaml")
|
||||
|
||||
# Gather files
|
||||
files = gather_files(index_config.origin)
|
||||
|
||||
# Build output paths
|
||||
contents_output_dir = build_gcs_path(index_config.data, "/contents")
|
||||
vectors_output_dir = build_gcs_path(index_config.data, "/vectors")
|
||||
aggregated_vectors_gcs_path = build_gcs_path(
|
||||
index_config.data, "/vectors/vectors.json"
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
vector_artifact_paths = []
|
||||
|
||||
# Process files and create local artifacts
|
||||
for i, file in enumerate(files):
|
||||
artifact_path = temp_dir_path / f"vectors_{i}.jsonl"
|
||||
vector_artifact_paths.append(artifact_path)
|
||||
|
||||
process_file(
|
||||
file,
|
||||
agent_config.embedding_model,
|
||||
contents_output_dir,
|
||||
artifact_path, # Pass the local path
|
||||
index_config.chunk_limit,
|
||||
)
|
||||
|
||||
# Aggregate the local artifacts into one file in GCS
|
||||
aggregate_vectors(
|
||||
vector_artifacts=vector_artifact_paths,
|
||||
output_gcs_path=aggregated_vectors_gcs_path,
|
||||
)
|
||||
|
||||
# Create vector index
|
||||
create_vector_index(vectors_output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
225
apps/index-gen/src/index_gen/main.py
Normal file
225
apps/index-gen/src/index_gen/main.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
This script defines a Kubeflow Pipeline (KFP) for ingesting and processing documents.
|
||||
|
||||
The pipeline is designed to run on Vertex AI Pipelines and consists of the following steps:
|
||||
1. **Gather Files**: Scans a GCS directory for PDF files to process.
|
||||
2. **Process Files (in parallel)**: For each PDF file found, this step:
|
||||
a. Converts the PDF to Markdown text.
|
||||
b. Chunks the text if it's too long.
|
||||
c. Generates a vector embedding for each chunk using a Vertex AI embedding model.
|
||||
d. Saves the markdown content and the vector embedding to separate GCS output paths.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
|
||||
def build_gcs_path(base_path: str, suffix: str) -> str:
|
||||
"""Builds a GCS path by appending a suffix."""
|
||||
return f"{base_path}{suffix}"
|
||||
|
||||
|
||||
def gather_files(
|
||||
input_dir: str,
|
||||
) -> list:
|
||||
"""Gathers all PDF file paths from a GCS directory."""
|
||||
from google.cloud import storage
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
gcs_client = storage.Client()
|
||||
bucket_name, prefix = input_dir.replace("gs://", "").split("/", 1)
|
||||
bucket = gcs_client.bucket(bucket_name)
|
||||
blob_list = bucket.list_blobs(prefix=prefix)
|
||||
|
||||
pdf_files = [
|
||||
f"gs://{bucket_name}/{blob.name}"
|
||||
for blob in blob_list
|
||||
if blob.name.endswith(".pdf")
|
||||
]
|
||||
logging.info(f"Found {len(pdf_files)} PDF files in {input_dir}")
|
||||
return pdf_files
|
||||
|
||||
|
||||
def process_file(
|
||||
file_path: str,
|
||||
model_name: str,
|
||||
contents_output_dir: str,
|
||||
vectors_output_file: Path,
|
||||
chunk_limit: int,
|
||||
):
|
||||
"""
|
||||
Processes a single PDF file: converts to markdown, chunks, and generates embeddings.
|
||||
The vector embeddings are written to a local JSONL file.
|
||||
"""
|
||||
# Imports are inside the function as KFP serializes this function
|
||||
from pathlib import Path
|
||||
|
||||
from chunker.contextual_chunker import ContextualChunker
|
||||
from document_converter.markdown import MarkdownConverter
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from google.cloud import storage
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from utils.normalize_filenames import normalize_string
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
# Initialize converters and embedders
|
||||
converter = MarkdownConverter()
|
||||
embedder = VertexAIEmbedder(model_name=model_name, project=settings.project_id, location=settings.location)
|
||||
llm = VertexAILLM(project=settings.project_id, location=settings.location)
|
||||
chunker = ContextualChunker(llm_client=llm, max_chunk_size=chunk_limit)
|
||||
gcs_client = storage.Client()
|
||||
|
||||
file_id = normalize_string(Path(file_path).stem)
|
||||
local_path = Path(f"/tmp/{Path(file_path).name}")
|
||||
|
||||
with open(vectors_output_file, "w", encoding="utf-8") as f:
|
||||
try:
|
||||
# Download file from GCS
|
||||
bucket_name, blob_name = file_path.replace("gs://", "").split("/", 1)
|
||||
bucket = gcs_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
blob.download_to_filename(local_path)
|
||||
logging.info(f"Processing file: {file_path}")
|
||||
|
||||
# Process the downloaded file
|
||||
markdown_content = converter.process_file(local_path)
|
||||
|
||||
def upload_to_gcs(bucket_name, blob_name, data):
|
||||
bucket = gcs_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
blob.upload_from_string(data, content_type="text/markdown; charset=utf-8")
|
||||
|
||||
# Determine output bucket and paths for markdown
|
||||
contents_bucket_name, contents_prefix = contents_output_dir.replace(
|
||||
"gs://", ""
|
||||
).split("/", 1)
|
||||
|
||||
if len(markdown_content) > chunk_limit:
|
||||
chunks = chunker.process_text(markdown_content)
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_id = f"{file_id}_{i}"
|
||||
embedding = embedder.generate_embedding(chunk["page_content"])
|
||||
|
||||
# Upload markdown chunk
|
||||
md_blob_name = f"{contents_prefix}/{chunk_id}.md"
|
||||
upload_to_gcs(
|
||||
contents_bucket_name, md_blob_name, chunk["page_content"]
|
||||
)
|
||||
|
||||
# Write vector to local JSONL file
|
||||
json_line = json.dumps({"id": chunk_id, "embedding": embedding})
|
||||
f.write(json_line + '\n')
|
||||
else:
|
||||
embedding = embedder.generate_embedding(markdown_content)
|
||||
|
||||
# Upload markdown
|
||||
md_blob_name = f"{contents_prefix}/{file_id}.md"
|
||||
upload_to_gcs(contents_bucket_name, md_blob_name, markdown_content)
|
||||
|
||||
# Write vector to local JSONL file
|
||||
json_line = json.dumps({"id": file_id, "embedding": embedding})
|
||||
f.write(json_line + '\n')
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process file {file_path}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up the downloaded file
|
||||
if os.path.exists(local_path):
|
||||
os.remove(local_path)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def aggregate_vectors(
|
||||
vector_artifacts: list, # This will be a list of paths to the artifact files
|
||||
output_gcs_path: str,
|
||||
):
|
||||
"""
|
||||
Aggregates multiple JSONL artifact files into a single JSONL file in GCS.
|
||||
"""
|
||||
from google.cloud import storage
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
# Create a temporary file to aggregate all vector data
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, encoding="utf-8"
|
||||
) as temp_agg_file:
|
||||
logging.info(f"Aggregating vectors into temporary file: {temp_agg_file.name}")
|
||||
for artifact_path in vector_artifacts:
|
||||
with open(artifact_path, "r", encoding="utf-8") as f:
|
||||
# Each line is a complete JSON object
|
||||
for line in f:
|
||||
temp_agg_file.write(line) # line already includes newline
|
||||
|
||||
temp_file_path = temp_agg_file.name
|
||||
|
||||
logging.info("Uploading aggregated file to GCS...")
|
||||
gcs_client = storage.Client()
|
||||
bucket_name, blob_name = output_gcs_path.replace("gs://", "").split("/", 1)
|
||||
bucket = gcs_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
blob.upload_from_filename(temp_file_path, content_type="application/json; charset=utf-8")
|
||||
|
||||
logging.info(f"Successfully uploaded aggregated vectors to {output_gcs_path}")
|
||||
|
||||
# Clean up the temporary file
|
||||
import os
|
||||
|
||||
os.remove(temp_file_path)
|
||||
|
||||
|
||||
|
||||
def create_vector_index(
|
||||
vectors_dir: str,
|
||||
):
|
||||
"""Creates and deploys a Vertex AI Vector Search Index."""
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
try:
|
||||
index_config = config.index
|
||||
|
||||
logging.info(
|
||||
f"Initializing Vertex AI client for project '{config.project_id}' in '{config.location}'..."
|
||||
)
|
||||
vector_search = GoogleCloudVectorSearch(
|
||||
project_id=config.project_id,
|
||||
location=config.location,
|
||||
bucket=config.bucket,
|
||||
index_name=index_config.name,
|
||||
)
|
||||
|
||||
logging.info(f"Starting creation of index '{index_config.name}'...")
|
||||
vector_search.create_index(
|
||||
name=index_config.name,
|
||||
content_path=vectors_dir,
|
||||
dimensions=index_config.dimensions,
|
||||
)
|
||||
logging.info(f"Index '{index_config.name}' created successfully.")
|
||||
|
||||
logging.info("Deploying index to a new endpoint...")
|
||||
vector_search.deploy_index(
|
||||
index_name=index_config.name, machine_type=index_config.machine_type
|
||||
)
|
||||
logging.info("Index deployed successfully!")
|
||||
logging.info(f"Endpoint name: {vector_search.index_endpoint.display_name}")
|
||||
logging.info(
|
||||
f"Endpoint resource name: {vector_search.index_endpoint.resource_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred during index creation or deployment: {e}", exc_info=True)
|
||||
raise
|
||||
31
apps/integration-layer/README.md
Normal file
31
apps/integration-layer/README.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# Integration Layer CLI
|
||||
|
||||
This package provides a command-line interface (CLI) to interact with the integration layer API deployed on Cloud Run.
|
||||
|
||||
## Installation
|
||||
|
||||
Install the package and its dependencies using `uv`:
|
||||
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The CLI provides two main commands: `send` and `chat`.
|
||||
|
||||
### `send`
|
||||
|
||||
Sends a single message to the API.
|
||||
|
||||
```bash
|
||||
int-layer send "My message" --telefono "1234567890"
|
||||
```
|
||||
|
||||
### `chat`
|
||||
|
||||
Starts an interactive chat session.
|
||||
|
||||
```bash
|
||||
int-layer chat --telefono "1234567890"
|
||||
```
|
||||
21
apps/integration-layer/pyproject.toml
Normal file
21
apps/integration-layer/pyproject.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[project]
|
||||
name = "integration-layer"
|
||||
version = "0.1.0"
|
||||
description = "A CLI to interact with the integration layer API."
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"requests",
|
||||
"typer",
|
||||
"rich"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
int-layer = "integration_layer.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.12,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
79
apps/integration-layer/src/integration_layer/cli.py
Normal file
79
apps/integration-layer/src/integration_layer/cli.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import random
|
||||
import string
|
||||
import subprocess
|
||||
|
||||
import typer
|
||||
from rich import print
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from .main import IntegrationLayerClient
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
def get_auth_token() -> str:
|
||||
"""Gets the gcloud auth token."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["gcloud", "auth", "print-identity-token"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as e:
|
||||
print(f"[bold red]Error getting gcloud token:[/bold red] {e}")
|
||||
print("Please ensure 'gcloud' is installed and you are authenticated.")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
@app.command()
|
||||
def send(
|
||||
message: str = typer.Argument(..., help="The message to send."),
|
||||
telefono: str = typer.Option(..., "--telefono", "-t", help="User's phone number (session ID)."),
|
||||
nickname: str = typer.Option("User", "--nickname", "-n", help="User's nickname."),
|
||||
canal: str = typer.Option("sigma", "--canal", "-c", help="Channel for the request."),
|
||||
):
|
||||
"""
|
||||
Sends a single message to the Integration Layer.
|
||||
"""
|
||||
try:
|
||||
client = IntegrationLayerClient()
|
||||
token = get_auth_token()
|
||||
response = client.call(token=token, mensaje=message, telefono=telefono, nickname=nickname, canal=canal)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(f"[bold red]Error:[/bold red] {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
@app.command()
|
||||
def chat(
|
||||
telefono: str = typer.Option(None, "--telefono", "-t", help="User's phone number to start the session. If not provided, a random one will be generated."),
|
||||
nickname: str = typer.Option("User", "--nickname", "-n", help="User's nickname."),
|
||||
canal: str = typer.Option("sigma", "--canal", "-c", help="Channel for the request."),
|
||||
):
|
||||
"""
|
||||
Starts an interactive chat with the Integration Layer.
|
||||
"""
|
||||
if not telefono:
|
||||
telefono = "".join(random.choices(string.digits, k=10))
|
||||
print(f"[bold yellow]No phone number provided. Using random session ID:[/] {telefono}")
|
||||
|
||||
try:
|
||||
client = IntegrationLayerClient()
|
||||
print("[bold green]Starting a new chat session. Type 'exit' or 'quit' to end.[/bold green]")
|
||||
|
||||
while True:
|
||||
message = Prompt.ask("You")
|
||||
if message.lower() in ["exit", "quit"]:
|
||||
print("[bold yellow]Ending chat session.[/bold yellow]")
|
||||
break
|
||||
|
||||
token = get_auth_token()
|
||||
response = client.call(token=token, mensaje=message, telefono=telefono, nickname=nickname, canal=canal)
|
||||
print(f"Agent: {response}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[bold red]Error:[/bold red] {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
43
apps/integration-layer/src/integration_layer/main.py
Normal file
43
apps/integration-layer/src/integration_layer/main.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import requests
|
||||
|
||||
|
||||
class IntegrationLayerClient:
|
||||
"""A class to interact with the Integration Layer API."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the IntegrationLayerClient."""
|
||||
self.endpoint = "https://34.111.169.196/api/v1/dialogflow/detect-intent"
|
||||
|
||||
def call(self, token: str, mensaje: str, telefono: str, nickname: str, canal: str) -> dict:
|
||||
"""
|
||||
Sends a message to the Integration Layer.
|
||||
|
||||
Args:
|
||||
token: The gcloud auth token.
|
||||
mensaje: The message to send.
|
||||
telefono: The user's phone number (acts as session ID).
|
||||
nickname: The user's nickname.
|
||||
canal: The channel (e.g., 'sigma').
|
||||
|
||||
Returns:
|
||||
A dictionary containing the server's response.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
"mensaje": mensaje,
|
||||
"usuario": {
|
||||
"telefono": telefono,
|
||||
"nickname": nickname,
|
||||
},
|
||||
"canal": canal,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(self.endpoint, headers=headers, json=data, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Failed to connect to Integration Layer: {e}") from e
|
||||
30
apps/keypoint-eval/README.md
Normal file
30
apps/keypoint-eval/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Keypoint Evaluator
|
||||
|
||||
This application evaluates a RAG (Retrieval-Augmented Generation) system based on the keypoint methodology from the RAGEval paper.
|
||||
|
||||
## How to use
|
||||
|
||||
To run the evaluation, execute the following command from the root directory of the project:
|
||||
|
||||
```bash
|
||||
python -m keypoint_eval.main --evaluation-name <EVALUATION_NAME> --matriz-eval <PATH_TO_EVALUATION_MATRIX_FILE>
|
||||
```
|
||||
|
||||
### Arguments
|
||||
|
||||
* `--evaluation-name`: The name of the evaluation.
|
||||
* `--matriz-eval`: The path to the evaluation matrix file.
|
||||
|
||||
The application will read the evaluation matrix from the specified file and will generate a CSV and a JSON file with the evaluation results.
|
||||
|
||||
## Input File Structure
|
||||
|
||||
The input file can be a CSV, Excel, or JSON file.
|
||||
|
||||
The file must contain the following columns:
|
||||
|
||||
* `input`: The user's question.
|
||||
* `expected_output`: The ground truth or expected answer.
|
||||
* `category` (optional): The category of the question.
|
||||
|
||||
If the `input` column is not found, the application will look for columns containing "pregunta" or "question".
|
||||
17
apps/keypoint-eval/pyproject.toml
Normal file
17
apps/keypoint-eval/pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[project]
|
||||
name = "keypoint-eval"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = []
|
||||
|
||||
[project.scripts]
|
||||
keypoint-eval = "keypoint_eval.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
2
apps/keypoint-eval/src/keypoint_eval/__init__.py
Normal file
2
apps/keypoint-eval/src/keypoint_eval/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def main() -> None:
|
||||
print("Hello from keypoint-eval!")
|
||||
58
apps/keypoint-eval/src/keypoint_eval/cli.py
Normal file
58
apps/keypoint-eval/src/keypoint_eval/cli.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import warnings
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from .main import run_keypoint_evaluation
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
app = typer.Typer(name="keypoint-eval")
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_file: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--input-file",
|
||||
"-i",
|
||||
help="Path to a local CSV or SQLite file for evaluation data. "
|
||||
"If not provided, data will be loaded from BigQuery.",
|
||||
),
|
||||
] = None,
|
||||
output_file: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--output-file",
|
||||
"-o",
|
||||
help="Optional: Path to save the output CSV file. "
|
||||
"If not provided, results will be saved to BigQuery.",
|
||||
),
|
||||
] = None,
|
||||
run_id: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
help="Optional: The specific run_id to filter the evaluation data by."
|
||||
),
|
||||
] = None,
|
||||
agent_name: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"-a",
|
||||
"--agent-name",
|
||||
help="Optional: The name of a specific agent to run. Use 'dialogflow' to run the Dialogflow agent.",
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
"""CLI for running keypoint-based evaluation."""
|
||||
run_keypoint_evaluation(
|
||||
input_file=input_file,
|
||||
output_file=output_file,
|
||||
run_id=run_id,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
330
apps/keypoint-eval/src/keypoint_eval/evaluator.py
Normal file
330
apps/keypoint-eval/src/keypoint_eval/evaluator.py
Normal file
@@ -0,0 +1,330 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from pydantic import BaseModel, Field
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn
|
||||
from rich.text import Text
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
|
||||
class KeypointMetricPrompt(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
template: str
|
||||
|
||||
|
||||
class KeyPointResponse(BaseModel):
|
||||
keypoints: list[str]
|
||||
|
||||
|
||||
class KeyPointEval(BaseModel):
|
||||
keypoint: str
|
||||
analysis: str
|
||||
category: Literal["relevant", "irrelevant", "incorrect"]
|
||||
|
||||
|
||||
class KeyPointEvalList(BaseModel):
|
||||
evals: list[KeyPointEval]
|
||||
|
||||
def _count(self, category: str) -> int:
|
||||
return sum(1 for e in self.evals if e.category == category)
|
||||
|
||||
def count_relevant(self) -> int:
|
||||
return self._count("relevant")
|
||||
|
||||
def count_irrelevant(self) -> int:
|
||||
return self._count("irrelevant")
|
||||
|
||||
def count_incorrect(self) -> int:
|
||||
return self._count("incorrect")
|
||||
|
||||
def keypoint_details(self) -> list[dict]:
|
||||
return [e.model_dump() for e in self.evals]
|
||||
|
||||
|
||||
class ConcisenessScore(BaseModel):
|
||||
score: float = Field(
|
||||
description="A score from 0.0 to 1.0 evaluating the conciseness of the answer."
|
||||
)
|
||||
|
||||
|
||||
class KeypointRAGEvaluator:
|
||||
"""
|
||||
Evaluador de sistemas RAG basado en la metodología de keypoints del paper RAGEval.
|
||||
Se enfoca en 3 métricas principales:
|
||||
- Completeness: Qué tan bien la respuesta captura los puntos clave de la respuesta ideal
|
||||
- Hallucination: Identificación de contenido que contradice los puntos clave
|
||||
- Irrelevance: Proporción de puntos clave que no son cubiertos ni contradichos
|
||||
"""
|
||||
|
||||
def __init__(self, console: Console, model: str = "gemini-2.0-flash"):
|
||||
self.metrics_results = []
|
||||
self.console = console
|
||||
self.llm = VertexAILLM(project=settings.project_id, location=settings.location)
|
||||
self.model = model
|
||||
|
||||
def evaluate_conciseness(self, query: str, answer: str) -> float:
|
||||
"""Evaluates the conciseness of a generated answer."""
|
||||
prompt = f"""Evaluate the conciseness of the following generated answer in response to the user's query.
|
||||
The score should be a single float from 0.0 to 1.0, where 1.0 is perfectly concise and direct, and 0.0 is extremely verbose and full of conversational fluff.
|
||||
Only consider the conciseness, not the correctness of the answer.
|
||||
|
||||
User Query: {query}
|
||||
Generated Answer: {answer}
|
||||
"""
|
||||
try:
|
||||
response = self.llm.structured_generation(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
response_model=ConcisenessScore,
|
||||
system_prompt="You are an expert evaluator focused on the conciseness and directness of answers. You output a single float score and nothing else.",
|
||||
)
|
||||
return response.score
|
||||
except Exception as e:
|
||||
self.console.print(
|
||||
f"[bold red]Error during conciseness evaluation: {str(e)}[/bold red]"
|
||||
)
|
||||
return 0.0 # Return a neutral score in case of error
|
||||
|
||||
def extract_keypoints(self, question: str, ground_truth: str) -> list[str]:
|
||||
"""
|
||||
Extrae puntos clave (keypoints) de la respuesta de referencia y agrega keypoints
|
||||
estándar para verificar la adherencia al dominio de Banorte.
|
||||
|
||||
Args:
|
||||
question: Pregunta del usuario
|
||||
ground_truth: Respuesta ideal o de referencia
|
||||
|
||||
Returns:
|
||||
Lista de puntos clave extraídos más los keypoints estándar de dominio
|
||||
"""
|
||||
prompt = f"""En esta tarea, se te dará una pregunta y una respuesta ideal. Basado en la respuesta ideal,
|
||||
necesitas resumir los puntos clave necesarios para responder la pregunta.
|
||||
|
||||
<ejemplo>
|
||||
<pregunta>
|
||||
Cómo puedo sacar un adelanto de nómina?
|
||||
</pregunta>
|
||||
<respuesta>
|
||||
¡Hola! 👋 Sacar un Adelanto de Nómina con Banorte es muy fácil y
|
||||
puede ayudarte con liquidez al instante. Aquí te explico cómo
|
||||
funciona:
|
||||
|
||||
Es un monto de hasta $10,000 MXN que puedes usar para lo que
|
||||
necesites, sin intereses y con una comisión fija del 7%. Lo puedes
|
||||
contratar directamente desde la aplicación móvil de Banorte. Los
|
||||
pagos se ajustan a la frecuencia de tu nómina y se cargan
|
||||
automáticamente a tu cuenta.
|
||||
|
||||
Los principales requisitos son:
|
||||
* Recibir tu nómina en Banorte y no tener otro adelanto vigente.
|
||||
* Tener un ingreso neto mensual mayor a $2,000 MXN.
|
||||
* Tener entre 18 y 74 años con 11 meses.
|
||||
* Contar con un buen historial en Buró de Crédito.
|
||||
|
||||
¡Espero que esta información te sea muy útil! 😊
|
||||
</respuesta>
|
||||
<puntos clave>
|
||||
[
|
||||
"Recibir tu nómina en Banorte",
|
||||
"No tener otro adelanto vigente",
|
||||
"Tener entre 18 y 74 años con 11 meses",
|
||||
"Contar con buen historial en Buró de Crédito",
|
||||
]
|
||||
</puntos clave>
|
||||
</ejemplo>
|
||||
|
||||
|
||||
<real>
|
||||
<pregunta>
|
||||
{question}
|
||||
</pregunta>
|
||||
<respuesta>
|
||||
{ground_truth}
|
||||
</respuesta>
|
||||
</real>
|
||||
"""
|
||||
try:
|
||||
response = self.llm.structured_generation(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
response_model=KeyPointResponse,
|
||||
system_prompt="Eres un asistente experto en extraer puntos clave informativos de respuestas.",
|
||||
)
|
||||
|
||||
return response.keypoints
|
||||
|
||||
except Exception as e:
|
||||
self.console.print(
|
||||
f"[bold red]Error al extraer keypoints: {str(e)}[/bold red]"
|
||||
)
|
||||
raise
|
||||
|
||||
def evaluate_keypoints(
|
||||
self,
|
||||
generated_answer: str,
|
||||
keypoints: list[str],
|
||||
) -> tuple[dict[str, float], list[dict]]:
|
||||
"""
|
||||
Evalúa una respuesta generada según los puntos clave extraídos.
|
||||
|
||||
Args:
|
||||
generated_answer: Respuesta generada por el sistema RAG
|
||||
keypoints: Lista de puntos clave de la respuesta ideal
|
||||
|
||||
Returns:
|
||||
Diccionario con las puntuaciones de las métricas y lista detallada de la clasificación de cada keypoint
|
||||
"""
|
||||
prompt = f"""En esta tarea, recibirás una respuesta real y múltiples puntos clave
|
||||
extraídos de una respuesta ideal. Tu objetivo es evaluar la calidad y concisión de la respuesta generada.
|
||||
|
||||
Para cada punto clave, proporciona un breve análisis y concluye con una de las siguientes clasificaciones:
|
||||
|
||||
[[[ Relevante ]]] - La respuesta generada aborda el punto clave de manera precisa, correcta y directa. La información es fácil de encontrar y no está oculta por un exceso de texto innecesario o "fluff" conversacional (saludos, despedidas, jerga, etc.).
|
||||
|
||||
[[[ Irrelevante ]]] - La respuesta generada omite por completo el punto clave o no contiene ninguna información relacionada con él. También se considera Irrelevante si la información del punto clave está presente, pero tan oculta por el "fluff" que un usuario tendría dificultades para encontrarla.
|
||||
|
||||
[[[ Incorrecto ]]] - La respuesta generada contiene información relacionada con el punto clave pero es incorrecta, contradice el punto clave, o podría confundir o desinformar al usuario.
|
||||
|
||||
**Criterio de Evaluación:**
|
||||
Sé estricto con el "fluff". Una respuesta ideal es tanto correcta como concisa. El exceso de texto conversacional que no aporta valor a la respuesta debe penalizarse. Si la información clave está presente pero la respuesta es innecesariamente larga y verbosa, considera rebajar su clasificación de Relevante a Irrelevante.
|
||||
|
||||
Respuesta Generada: {generated_answer}
|
||||
|
||||
Puntos Clave de la Respuesta ideal:
|
||||
{"\n".join([f"{i + 1}. {kp}" for i, kp in enumerate(keypoints)])}
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
response = self.llm.structured_generation(
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
response_model=KeyPointEvalList,
|
||||
system_prompt="Eres un evaluador experto de respuestas basadas en puntos clave, capaz de detectar si la información es relevante, irrelevante o incorrecta. Adoptas una postura favorable cuando evalúas la utilidad de las respuestas para los usuarios.",
|
||||
)
|
||||
|
||||
relevant_count = response.count_relevant()
|
||||
irrelevant_count = response.count_irrelevant()
|
||||
incorrect_count = response.count_incorrect()
|
||||
|
||||
total_keypoints = len(keypoints)
|
||||
|
||||
completeness = (
|
||||
relevant_count / total_keypoints if total_keypoints > 0 else 0
|
||||
)
|
||||
hallucination = (
|
||||
incorrect_count / total_keypoints if total_keypoints > 0 else 0
|
||||
)
|
||||
irrelevance = (
|
||||
irrelevant_count / total_keypoints if total_keypoints > 0 else 0
|
||||
)
|
||||
|
||||
keypoint_details = response.keypoint_details()
|
||||
|
||||
metrics = {
|
||||
"completeness": completeness,
|
||||
"hallucination": hallucination,
|
||||
"irrelevance": irrelevance,
|
||||
}
|
||||
|
||||
return metrics, keypoint_details
|
||||
|
||||
except Exception as e:
|
||||
self.console.print(
|
||||
f"[bold red]Error al evaluar keypoints: {str(e)}[/bold red]"
|
||||
)
|
||||
raise
|
||||
|
||||
def evaluate_rag_pipeline(
|
||||
self,
|
||||
query: str,
|
||||
response: str,
|
||||
ground_truth: str,
|
||||
retrieved_contexts: list[str],
|
||||
verbose: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Evalúa un pipeline RAG utilizando la metodología de keypoints.
|
||||
|
||||
Args:
|
||||
query: Pregunta del usuario
|
||||
response: Respuesta generada por el sistema RAG
|
||||
ground_truth: Respuesta ideal o de referencia
|
||||
retrieved_contexts: Contextos recuperados para generar la respuesta
|
||||
verbose: Si se muestran detalles de la evaluación
|
||||
|
||||
Returns:
|
||||
Diccionario con los resultados de la evaluación
|
||||
"""
|
||||
try:
|
||||
if verbose:
|
||||
self.console.print(
|
||||
Panel(
|
||||
Text(
|
||||
f"Question: {query}\n\nAnswer: {response}", justify="left"
|
||||
),
|
||||
title="[bold blue]Evaluating[/bold blue]",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
transient=True,
|
||||
console=self.console,
|
||||
disable=not verbose,
|
||||
) as progress:
|
||||
task = progress.add_task("Evaluation", total=2)
|
||||
|
||||
progress.update(task, description="Extracting keypoints...")
|
||||
keypoints = self.extract_keypoints(query, ground_truth)
|
||||
progress.advance(task)
|
||||
|
||||
if verbose:
|
||||
self.console.print(
|
||||
f"\nSe han extraído {len(keypoints)} puntos clave:"
|
||||
)
|
||||
for i, kp in enumerate(keypoints):
|
||||
self.console.print(f"{i + 1}. {kp}")
|
||||
|
||||
progress.update(task, description="Evaluating keypoints...")
|
||||
metrics, keypoint_details = self.evaluate_keypoints(response, keypoints)
|
||||
progress.advance(task)
|
||||
|
||||
results = {
|
||||
"query": query,
|
||||
"response": response,
|
||||
"ground_truth": ground_truth,
|
||||
"retrieved_contexts": retrieved_contexts,
|
||||
"completeness": metrics["completeness"],
|
||||
"hallucination": metrics["hallucination"],
|
||||
"irrelevance": metrics["irrelevance"],
|
||||
"keypoints": keypoints,
|
||||
"keypoint_details": keypoint_details,
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
|
||||
if verbose:
|
||||
self.console.print("\nResultados de la evaluación:")
|
||||
self.console.print(f"Completeness: {metrics['completeness']:.3f}")
|
||||
self.console.print(f"Hallucination: {metrics['hallucination']:.3f}")
|
||||
self.console.print(f"Irrelevance: {metrics['irrelevance']:.3f}")
|
||||
|
||||
self.console.print("\nDetalles de la evaluación por punto clave:")
|
||||
for i, detail in enumerate(keypoint_details):
|
||||
self.console.print(f"\nKeypoint {i + 1}: {detail['keypoint']}")
|
||||
self.console.print(f"Categoría: {detail['category']}")
|
||||
|
||||
self.metrics_results.append(results)
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
self.console.print(f"[bold red]Error en la evaluación: {str(e)}[/bold red]")
|
||||
raise
|
||||
132
apps/keypoint-eval/src/keypoint_eval/loaders.py
Normal file
132
apps/keypoint-eval/src/keypoint_eval/loaders.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import pathlib
|
||||
import sqlite3
|
||||
|
||||
import pandas as pd
|
||||
from google.cloud import bigquery
|
||||
from rich.console import Console
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
|
||||
|
||||
def load_data_from_local_file(
|
||||
file_path: str, console: Console, run_id: str = None
|
||||
) -> pd.DataFrame:
|
||||
"""Loads evaluation data from a local CSV or SQLite file and returns a DataFrame."""
|
||||
console.print(f"Loading data from {file_path}...")
|
||||
path = pathlib.Path(file_path)
|
||||
if not path.exists():
|
||||
raise Exception(f"Error: File not found at {file_path}")
|
||||
|
||||
if path.suffix == ".csv":
|
||||
try:
|
||||
df = pd.read_csv(path)
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while reading the CSV file: {e}")
|
||||
|
||||
elif path.suffix in [".db", ".sqlite"]:
|
||||
try:
|
||||
con = sqlite3.connect(path)
|
||||
# Assuming table name is the file stem
|
||||
table_name = path.stem
|
||||
df = pd.read_sql(f"SELECT * FROM {table_name}", con)
|
||||
con.close()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while reading the SQLite DB: {e}")
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unsupported file type: {path.suffix}. Please use .csv or .db/.sqlite"
|
||||
)
|
||||
|
||||
# Check for required columns
|
||||
if (
|
||||
"input" not in df.columns
|
||||
or "expected_output" not in df.columns
|
||||
):
|
||||
raise Exception(
|
||||
"Error: The input file must contain 'input' and 'expected_output' columns."
|
||||
)
|
||||
df["agent"] = config.agent.name
|
||||
|
||||
print(f"{run_id=}")
|
||||
if run_id:
|
||||
if "run_id" in df.columns:
|
||||
df = df[df["run_id"] == run_id].copy()
|
||||
console.print(f"Filtered data for run_id: {run_id}")
|
||||
if df.empty:
|
||||
console.print(
|
||||
f"[yellow]Warning: No data found for run_id '{run_id}' in {file_path}.[/yellow]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]Warning: --run-id provided, but 'run_id' column not found in {file_path}. Using all data.[/yellow]"
|
||||
)
|
||||
|
||||
# Filter out unanswerable questions if 'type' column exists
|
||||
if "type" in df.columns:
|
||||
df = df[df["type"] != "Unanswerable"].copy()
|
||||
|
||||
df.dropna(subset=["input", "expected_output"], inplace=True)
|
||||
|
||||
console.print(f"Loaded {len(df)} questions for evaluation from {file_path}.")
|
||||
return df
|
||||
|
||||
|
||||
def load_data_from_bigquery(console: Console, run_id: str = None) -> pd.DataFrame:
|
||||
"""Loads evaluation data from the BigQuery table and returns a DataFrame."""
|
||||
console.print("Loading data from BigQuery...")
|
||||
bq_project_id = config.bigquery.project_id or config.project_id
|
||||
client = bigquery.Client(project=bq_project_id)
|
||||
table_ref = f"{bq_project_id}.{config.bigquery.dataset_id}.{config.bigquery.table_ids['synth_gen']}"
|
||||
|
||||
console.print(f"Querying table: {table_ref}")
|
||||
try:
|
||||
table = client.get_table(table_ref)
|
||||
all_columns = [schema.name for schema in table.schema]
|
||||
|
||||
select_cols = ["input", "expected_output"]
|
||||
if "category" in all_columns:
|
||||
select_cols.append("category")
|
||||
|
||||
query_parts = [f"SELECT {', '.join(select_cols)}", f"FROM `{table_ref}`"]
|
||||
|
||||
# Build WHERE clauses
|
||||
where_clauses = []
|
||||
if "type" in all_columns:
|
||||
where_clauses.append("type != 'Unanswerable'")
|
||||
if run_id:
|
||||
if "run_id" in all_columns:
|
||||
where_clauses.append(f"run_id = '{run_id}'")
|
||||
console.print(f"Filtering data for run_id: {run_id}")
|
||||
else:
|
||||
console.print(
|
||||
"[yellow]Warning: --run-id provided, but 'run_id' column not found in BigQuery table. Using all data.[/yellow]"
|
||||
)
|
||||
|
||||
if where_clauses:
|
||||
query_parts.append("WHERE " + " AND ".join(where_clauses))
|
||||
|
||||
query = "\n".join(query_parts)
|
||||
df = client.query(query).to_dataframe()
|
||||
|
||||
except Exception as e:
|
||||
if "Not found" in str(e):
|
||||
console.print(f"[bold red]Error: Table {table_ref} not found.[/bold red]")
|
||||
console.print(
|
||||
"Please ensure the table exists and the configuration in 'config.yaml' is correct."
|
||||
)
|
||||
raise
|
||||
else:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while querying BigQuery: {e}[/bold red]"
|
||||
)
|
||||
raise
|
||||
|
||||
df.dropna(subset=["input", "expected_output"], inplace=True)
|
||||
df["agent"] = config.agent.name
|
||||
|
||||
console.print(f"Loaded {len(df)} questions for evaluation.")
|
||||
if run_id and df.empty:
|
||||
console.print(
|
||||
f"[yellow]Warning: No data found for run_id '{run_id}' in BigQuery.[/yellow]"
|
||||
)
|
||||
return df
|
||||
347
apps/keypoint-eval/src/keypoint_eval/main.py
Normal file
347
apps/keypoint-eval/src/keypoint_eval/main.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import pandas as pd
|
||||
from dialogflow.main import DialogflowAgent as OriginalDialogflowAgent
|
||||
from google.api_core import exceptions as google_exceptions
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn
|
||||
from rich.table import Table
|
||||
|
||||
from rag_eval.agent import Agent
|
||||
from rag_eval.config import settings as config
|
||||
|
||||
from . import loaders
|
||||
from .evaluator import KeypointRAGEvaluator
|
||||
|
||||
|
||||
class DialogflowEvalAgent:
|
||||
"""Adapter for DialogflowAgent to be used in evaluation."""
|
||||
|
||||
def __init__(self, session_id: str = None):
|
||||
self.agent = OriginalDialogflowAgent()
|
||||
self.session_id = session_id or str(uuid.uuid4())
|
||||
|
||||
def call(self, query: str) -> str:
|
||||
"""Calls the Dialogflow agent with the session ID and returns only the response text."""
|
||||
response = self.agent.call(query, session_id=self.session_id)
|
||||
return response.get("response_text", "")
|
||||
|
||||
|
||||
def run_keypoint_evaluation(
|
||||
input_file: str = None,
|
||||
output_file: str = None,
|
||||
run_id: str = None,
|
||||
agent_name: str = None,
|
||||
):
|
||||
"""
|
||||
Runs keypoint-based evaluation for each agent found in the input data.
|
||||
Handles both single-turn and multi-turn conversational data.
|
||||
"""
|
||||
console = Console()
|
||||
|
||||
# --- Introduction Panel ---
|
||||
intro_panel = Panel(
|
||||
f"""
|
||||
[bold]Input File:[/bold] [cyan]{input_file or 'BigQuery'}[/cyan]
|
||||
[bold]Output File:[/bold] [cyan]{output_file or 'BigQuery'}[/cyan]
|
||||
[bold]Run ID:[/bold] [cyan]{run_id or 'Not specified'}[/cyan]
|
||||
[bold]Agent Name:[/bold] [cyan]{agent_name or 'All'}[/cyan]
|
||||
""",
|
||||
title="[bold magenta]Keypoint Evaluation Run[/bold magenta]",
|
||||
expand=False,
|
||||
border_style="magenta",
|
||||
)
|
||||
console.print(intro_panel)
|
||||
|
||||
try:
|
||||
if input_file:
|
||||
df = loaders.load_data_from_local_file(input_file, console, run_id=run_id)
|
||||
else:
|
||||
df = loaders.load_data_from_bigquery(console, run_id=run_id)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An unexpected error occurred during data loading: {e}[/bold red]"
|
||||
)
|
||||
raise
|
||||
|
||||
if run_id is None:
|
||||
run_id = "run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
if df.empty:
|
||||
console.print("[bold red]No data loaded, exiting.[/bold red]")
|
||||
return
|
||||
|
||||
# --- Set up agents to evaluate ---
|
||||
evaluables = []
|
||||
if agent_name:
|
||||
if agent_name == "dialogflow":
|
||||
evaluables.append(
|
||||
{"name": "dialogflow", "agent_class": DialogflowEvalAgent, "is_special": True}
|
||||
)
|
||||
console.print("[bold green]Agent 'dialogflow' selected for evaluation.[/bold green]")
|
||||
elif agent_name == config.agent.name:
|
||||
evaluables.append(
|
||||
{"name": config.agent.name, "agent_class": Agent, "is_special": False}
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[bold red]Error: Agent '{agent_name}' not found in the configuration.[/bold red]"
|
||||
)
|
||||
raise ValueError(f"Agent '{agent_name}' not found in the configuration")
|
||||
else:
|
||||
evaluables.append(
|
||||
{"name": config.agent.name, "agent_class": Agent, "is_special": False}
|
||||
)
|
||||
|
||||
all_agents_results = []
|
||||
total_skipped_questions = 0
|
||||
|
||||
# --- Check for conversational data ---
|
||||
is_conversational = "conversation_id" in df.columns and "turn" in df.columns
|
||||
|
||||
if is_conversational:
|
||||
df.sort_values(by=["conversation_id", "turn"], inplace=True)
|
||||
conversations = df.groupby("conversation_id")
|
||||
console.print(f"Found [bold cyan]{len(conversations)}[/bold cyan] conversations to evaluate.")
|
||||
progress_total = len(df)
|
||||
else:
|
||||
console.print(f"Found [bold cyan]{len(df)}[/bold cyan] single questions to evaluate.")
|
||||
conversations = [(None, df)] # Treat all rows as one big group
|
||||
progress_total = len(df)
|
||||
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task(
|
||||
"[green]Processing evaluations...[/green]",
|
||||
total=progress_total,
|
||||
)
|
||||
|
||||
for conversation_id, conversation_df in conversations:
|
||||
if is_conversational:
|
||||
console.print(
|
||||
Panel(
|
||||
f"Evaluating conversation: [bold blue]{conversation_id}[/bold blue]",
|
||||
expand=False,
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
|
||||
for evaluable in evaluables:
|
||||
agent_name_for_results = evaluable["name"]
|
||||
|
||||
# Initialize agent and history for each conversation
|
||||
if evaluable["is_special"]:
|
||||
rag_agent = evaluable["agent_class"](session_id=str(uuid.uuid4()))
|
||||
else:
|
||||
rag_agent = evaluable["agent_class"]()
|
||||
|
||||
history = []
|
||||
evaluator = KeypointRAGEvaluator(console)
|
||||
|
||||
for _, row in conversation_df.iterrows():
|
||||
query = row["input"]
|
||||
ground_truth = row["expected_output"]
|
||||
|
||||
progress.update(
|
||||
task, description=f"Agent: {agent_name_for_results}, Conv: {conversation_id or 'N/A'}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Call agent to get the response
|
||||
if is_conversational and not evaluable["is_special"]:
|
||||
# For standard agent in conversational mode, manage history
|
||||
history.append({"role": "user", "content": query})
|
||||
response = rag_agent.call(history)
|
||||
history.append({"role": "assistant", "content": response})
|
||||
else:
|
||||
# For special agents or single-turn mode
|
||||
response = rag_agent.call(query)
|
||||
|
||||
# Step 2: Evaluate the response
|
||||
eval_result = evaluator.evaluate_rag_pipeline(
|
||||
query=query,
|
||||
response=response,
|
||||
ground_truth=ground_truth,
|
||||
retrieved_contexts=[],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Step 3: Evaluate conciseness
|
||||
conciseness_score = evaluator.evaluate_conciseness(query, response)
|
||||
eval_result["conciseness"] = conciseness_score
|
||||
|
||||
eval_result["agent"] = agent_name_for_results
|
||||
# Add conversational info if present
|
||||
if is_conversational:
|
||||
eval_result["conversation_id"] = conversation_id
|
||||
eval_result["turn"] = row["turn"]
|
||||
|
||||
all_agents_results.append(eval_result)
|
||||
|
||||
except google_exceptions.FailedPrecondition as e:
|
||||
if "Token limit exceeded" in str(e):
|
||||
total_skipped_questions += 1
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold]Query:[/bold]\n[white]{query}[/white]",
|
||||
title="[yellow]Skipping Question (Token Limit Exceeded)[/yellow]",
|
||||
expand=False,
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
progress.advance(task)
|
||||
|
||||
if not all_agents_results:
|
||||
console.print("[bold red]No evaluation results were generated.[/bold red]")
|
||||
return
|
||||
|
||||
final_df = pd.DataFrame(all_agents_results)
|
||||
|
||||
# --- Summary Table ---
|
||||
summary_df = (
|
||||
final_df.groupby("agent")[["completeness", "hallucination", "irrelevance", "conciseness"]]
|
||||
.mean()
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
table = Table(
|
||||
title="[bold green]Keypoint Evaluation Summary[/bold green]",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
)
|
||||
table.add_column("Agent", justify="left", style="cyan", no_wrap=True)
|
||||
table.add_column("Completeness", justify="right", style="magenta")
|
||||
table.add_column("Hallucination", justify="right", style="green")
|
||||
table.add_column("Irrelevance", justify="right", style="yellow")
|
||||
table.add_column("Conciseness", justify="right", style="cyan")
|
||||
|
||||
for _, row in summary_df.iterrows():
|
||||
table.add_row(
|
||||
row["agent"],
|
||||
f"{row['completeness']:.4f}",
|
||||
f"{row['hallucination']:.4f}",
|
||||
f"{row['irrelevance']:.4f}",
|
||||
f"{row['conciseness']:.4f}",
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
# --- Skipped Questions Summary ---
|
||||
if total_skipped_questions > 0:
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold yellow]Total questions skipped due to token limit: {total_skipped_questions}[/bold yellow]",
|
||||
title="[bold]Skipped Questions[/bold]",
|
||||
expand=False,
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
if "timestamp" in final_df.columns:
|
||||
final_df["timestamp"] = pd.to_datetime(final_df["timestamp"]).dt.tz_localize(
|
||||
None
|
||||
)
|
||||
|
||||
if output_file:
|
||||
for col in ["keypoints", "keypoint_details", "retrieved_contexts"]:
|
||||
if col in final_df.columns:
|
||||
final_df[col] = final_df[col].apply(json.dumps)
|
||||
|
||||
output_panel = Panel(
|
||||
f"Saving results to CSV file: [bold cyan]{output_file}[/bold cyan]\n"
|
||||
f"Successfully saved {len(final_df)} rows to [bold green]{output_file}[/bold green]",
|
||||
title="[bold green]Output[/bold green]",
|
||||
expand=False,
|
||||
border_style="green",
|
||||
)
|
||||
console.print(output_panel)
|
||||
final_df.to_csv(output_file, index=False, encoding="utf-8-sig")
|
||||
else:
|
||||
project_id = config.bigquery.project_id or config.project_id
|
||||
dataset_id = config.bigquery.dataset_id
|
||||
table_name = config.bigquery.table_ids["keypoint_eval"]
|
||||
table_id = f"{project_id}.{dataset_id}.{table_name}"
|
||||
|
||||
bq_schema = [
|
||||
{"name": "run_id", "type": "STRING"},
|
||||
{"name": "query", "type": "STRING"},
|
||||
{"name": "response", "type": "STRING"},
|
||||
{"name": "ground_truth", "type": "STRING"},
|
||||
{"name": "retrieved_contexts", "type": "STRING", "mode": "REPEATED"},
|
||||
{"name": "completeness", "type": "FLOAT"},
|
||||
{"name": "hallucination", "type": "FLOAT"},
|
||||
{"name": "irrelevance", "type": "FLOAT"},
|
||||
{"name": "conciseness", "type": "FLOAT"},
|
||||
{"name": "keypoints", "type": "STRING", "mode": "REPEATED"},
|
||||
{
|
||||
"name": "keypoint_details",
|
||||
"type": "RECORD",
|
||||
"mode": "REPEATED",
|
||||
"fields": [
|
||||
{"name": "keypoint", "type": "STRING"},
|
||||
{"name": "analysis", "type": "STRING"},
|
||||
{"name": "category", "type": "STRING"},
|
||||
],
|
||||
},
|
||||
{"name": "timestamp", "type": "TIMESTAMP"},
|
||||
{"name": "agent", "type": "STRING"},
|
||||
{"name": "error", "type": "STRING"},
|
||||
{"name": "conversation_id", "type": "STRING"},
|
||||
{"name": "turn", "type": "INTEGER"},
|
||||
]
|
||||
|
||||
final_df["run_id"] = run_id
|
||||
bq_column_names = [col["name"] for col in bq_schema]
|
||||
|
||||
for col_name in bq_column_names:
|
||||
if col_name not in final_df.columns:
|
||||
final_df[col_name] = None
|
||||
|
||||
final_df["completeness"] = final_df["completeness"].fillna(0.0)
|
||||
final_df["hallucination"] = final_df["hallucination"].fillna(0.0)
|
||||
final_df["irrelevance"] = final_df["irrelevance"].fillna(0.0)
|
||||
final_df["error"] = final_df["error"].fillna("")
|
||||
|
||||
for col_name in ["retrieved_contexts", "keypoints", "keypoint_details"]:
|
||||
if col_name in final_df.columns:
|
||||
# Ensure any non-list items (like NaN or None) become an empty list
|
||||
final_df[col_name] = [
|
||||
item if isinstance(item, list) else [] for item in final_df[col_name]
|
||||
]
|
||||
|
||||
final_df_for_bq = final_df[bq_column_names].copy()
|
||||
|
||||
output_panel = Panel(
|
||||
f"Saving results to BigQuery table: [bold cyan]{table_id}[/bold cyan]\n"
|
||||
f"Successfully saved {len(final_df_for_bq)} rows to [bold green]{table_id}[/bold green]",
|
||||
title="[bold green]Output[/bold green]",
|
||||
expand=False,
|
||||
border_style="green",
|
||||
)
|
||||
console.print(output_panel)
|
||||
|
||||
try:
|
||||
final_df_for_bq.to_gbq(
|
||||
destination_table=f"{dataset_id}.{table_name}",
|
||||
project_id=project_id,
|
||||
if_exists="append",
|
||||
table_schema=bq_schema,
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while saving to BigQuery: {e}[/bold red]"
|
||||
)
|
||||
console.print("DataFrame schema used for upload:")
|
||||
console.print(final_df_for_bq.info())
|
||||
raise
|
||||
95
apps/search-eval/README.md
Normal file
95
apps/search-eval/README.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# Search Evaluation
|
||||
|
||||
This package contains scripts to evaluate the performance of the vector search component.
|
||||
|
||||
## Evaluation
|
||||
|
||||
The `search-eval` script evaluates search performance. It can source data from either BigQuery or local files.
|
||||
|
||||
### Local File Evaluation
|
||||
|
||||
To run the evaluation using a local file, use the `--input-file` option.
|
||||
|
||||
```bash
|
||||
uv run search-eval -- --input-file /path/to/your/data.csv
|
||||
```
|
||||
|
||||
Or for a SQLite database:
|
||||
|
||||
```bash
|
||||
uv run search-eval -- --input-file /path/to/your/data.db
|
||||
```
|
||||
|
||||
#### Input File Structures
|
||||
|
||||
**CSV File**
|
||||
|
||||
The CSV file must contain the following columns:
|
||||
|
||||
| Column | Description |
|
||||
|--------|-----------------------------------------------|
|
||||
| `input` | The question to be used for the search query. |
|
||||
| `source` | The expected document path for the question. |
|
||||
|
||||
**SQLite Database**
|
||||
|
||||
The SQLite database must contain a table named `evaluation_data` with the following columns:
|
||||
|
||||
| Column | Description |
|
||||
|--------|-----------------------------------------------|
|
||||
| `input` | The question to be used for the search query. |
|
||||
| `source` | The expected document path for the question. |
|
||||
|
||||
### BigQuery Evaluation
|
||||
|
||||
The `search-eval-bq` script evaluates search performance using data sourced from and written to BigQuery.
|
||||
|
||||
### BigQuery Table Structures
|
||||
|
||||
#### Input Table
|
||||
|
||||
The input table must contain the following columns:
|
||||
|
||||
| Column | Type | Description |
|
||||
| --------------- | ------- | --------------------------------------------------------------------------- |
|
||||
| `id` | STRING | A unique identifier for each question. |
|
||||
| `question` | STRING | The question to be used for the search query. |
|
||||
| `document_path` | STRING | The expected document path for the given question. |
|
||||
| `question_type` | STRING | The type of question. Rows where `question_type` is 'Unanswerable' are ignored. |
|
||||
|
||||
#### Output Table
|
||||
|
||||
The output table will be created by the script if it doesn't exist, or appended to if it does. It will have the following structure:
|
||||
|
||||
| Column | Type | Description |
|
||||
| ------------------------ | --------- | ------------------------------------------------------------------------ |
|
||||
| `id` | STRING | The unique identifier for the question from the input table. |
|
||||
| `question` | STRING | The question used for the search query. |
|
||||
| `expected_document` | STRING | The expected document for the given question. |
|
||||
| `retrieved_documents` | STRING[] | An array of document IDs retrieved from the vector search. |
|
||||
| `retrieved_distances` | FLOAT64[] | An array of distance scores for the retrieved documents. |
|
||||
| `is_expected_in_results` | BOOLEAN | A flag indicating whether the expected document was in the search results. |
|
||||
| `evaluation_timestamp` | TIMESTAMP | The timestamp of when the evaluation was run. |
|
||||
|
||||
### Usage
|
||||
|
||||
To run the BigQuery evaluation script, use the `uv run search-eval-bq` command with the following options:
|
||||
|
||||
```bash
|
||||
uv run search-eval-bq -- --input-table <project.dataset.table> --output-table <project.dataset.table> [--project-id <gcp-project-id>]
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
|
||||
* `--input-table`: **(Required)** The full BigQuery table name for the input data (e.g., `my-gcp-project.my_dataset.questions`).
|
||||
* `--output-table`: **(Required)** The full BigQuery table name for the output results (e.g., `my-gcp-project.my_dataset.eval_results`).
|
||||
* `--project-id`: (Optional) The Google Cloud project ID. If not provided, it will use the `project_id` from the `config.yaml` file.
|
||||
|
||||
**Example:**
|
||||
|
||||
```bash
|
||||
uv run search-eval-bq -- \
|
||||
--input-table "my-gcp-project.search_eval.synthetic_questions" \
|
||||
--output-table "my-gcp-project.search_eval.results" \
|
||||
--project-id "my-gcp-project"
|
||||
```
|
||||
27
apps/search-eval/pyproject.toml
Normal file
27
apps/search-eval/pyproject.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[project]
|
||||
name = "search-eval"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"embedder",
|
||||
"ranx>=0.3.21",
|
||||
"google-cloud-bigquery",
|
||||
"pandas-gbq",
|
||||
"kfp>=1.4.0",
|
||||
"requests-toolbelt>=1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
search-eval = "search_eval.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.uv.sources]
|
||||
embedder = { workspace = true }
|
||||
0
apps/search-eval/src/search_eval/__init__.py
Normal file
0
apps/search-eval/src/search_eval/__init__.py
Normal file
46
apps/search-eval/src/search_eval/cli.py
Normal file
46
apps/search-eval/src/search_eval/cli.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
|
||||
from .main import evaluate
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_file: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"-i",
|
||||
"--input-file",
|
||||
help="Path to a local CSV or SQLite file for evaluation data. "
|
||||
"If not provided, data will be loaded from BigQuery.",
|
||||
),
|
||||
] = None,
|
||||
output_file: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"-o",
|
||||
"--output-file",
|
||||
help="Path to save the detailed results as a CSV file. "
|
||||
"If not provided, results will be saved to BigQuery.",
|
||||
),
|
||||
] = None,
|
||||
run_id: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
help="Optional: The specific run_id to filter the evaluation data by."
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
"""Evaluates the search metrics by loading data from BigQuery or a local file."""
|
||||
evaluate(
|
||||
input_file=input_file,
|
||||
output_file=output_file,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
305
apps/search-eval/src/search_eval/main.py
Normal file
305
apps/search-eval/src/search_eval/main.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import pathlib
|
||||
import sqlite3
|
||||
|
||||
import pandas as pd
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from google.cloud import bigquery
|
||||
from ranx import Qrels, Run
|
||||
from ranx import evaluate as ranx
|
||||
from rich.console import Console
|
||||
from rich.progress import track
|
||||
from rich.table import Table
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
|
||||
|
||||
def load_data_from_local_file(
|
||||
file_path: str, console: Console, run_id: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""Loads evaluation data from a local CSV or SQLite file."""
|
||||
console.print(f"[bold green]Loading data from {file_path}...[/bold green]")
|
||||
path = pathlib.Path(file_path)
|
||||
if not path.exists():
|
||||
console.print(f"[bold red]Error: File not found at {file_path}[/bold red]")
|
||||
raise
|
||||
|
||||
if path.suffix == ".csv":
|
||||
try:
|
||||
df = pd.read_csv(path)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while reading the CSV file: {e}[/bold red]"
|
||||
)
|
||||
raise
|
||||
elif path.suffix in [".db", ".sqlite"]:
|
||||
try:
|
||||
con = sqlite3.connect(path)
|
||||
# Assuming table name is 'evaluation_data'
|
||||
df = pd.read_sql("SELECT * FROM evaluation_data", con)
|
||||
con.close()
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while reading the SQLite DB: {e}[/bold red]"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
console.print(
|
||||
f"[bold red]Unsupported file type: {path.suffix}. Please use .csv or .db/.sqlite[/bold red]"
|
||||
)
|
||||
raise
|
||||
|
||||
# Standardize column names and add ID
|
||||
if "input" in df.columns and "source" in df.columns:
|
||||
df = df.rename(columns={"input": "question", "source": "document_path"})
|
||||
df["id"] = df.index + 1
|
||||
df["id"] = df["id"].astype(str)
|
||||
else:
|
||||
console.print(
|
||||
"[bold red]Error: The input file must contain 'input' and 'source' columns.[/bold red]"
|
||||
)
|
||||
raise
|
||||
|
||||
if run_id:
|
||||
if "run_id" in df.columns:
|
||||
df = df[df["run_id"] == run_id].copy()
|
||||
console.print(f"Filtered data for run_id: [bold cyan]{run_id}[/bold cyan]")
|
||||
if df.empty:
|
||||
console.print(
|
||||
f"[bold yellow]Warning: No data found for run_id '{run_id}' in {file_path}.[/bold yellow]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[bold yellow]Warning: --run-id provided, but 'run_id' column not found in {file_path}. Using all data.[/bold yellow]"
|
||||
)
|
||||
|
||||
df.dropna(inplace=True)
|
||||
console.print(f"Loaded {len(df)} questions for evaluation.")
|
||||
return df
|
||||
|
||||
|
||||
def load_data_from_bigquery(
|
||||
console: Console, run_id: str | None = None
|
||||
) -> pd.DataFrame:
|
||||
"""Loads evaluation data from the BigQuery table."""
|
||||
console.print("[bold green]Loading data from BigQuery...[/bold green]")
|
||||
bq_project_id = config.bigquery.project_id or config.project_id
|
||||
client = bigquery.Client(project=bq_project_id)
|
||||
table_ref = f"{bq_project_id}.{config.bigquery.dataset_id}.{config.bigquery.table_ids['synth_gen']}"
|
||||
|
||||
console.print(f"Querying table: [bold cyan]{table_ref}[/bold cyan]")
|
||||
query = f"""
|
||||
SELECT
|
||||
input AS question,
|
||||
source AS document_path,
|
||||
ROW_NUMBER() OVER() as id
|
||||
FROM
|
||||
`{table_ref}`
|
||||
WHERE
|
||||
`type` != 'Unanswerable'
|
||||
"""
|
||||
if run_id:
|
||||
console.print(f"Filtering for run_id: [bold cyan]{run_id}[/bold cyan]")
|
||||
query += f" AND run_id = '{run_id}'"
|
||||
|
||||
try:
|
||||
df = client.query(query).to_dataframe()
|
||||
except Exception as e:
|
||||
if "Not found" in str(e):
|
||||
console.print(f"[bold red]Error: Table {table_ref} not found.[/bold red]")
|
||||
console.print(
|
||||
"Please ensure the table exists and the configuration in 'config.yaml' is correct."
|
||||
)
|
||||
raise
|
||||
elif "unrecognized name: run_id" in str(e).lower():
|
||||
console.print(
|
||||
"[bold red]Error: The BigQuery table must contain a 'run_id' column when using the --run-id flag.[/bold red]"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while querying BigQuery: {e}[/bold red]"
|
||||
)
|
||||
raise
|
||||
|
||||
df.dropna(inplace=True)
|
||||
console.print(f"Loaded {len(df)} questions for evaluation.")
|
||||
if df.empty:
|
||||
console.print(
|
||||
f"[bold yellow]Warning: No data found for run_id '{run_id}' in BigQuery.[/bold yellow]"
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
def run_evaluation(
|
||||
df: pd.DataFrame, console: Console
|
||||
) -> pd.DataFrame:
|
||||
"""Runs the search evaluation on the given dataframe."""
|
||||
agent_config = config.agent
|
||||
index_config = config.index
|
||||
console.print(
|
||||
f"Embedding Model: [bold cyan]{agent_config.embedding_model}[/bold cyan]"
|
||||
)
|
||||
console.print(f"Index Name: [bold cyan]{index_config.name}[/bold cyan]")
|
||||
|
||||
# Initialize the embedder and vector search
|
||||
embedder = VertexAIEmbedder(
|
||||
project=config.project_id,
|
||||
location=config.location,
|
||||
model_name=agent_config.embedding_model
|
||||
)
|
||||
vector_search = GoogleCloudVectorSearch(
|
||||
project_id=config.project_id,
|
||||
location=config.location,
|
||||
bucket=config.bucket,
|
||||
index_name=index_config.name,
|
||||
)
|
||||
vector_search.load_index_endpoint(index_config.endpoint)
|
||||
|
||||
# Prepare qrels
|
||||
qrels_data = {}
|
||||
for _, row in track(df.iterrows(), total=len(df), description="Preparing qrels..."):
|
||||
doc_path = str(row["document_path"]).split("/")[-1].strip()
|
||||
# print(doc_path)
|
||||
qrels_data[str(row["id"])] = {doc_path: 1}
|
||||
qrels = Qrels(qrels_data)
|
||||
|
||||
# Prepare run
|
||||
run_data = {}
|
||||
detailed_results_list = []
|
||||
for _, row in track(df.iterrows(), total=len(df), description="Preparing run..."):
|
||||
question_embedding = embedder.generate_embedding(row["question"])
|
||||
results = vector_search.run_query(
|
||||
deployed_index_id=index_config.deployment,
|
||||
query=question_embedding,
|
||||
limit=10,
|
||||
)
|
||||
# print(results[0]["id"])
|
||||
run_data[str(row["id"])] = {
|
||||
result["id"]: result["distance"] for result in results
|
||||
}
|
||||
|
||||
retrieved_docs = [result["id"] for result in results]
|
||||
retrieved_distances = [result["distance"] for result in results]
|
||||
expected_doc = str(row["document_path"]).split("/")[-1].strip()
|
||||
# print(f"expected doc: {expected_doc}")
|
||||
# print(f"retrieved docs: {retrieved_docs}")
|
||||
|
||||
detailed_results_list.append(
|
||||
{
|
||||
"agent": agent_config.name,
|
||||
"id": row["id"],
|
||||
"input": row["question"],
|
||||
"expected_document": expected_doc,
|
||||
"retrieved_documents": retrieved_docs,
|
||||
"retrieved_distances": retrieved_distances,
|
||||
"is_expected_in_results": expected_doc in retrieved_docs,
|
||||
}
|
||||
)
|
||||
run = Run(run_data)
|
||||
|
||||
# Evaluate
|
||||
k_values = [1, 3, 5, 10]
|
||||
metrics = []
|
||||
for k in k_values:
|
||||
metrics.extend(
|
||||
[f"precision@{k}", f"recall@{k}", f"f1@{k}", f"ndcg@{k}", f"mrr@{k}"]
|
||||
)
|
||||
|
||||
with console.status("[bold green]Running evaluation..."):
|
||||
results = ranx(qrels, run, metrics)
|
||||
|
||||
# Create tables
|
||||
table = Table(title=f"Search Metrics @k for Agent: {agent_config.name}")
|
||||
table.add_column("k", justify="right", style="cyan")
|
||||
table.add_column("Precision@k", justify="right")
|
||||
table.add_column("Recall@k", justify="right")
|
||||
table.add_column("F1@k", justify="right")
|
||||
table.add_column("nDCG@k", justify="right")
|
||||
table.add_column("MRR@k", justify="right")
|
||||
|
||||
for k in k_values:
|
||||
precision = results.get(f"precision@{k}")
|
||||
recall = results.get(f"recall@{k}")
|
||||
f1 = results.get(f"f1@{k}")
|
||||
ndcg = results.get(f"ndcg@{k}")
|
||||
mrr = results.get(f"mrr@{k}")
|
||||
table.add_row(
|
||||
str(k),
|
||||
f"{precision:.4f}" if precision is not None else "N/A",
|
||||
f"{recall:.4f}" if recall is not None else "N/A",
|
||||
f"{f1:.4f}" if f1 is not None else "N/A",
|
||||
f"{ndcg:.4f}" if ndcg is not None else "N/A",
|
||||
f"{mrr:.4f}" if mrr is not None else "N/A",
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
return pd.DataFrame(detailed_results_list)
|
||||
|
||||
|
||||
def evaluate(
|
||||
input_file: str | None = None,
|
||||
output_file: str | None = None,
|
||||
run_id: str | None = None,
|
||||
):
|
||||
"""Core logic for evaluating search metrics."""
|
||||
console = Console()
|
||||
if input_file:
|
||||
df = load_data_from_local_file(input_file, console, run_id)
|
||||
else:
|
||||
df = load_data_from_bigquery(console, run_id)
|
||||
|
||||
if df.empty:
|
||||
raise Exception("Dataframe is empty")
|
||||
|
||||
if config.index:
|
||||
console.print(
|
||||
f"[bold blue]Running evaluation for agent: {config.agent.name}[/bold blue]"
|
||||
)
|
||||
results_df = run_evaluation(df, console)
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]Skipping agent '{config.agent.name}' as it has no index configured.[/yellow]"
|
||||
)
|
||||
raise
|
||||
|
||||
final_results_df = results_df
|
||||
|
||||
if output_file:
|
||||
console.print(
|
||||
f"Saving detailed results to CSV file: [bold cyan]{output_file}[/bold cyan]"
|
||||
)
|
||||
try:
|
||||
final_results_df.to_csv(output_file, index=False)
|
||||
console.print(
|
||||
f"Successfully saved {len(final_results_df)} rows to [bold green]{output_file}[/bold green]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while saving to CSV: {e}[/bold red]"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Save detailed results to BigQuery
|
||||
project_id = config.bigquery.project_id or config.project_id
|
||||
dataset_id = config.bigquery.dataset_id
|
||||
table_name = config.bigquery.table_ids["search_eval"]
|
||||
table_id = f"{project_id}.{dataset_id}.{table_name}"
|
||||
|
||||
console.print(
|
||||
f"Saving detailed results to BigQuery table: [bold cyan]{table_id}[/bold cyan]"
|
||||
)
|
||||
try:
|
||||
final_results_df.to_gbq(
|
||||
destination_table=f"{dataset_id}.{table_name}",
|
||||
project_id=project_id,
|
||||
if_exists="append",
|
||||
)
|
||||
console.print(
|
||||
f"Successfully saved {len(final_results_df)} rows to [bold green]{table_id}[/bold green]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while saving to BigQuery: {e}[/bold red]"
|
||||
)
|
||||
raise
|
||||
28
apps/synth-gen/README.md
Normal file
28
apps/synth-gen/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Synthetic Question Generator
|
||||
|
||||
This application generates a set of synthetic questions from documents stored in Google Cloud Storage (GCS) and saves them to a local CSV file. For each document, it generates one question for each predefined question type (Factual, Summarization, etc.).
|
||||
|
||||
The output CSV is structured for easy uploading to a BigQuery table with the following schema: `input` (STRING), `expected_output` (STRING), `source` (STRING), `type` (STRING).
|
||||
|
||||
## Usage
|
||||
|
||||
The script is run from the command line. You need to provide the path to the source documents within your GCS bucket and a path for the output CSV file.
|
||||
|
||||
### Command
|
||||
|
||||
```bash
|
||||
uv run python -m synth_gen.main [OPTIONS] GCS_PATH
|
||||
```
|
||||
|
||||
### Arguments
|
||||
|
||||
* `GCS_PATH`: (Required) The path to the directory in your GCS bucket where the source markdown files are located (e.g., `documents/markdown/`).
|
||||
* `--output-csv, -o`: (Required) The local file path where the generated questions will be saved in CSV format.
|
||||
|
||||
### Example
|
||||
|
||||
```bash
|
||||
uv run python -m synth_gen.main documents/processed/ --output-csv synthetic_questions.csv
|
||||
```
|
||||
|
||||
This command will fetch all documents from the `gs://<your-bucket-name>/documents/processed/` directory, generate questions for each, and save them to a file named `synthetic_questions.csv` in the current directory.
|
||||
22
apps/synth-gen/pyproject.toml
Normal file
22
apps/synth-gen/pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[project]
|
||||
name = "synth-gen"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"llm",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
synth-gen = "synth_gen.main:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.uv.sources]
|
||||
llm = { workspace = true }
|
||||
2
apps/synth-gen/src/synth_gen/__init__.py
Normal file
2
apps/synth-gen/src/synth_gen/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def main() -> None:
|
||||
print("Hello from synth-gen!")
|
||||
349
apps/synth-gen/src/synth_gen/main.py
Normal file
349
apps/synth-gen/src/synth_gen/main.py
Normal file
@@ -0,0 +1,349 @@
|
||||
import datetime
|
||||
import os
|
||||
import random
|
||||
from typing import Annotated, Any, Dict, List
|
||||
|
||||
import pandas as pd
|
||||
import typer
|
||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from pydantic import BaseModel
|
||||
from rich.console import Console
|
||||
from rich.progress import track
|
||||
|
||||
from rag_eval.config import Settings
|
||||
|
||||
# --- Configuration ---
|
||||
PROMPT_TEMPLATE = """
|
||||
Eres un experto en generación de preguntas sintéticas. Tu tarea es crear preguntas sintéticas en español basadas en documentos de referencia proporcionados.
|
||||
|
||||
## INSTRUCCIONES:
|
||||
|
||||
### Requisitos obligatorios:
|
||||
1. **Idioma**: La pregunta DEBE estar completamente en español
|
||||
2. **Basada en documentos**: La pregunta DEBE poder responderse ÚNICAMENTE con la información contenida en los documentos proporcionados
|
||||
3. **Tipo de pregunta**: Sigue estrictamente la definición del tipo de pregunta especificado
|
||||
4. **Identificación de fuentes**: Incluye el ID de fuente de todos los documentos necesarios para responder la pregunta
|
||||
5. **Salida esperada**: Incluye la respuesta perfecta basada en los documentos necesarios para responder la pregunta
|
||||
|
||||
### Tono de pregunta:
|
||||
La pregunta debe ser similar a la que haría un usuario sin contexto sobre el sistema o la información disponible. Ingenuo y curioso.
|
||||
|
||||
### Tipo de pregunta solicitado:
|
||||
**Tipo**: {qtype}
|
||||
**Definición**: {qtype_def}
|
||||
|
||||
### Documentos de referencia:
|
||||
{context}
|
||||
|
||||
Por favor, genera una pregunta siguiendo estas instrucciones.
|
||||
""".strip()
|
||||
|
||||
RESPONSE_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pregunta": {
|
||||
"type": "string",
|
||||
},
|
||||
"expected_output": {
|
||||
"type": "string",
|
||||
},
|
||||
"ids": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["pregunta", "expected_output", "ids"],
|
||||
}
|
||||
|
||||
|
||||
class ResponseSchema(BaseModel):
|
||||
pregunta: str
|
||||
expected_output: str
|
||||
ids: List[str]
|
||||
|
||||
|
||||
class Turn(BaseModel):
|
||||
pregunta: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
class MultiStepResponseSchema(BaseModel):
|
||||
conversation: List[Turn]
|
||||
|
||||
|
||||
MULTI_STEP_PROMPT_TEMPLATE = """
|
||||
Eres un experto en la generación de conversaciones sintéticas. Tu tarea es crear una conversación en español con múltiples turnos basada en los documentos de referencia proporcionados.
|
||||
|
||||
## INSTRUCCIONES:
|
||||
|
||||
### Requisitos obligatorios:
|
||||
1. **Idioma**: La conversación DEBE estar completamente en español.
|
||||
2. **Basada en documentos**: Todas las respuestas DEBEN poder responderse ÚNICAMENTE con la información contenida en los documentos de referencia.
|
||||
3. **Número de turnos**: La conversación debe tener exactamente {num_turns} turnos. Un turno consiste en una pregunta del usuario y una respuesta del asistente.
|
||||
4. **Flujo conversacional**: Las preguntas deben seguir un orden lógico, como si un usuario estuviera explorando un tema paso a paso. La segunda pregunta debe ser una continuación de la primera, y así sucesivamente.
|
||||
5. **Salida esperada**: Proporciona la respuesta perfecta para cada pregunta, basada en los documentos de referencia.
|
||||
|
||||
### Tono de las preguntas:
|
||||
Las preguntas deben ser similares a las que haría un usuario sin contexto sobre el sistema o la información disponible. Deben ser ingenuas y curiosas.
|
||||
|
||||
### Documentos de referencia:
|
||||
{context}
|
||||
|
||||
Por favor, genera una conversación de {num_turns} turnos siguiendo estas instrucciones.
|
||||
""".strip()
|
||||
|
||||
|
||||
QUESTION_TYPE_MAP = {
|
||||
"Factual": "Questions targeting specific details within a reference (e.g., a company’s profit in a report, a verdict in a legal case, or symptoms in a medical record) to test RAG’s retrieval accuracy.",
|
||||
"Summarization": "Questions that require comprehensive answers, covering all relevant information, to mainly evaluate the recall rate of RAG retrieval.",
|
||||
"Multi-hop Reasoning": "Questions involve logical relationships among events and details within adocument, forming a reasoning chain to assess RAG’s logical reasoning ability.",
|
||||
"Unanswerable": "Questions arise from potential information loss during the schema-to-article generation, where no corresponding information fragment exists, or the information is insufficient for an answer.",
|
||||
}
|
||||
|
||||
|
||||
# --- Core Logic ---
|
||||
def generate_synthetic_question(
|
||||
llm: VertexAILLM, file_content: str, file_path: str, q_type: str, q_def: str, language_model: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Generates a single synthetic question using the LLM."""
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
context=file_content, id=file_path, qtype=q_type, qtype_def=q_def
|
||||
)
|
||||
response = llm.structured_generation(
|
||||
model=language_model,
|
||||
prompt=prompt,
|
||||
response_model=ResponseSchema,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def generate_synthetic_conversation(
|
||||
llm: VertexAILLM,
|
||||
file_content: str,
|
||||
file_path: str,
|
||||
num_turns: int,
|
||||
language_model: str,
|
||||
) -> MultiStepResponseSchema:
|
||||
"""Generates a synthetic conversation with multiple turns using the LLM."""
|
||||
prompt = MULTI_STEP_PROMPT_TEMPLATE.format(
|
||||
context=file_content, num_turns=num_turns
|
||||
)
|
||||
response = llm.structured_generation(
|
||||
model=language_model,
|
||||
prompt=prompt,
|
||||
response_model=MultiStepResponseSchema,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def generate(
|
||||
num_questions: int,
|
||||
output_csv: str = None,
|
||||
num_turns: int = 1,
|
||||
) -> str:
|
||||
"""
|
||||
Core logic for generating a specified number of synthetic questions.
|
||||
"""
|
||||
console = Console()
|
||||
settings = Settings()
|
||||
llm = VertexAILLM(project=settings.project_id, location=settings.location)
|
||||
storage = GoogleCloudFileStorage(bucket=settings.bucket)
|
||||
|
||||
run_id = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
console.print(f"[bold yellow]Generated Run ID: {run_id}[/bold yellow]")
|
||||
|
||||
all_rows = []
|
||||
if not settings.index:
|
||||
console.print("[yellow]Skipping as no index is configured.[/yellow]")
|
||||
return ""
|
||||
|
||||
gcs_path = f"{settings.index.name}/contents/"
|
||||
console.print(f"[green]Fetching files from GCS path: {gcs_path}[/green]")
|
||||
|
||||
try:
|
||||
all_files = storage.list_files(path=gcs_path)
|
||||
console.print(f"Found {len(all_files)} total files to process.")
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error listing files: {e}[/bold red]")
|
||||
return ""
|
||||
|
||||
if not all_files:
|
||||
console.print("[yellow]No files found. Skipping.[/yellow]")
|
||||
return ""
|
||||
|
||||
files_to_process = random.sample(
|
||||
all_files, k=min(num_questions, len(all_files))
|
||||
)
|
||||
console.print(
|
||||
f"Randomly selected {len(files_to_process)} files to generate questions from."
|
||||
)
|
||||
|
||||
for file_path in track(files_to_process, description="Generating questions..."):
|
||||
try:
|
||||
file_content = storage.get_file_stream(file_path).read().decode("utf-8-sig")
|
||||
q_type, q_def = random.choice(list(QUESTION_TYPE_MAP.items()))
|
||||
|
||||
if num_turns > 1:
|
||||
conversation_data = None
|
||||
for attempt in range(3): # Retry up to 3 times
|
||||
conversation_data = generate_synthetic_conversation(
|
||||
llm,
|
||||
file_content,
|
||||
file_path,
|
||||
num_turns,
|
||||
settings.agent.language_model,
|
||||
)
|
||||
if (
|
||||
conversation_data
|
||||
and conversation_data.conversation
|
||||
and len(conversation_data.conversation) == num_turns
|
||||
):
|
||||
break # Success
|
||||
console.print(
|
||||
f"[yellow]Failed to generate valid conversation for {os.path.basename(file_path)}. Retrying ({attempt + 1}/3)...[/yellow]"
|
||||
)
|
||||
conversation_data = None
|
||||
|
||||
if not conversation_data:
|
||||
console.print(
|
||||
f"[bold red]Failed to generate valid conversation for {os.path.basename(file_path)} after 3 attempts. Skipping.[/bold red]"
|
||||
)
|
||||
continue
|
||||
|
||||
conversation_id = str(random.randint(10000, 99999))
|
||||
for i, turn in enumerate(conversation_data.conversation):
|
||||
row = {
|
||||
"input": turn.pregunta,
|
||||
"expected_output": turn.expected_output,
|
||||
"source": os.path.splitext(os.path.basename(file_path))[0],
|
||||
"type": "Multi-turn",
|
||||
"agent": settings.agent.name,
|
||||
"run_id": run_id,
|
||||
"conversation_id": conversation_id,
|
||||
"turn": i + 1,
|
||||
}
|
||||
all_rows.append(row)
|
||||
|
||||
else: # Single turn generation
|
||||
generated_data = None
|
||||
for attempt in range(3): # Retry up to 3 times
|
||||
generated_data = generate_synthetic_question(
|
||||
llm,
|
||||
file_content,
|
||||
file_path,
|
||||
q_type,
|
||||
q_def,
|
||||
settings.agent.language_model,
|
||||
)
|
||||
if (
|
||||
generated_data
|
||||
and generated_data.expected_output
|
||||
and generated_data.expected_output.strip()
|
||||
):
|
||||
break # Success, exit retry loop
|
||||
console.print(
|
||||
f"[yellow]Empty answer for {q_type} on {os.path.basename(file_path)}. Retrying ({attempt + 1}/3)...[/yellow]"
|
||||
)
|
||||
generated_data = None # Reset to indicate failure
|
||||
|
||||
if not generated_data:
|
||||
console.print(
|
||||
f"[bold red]Failed to generate valid answer for {q_type} on {os.path.basename(file_path)} after 3 attempts. Skipping.[/bold red]"
|
||||
)
|
||||
continue
|
||||
|
||||
row = {
|
||||
"input": generated_data.pregunta,
|
||||
"expected_output": generated_data.expected_output,
|
||||
"source": os.path.splitext(os.path.basename(file_path))[0],
|
||||
"type": q_type,
|
||||
"agent": settings.agent.name,
|
||||
"run_id": run_id,
|
||||
}
|
||||
all_rows.append(row)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error processing file {file_path}: {e}[/bold red]")
|
||||
|
||||
if not all_rows:
|
||||
console.print("[bold yellow]No questions were generated.[/bold yellow]")
|
||||
return ""
|
||||
|
||||
df = pd.DataFrame(all_rows)
|
||||
|
||||
if output_csv:
|
||||
console.print(
|
||||
f"\n[bold green]Saving {len(df)} generated questions to {output_csv}...[/bold green]"
|
||||
)
|
||||
df.to_csv(output_csv, index=False, encoding="utf-8-sig")
|
||||
console.print("[bold green]Synthetic question generation complete.[/bold green]")
|
||||
else:
|
||||
console.print(
|
||||
f"\n[bold green]Saving {len(df)} generated questions to BigQuery...[/bold green]"
|
||||
)
|
||||
project_id = settings.bigquery.project_id or settings.project_id
|
||||
dataset_id = settings.bigquery.dataset_id
|
||||
table_name = settings.bigquery.table_ids["synth_gen"]
|
||||
table_id = f"{project_id}.{dataset_id}.{table_name}"
|
||||
|
||||
console.print(f"Saving to BigQuery table: [bold cyan]{table_id}[/bold cyan]")
|
||||
try:
|
||||
# Ensure new columns exist for all rows before upload
|
||||
if "conversation_id" not in df.columns:
|
||||
df["conversation_id"] = None
|
||||
if "turn" not in df.columns:
|
||||
df["turn"] = None
|
||||
|
||||
df.to_gbq(
|
||||
destination_table=f"{dataset_id}.{table_name}",
|
||||
project_id=project_id,
|
||||
if_exists="append",
|
||||
)
|
||||
console.print(
|
||||
f"Successfully saved {len(df)} rows to [bold green]{table_id}[/bold green]"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred while saving to BigQuery: {e}[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
console.print(f"[bold yellow]Finished run with ID: {run_id}[/bold yellow]")
|
||||
return run_id
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
num_questions: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--num-questions", "-n", help="Number of questions to generate."
|
||||
),
|
||||
] = 10,
|
||||
output_csv: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--output-csv", "-o", help="Optional: Path to save the output CSV file."
|
||||
),
|
||||
] = None,
|
||||
num_turns: Annotated[
|
||||
int,
|
||||
typer.Option(
|
||||
"--num-turns",
|
||||
"-t",
|
||||
help="Number of conversational turns to generate.",
|
||||
),
|
||||
] = 1,
|
||||
):
|
||||
"""
|
||||
Generates a specified number of synthetic questions and saves them to BigQuery (default) or a local CSV file.
|
||||
"""
|
||||
generate(
|
||||
num_questions=num_questions, output_csv=output_csv, num_turns=num_turns
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
173
config.yaml
Normal file
173
config.yaml
Normal file
@@ -0,0 +1,173 @@
|
||||
project_id: bnt-orquestador-cognitivo-dev
|
||||
location: us-central1
|
||||
service_account: sa-cicd-gitlab@bnt-orquestador-cognitivo-dev.iam.gserviceaccount.com
|
||||
|
||||
bucket: bnt_orquestador_cognitivo_gcs_configs_dev
|
||||
base_image: us-central1-docker.pkg.dev/bnt-orquestador-cognitivo-dev/dfcx-pipelines/ap01194-orq-cog-dfcx-agent-eval-in:v1.0.6-ft-anibal
|
||||
dialogflow_agent_id: 5590ff1d-1f66-4777-93f5-1a608f1900ac
|
||||
processing_image: us-central1-docker.pkg.dev/bnt-orquestador-cognitivo-dev/dfcx-pipelines/ap01194-orq-cog-rag-ops:v1.3.7
|
||||
|
||||
agent_name: sigma
|
||||
agent_instructions: |
|
||||
Eres VAia, un agente experto de Sigma especializado en educación financiera y los productos/servicios de la compañía. Tu único objetivo es dar respuestas directas, precisas y amigables a las preguntas de los usuarios en WhatsApp.
|
||||
|
||||
*Principio fundamental: Ve siempre directo al grano. Las respuestas deben ser concisas y comenzar inmediatamente con la información solicitada, sin frases introductorias.*
|
||||
|
||||
Utiliza exclusivamente la herramienta 'conocimiento' para basar tus respuestas. No confíes en tu conocimiento previo. Si la herramienta no arroja resultados relevantes, informa al usuario que no tienes la información necesaria.
|
||||
|
||||
---
|
||||
*REGLAS DE RESPUESTA CRÍTICAS:*
|
||||
1. *CERO INTRODUCCIONES:* Nunca inicies tus respuestas con saludos o frases de cortesía como "¡Hola!", "¡Claro!", "Por supuesto", "¡Desde luego!", etc. La primera palabra de tu respuesta debe ser parte de la respuesta directa.
|
||||
- _Ejemplo INCORRECTO:_ "¡Claro que sí! El interés compuesto es..."
|
||||
- _Ejemplo CORRECTO:_ "El interés compuesto es..."
|
||||
2. *TONO AMIGABLE Y DIRECTO:* Aunque no usas saludos, tu tono debe ser siempre cálido, servicial y fácil de entender. Usa un lenguaje claro y positivo. ¡Imagina que estás ayudando a un amigo a entender finanzas!
|
||||
3. *FORMATO WHATSAPP:* Utiliza el formato de WhatsApp para resaltar información importante: *negritas* para énfasis, _cursivas_ para términos específicos y bullet points (`- `) para listas.
|
||||
4. *SIEMPRE USA LA HERRAMIENTA:* Utiliza la herramienta 'conocimiento' para cada pregunta del usuario. Es tu única fuente de verdad.
|
||||
5. *RESPUESTAS BASADAS EN HECHOS:* Basa tus respuestas únicamente en la información obtenida de la herramienta 'conocimiento'.
|
||||
6. *RESPONDE EN ESPAÑOL LATINO:* Todas tus respuestas deben ser en español latinoamericano.
|
||||
7. *USA EMOJIS PARA SER AMIGABLE:* Utiliza emojis de forma natural para añadir un toque de calidez y dinamismo a tus respuestas. No temas usar emojis relevantes para hacer la conversación más amena. Algunos emojis que puedes usar son: 💡, ✅, 📈, 💰, 😊, 👍, ✨, 🚀, 😉, 🎉, 🤩, 🫡, 👏, 💸, 🛍️, 💪, 📊.
|
||||
|
||||
*Flujo de Interacción:*
|
||||
1. El usuario hace una pregunta.
|
||||
2. Tú, VAia, utilizas la herramienta 'conocimiento' para buscar la información más relevante.
|
||||
3. Tú, VAia, construyes una respuesta directa, concisa y amigable usando solo los resultados de la búsqueda y la envías al usuario.
|
||||
|
||||
---
|
||||
*CONTEXTO BASE:*
|
||||
|
||||
Esta información es complementaria y sirve para informar a VAia con contexto sobre sus propósito, capacidades, limitaciones, y contexto sobre Sigma y sus productos.
|
||||
|
||||
*1. Acerca de VAia*
|
||||
|
||||
*VAia* es un asistente virtual (chatbot) de la institución financiera Sigma, diseñado para ser el primer punto de contacto para resolver las dudas de los usuarios de forma automatizada.
|
||||
|
||||
- _Propósito principal:_ Proporcionar información clara, precisa y al instante sobre los productos y servicios del banco, las funcionalidades de la aplicación y temas de educación financiera.
|
||||
- _Fuente de conocimiento:_ Las respuestas de VAia se basan exclusivamente en la base de conocimiento oficial y curada de Sigma. Esto garantiza que la información sea fiable, consistente y esté actualizada.
|
||||
|
||||
*2. Capacidades y Alcance Informativo*
|
||||
|
||||
*Formulación de Preguntas y Ejemplos*
|
||||
|
||||
Para una interacción efectiva, el bot entiende mejor las *preguntas directas, específicas y formuladas con claridad*. Se recomienda usar palabras clave relevantes para el tema de interés.
|
||||
|
||||
* _Forma más efectiva:_ Realizar preguntas cortas y enfocadas en un solo tema a la vez. Por ejemplo, en lugar de preguntar _"necesito dinero y no sé qué hacer"_, es mejor preguntar _"¿qué créditos ofrece Sigma?"_ o _"¿cómo solicito un adelanto de nómina?"_.
|
||||
* _Tipos de dudas que entiende mejor:_ Preguntas que empiezan con "¿Qué es...?", "¿Cómo puedo...?", "¿Cuáles son los beneficios de...?", o que solicitan información sobre un producto específico.
|
||||
|
||||
_Ejemplos de preguntas bien formuladas:_
|
||||
|
||||
* _¿Qué es el Costo Anual Total (CAT)?_
|
||||
* _¿Cómo puedo activar mi nueva tarjeta de crédito desde la app?_
|
||||
* _¿Cuáles son los beneficios de la Tarjeta de Crédito Platinum?_
|
||||
* _¿Qué necesito para solicitar un Adelanto de Nómina?_
|
||||
* _Guíame para crear una Cápsula de ahorro._
|
||||
* _¿Cómo puedo consultar mi estado de cuenta?_
|
||||
|
||||
*Temas y Servicios Soportados*
|
||||
|
||||
VAia puede proporcionar información detallada sobre las siguientes áreas:
|
||||
|
||||
1. *Educación Financiera:*
|
||||
- Conceptos: Ahorro, presupuesto, inversiones, Buró de Crédito, CAT, CETES, tasas de interés, inflación.
|
||||
- Productos: Tarjetas de crédito y débito, fondos de inversión, seguros.
|
||||
|
||||
2. *Funcionalidades de la App Móvil (Servicios Digitales):*
|
||||
- _Consultas:_ Saldos, movimientos, estados de cuenta, detalles de tarjetas y créditos.
|
||||
- _Transferencias:_ SPEI, Dimo, entre cuentas propias, alta de nuevos contactos.
|
||||
- _Pagos:_ Pago de servicios (luz, agua, etc.), impuestos (SAT), y pagos con CoDi.
|
||||
- _Gestión de Tarjetas:_ Activación, reporte de robo/extravío, cambio de NIP, configuración de límites de gasto, encendido y apagado de tarjetas.
|
||||
- _Ahorro e Inversión:_ Creación y gestión de "Cápsulas" de ahorro, compra-venta en fondos de inversión.
|
||||
- _Solicitudes y Aclaraciones:_ Portabilidad de nómina, reposición de tarjetas, inicio de aclaraciones por cargos no reconocidos.
|
||||
|
||||
3. *Productos y Servicios del Banco:*
|
||||
- _Cuentas:_ Cuenta Digital, Cuenta Digital Ilimitada.
|
||||
- _Créditos:_ Crédito de Nómina, Adelanto de Nómina.
|
||||
- _Tarjetas:_ Tarjeta de Crédito Clásica, Platinum, Garantizada.
|
||||
- _Inversiones:_ Fondo Digital, Fondo Sustentable.
|
||||
- _Seguros:_ Seguro de Gadgets, Seguro de Mascotas.
|
||||
|
||||
*3. Limitaciones y Canales de Soporte*
|
||||
|
||||
*¿Qué NO puede hacer VAia?*
|
||||
|
||||
- _No realiza transacciones:_ No puede ejecutar operaciones como transferencias, pagos o inversiones en nombre del usuario. Su función es guiar al usuario para que él mismo las realice de forma segura.
|
||||
- _No tiene acceso a datos personales o de cuentas:_ No puede consultar saldos, movimientos, o cualquier información sensible del usuario.
|
||||
- _No ofrece asesoría financiera personalizada:_ No puede dar recomendaciones de inversión o productos basadas en la situación particular del usuario.
|
||||
- _No gestiona quejas o aclaraciones complejas:_ Puede guiar sobre cómo iniciar una aclaración, pero el seguimiento y la resolución corresponden a un ejecutivo humano.
|
||||
- _No posee información de otras instituciones bancarias_.
|
||||
|
||||
*Preguntas que VAia no entiende bien*
|
||||
|
||||
El bot puede tener dificultades con preguntas que son:
|
||||
|
||||
- _Ambigüas o muy generales:_ _"Ayuda"_, _"Tengo un problema"_.
|
||||
- _Emocionales o subjetivas:_ _"Estoy muy molesto con el servicio"_.
|
||||
- _Fuera de su dominio de conocimiento:_ Preguntas sobre temas no financieros o sobre productos de otros bancos.
|
||||
|
||||
*Diferencia clave con un Asesor Humano*
|
||||
|
||||
*VAia:*
|
||||
- _Disponibilidad:_ 24/7, respuesta inmediata.
|
||||
- _Tipo de Ayuda:_ Informativa y procedimental (basada en la base de conocimiento).
|
||||
- _Acceso a Datos:_ Nulo.
|
||||
- _Casos de Uso:_ Dudas generales, guías "cómo hacer", definiciones de productos.
|
||||
|
||||
*Asesor Humano:*
|
||||
- _Disponibilidad:_ Horario de oficina.
|
||||
- _Tipo de Ayuda:_ Personalizada, resolutiva y transaccional.
|
||||
- _Acceso a Datos:_ Acceso seguro al perfil y datos del cliente.
|
||||
- _Casos de Uso:_ Problemas específicos con la cuenta, errores en transacciones, quejas, asesoría financiera.
|
||||
|
||||
*4. Escalación y Contacto con Asesores Humanos*
|
||||
|
||||
*¿Cuándo buscar a un Asesor Humano?*
|
||||
|
||||
El usuario debe solicitar la ayuda de un asesor humano cuando:
|
||||
|
||||
- La consulta requiere acceso a información personal de la cuenta.
|
||||
- Se presenta un problema técnico, un error en una transacción o un cargo no reconocido.
|
||||
- Se necesita levantar una queja formal o dar seguimiento a una aclaración.
|
||||
|
||||
*Proceso de Escalación*
|
||||
|
||||
Si VAia no puede resolver una duda, está programado para ofrecer proactivamente al usuario instrucciones para *contactar a un asesor humano*, a través de la aplicación móvil o número telefónico.
|
||||
|
||||
*5. Seguridad y Privacidad de la Información*
|
||||
|
||||
- _Protección de Datos del Usuario:_ La interacción con VAia es segura, ya que el asistente *no solicita ni almacena datos personales*, números de cuenta, contraseñas o cualquier otra información sensible. Se instruye a los usuarios a no compartir este tipo de datos en la conversación.
|
||||
- _Información sobre Seguridad de la App:_ VAia puede dar detalles sobre _cómo funcionan_ las herramientas de seguridad de la aplicación (ej. activación de biometría, cambio de contraseña, apagado de tarjetas) para que el usuario las gestione. Sin embargo, no tiene acceso a la configuración de seguridad específica de la cuenta del usuario ni puede modificarla.
|
||||
|
||||
*6. Temas prohibídos*
|
||||
|
||||
VAia no puede compartir información o contestar preguntas sobre los siguentes temas:
|
||||
|
||||
- Criptomonedas
|
||||
- ETFs
|
||||
|
||||
---
|
||||
*NOTAS DE SIGMA:*
|
||||
|
||||
Esta es una sección con información rapida de Sigma. Puedes profundizar en esta información con la herramienta 'conocimiento'.
|
||||
|
||||
- Retiros en cajeros automaticos:
|
||||
a. Tarjetas de Crédito: 6.5% de interés, con 4 retiros gratuitos al mes.
|
||||
b. Tarjetas de Débito: Sin interés
|
||||
|
||||
agent_language_model: gemini-2.5-flash
|
||||
agent_embedding_model: gemini-embedding-001
|
||||
agent_thinking: 0
|
||||
|
||||
index_name: si1
|
||||
index_endpoint: projects/1007577023101/locations/us-central1/indexEndpoints/76334694269976576
|
||||
index_dimensions: 3072
|
||||
index_machine_type: e2-standard-16
|
||||
index_origin: gs://bnt_orquestador_cognitivo_gcs_kb_dev/
|
||||
index_destination: gs://bnt_orquestador_cognitivo_gcs_configs_dev/
|
||||
index_chunk_limit: 3000
|
||||
|
||||
bigquery_project_id: "bnt-lakehouse-innovacion-des"
|
||||
bigquery_dataset_id: "ds_orquestador_des"
|
||||
bigquery_table_ids:
|
||||
synth_gen: "eval_questions"
|
||||
search_eval: "search_results"
|
||||
keypoint_eval: "keypoints_results"
|
||||
red_team: "redteam_results"
|
||||
0
packages/chunker/README.md
Normal file
0
packages/chunker/README.md
Normal file
23
packages/chunker/pyproject.toml
Normal file
23
packages/chunker/pyproject.toml
Normal file
@@ -0,0 +1,23 @@
|
||||
[project]
|
||||
name = "chunker"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"chonkie>=1.1.2",
|
||||
"pdf2image>=1.17.0",
|
||||
"pypdf>=6.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
llm-chunker = "chunker.llm_chunker:app"
|
||||
recursive-chunker = "chunker.recursive_chunker:app"
|
||||
contextual-chunker = "chunker.contextual_chunker:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
2
packages/chunker/src/chunker/__init__.py
Normal file
2
packages/chunker/src/chunker/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from chunker!"
|
||||
66
packages/chunker/src/chunker/base_chunker.py
Normal file
66
packages/chunker/src/chunker/base_chunker.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, TypedDict
|
||||
|
||||
|
||||
class Document(TypedDict):
|
||||
"""A dictionary representing a processed document chunk."""
|
||||
|
||||
page_content: str
|
||||
metadata: dict
|
||||
|
||||
|
||||
class BaseChunker(ABC):
|
||||
"""Abstract base class for chunker implementations."""
|
||||
|
||||
@abstractmethod
|
||||
def process_text(self, text: str) -> List[Document]:
|
||||
"""
|
||||
Processes a string of text into a list of Document chunks.
|
||||
|
||||
Args:
|
||||
text: The input string to process.
|
||||
|
||||
Returns:
|
||||
A list of Document objects.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_path(self, path: Path) -> List[Document]:
|
||||
"""
|
||||
Reads a file from a Path object and processes its content.
|
||||
|
||||
It attempts to read the file with UTF-8 encoding and falls back to
|
||||
latin-1 if a UnicodeDecodeError occurs.
|
||||
|
||||
Args:
|
||||
path: The Path object pointing to the file.
|
||||
|
||||
Returns:
|
||||
A list of Document objects from the file's content.
|
||||
"""
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
text = path.read_text(encoding="latin-1")
|
||||
return self.process_text(text)
|
||||
|
||||
def process_bytes(self, b: bytes) -> List[Document]:
|
||||
"""
|
||||
Decodes a byte string and processes its content.
|
||||
|
||||
It first attempts to decode the bytes as UTF-8. If that fails,
|
||||
it falls back to latin-1.
|
||||
|
||||
Args:
|
||||
b: The input byte string.
|
||||
|
||||
Returns:
|
||||
A list of Document objects from the byte string's content.
|
||||
"""
|
||||
try:
|
||||
text = b.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
# Fallback for files that are not UTF-8 encoded.
|
||||
text = b.decode("utf-8-sig")
|
||||
return self.process_text(text)
|
||||
155
packages/chunker/src/chunker/contextual_chunker.py
Normal file
155
packages/chunker/src/chunker/contextual_chunker.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List
|
||||
|
||||
import typer
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
|
||||
from .base_chunker import BaseChunker, Document
|
||||
|
||||
|
||||
class ContextualChunker(BaseChunker):
|
||||
"""
|
||||
A chunker that uses a large language model to create context-aware chunks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: VertexAILLM,
|
||||
max_chunk_size: int = 800,
|
||||
model: str = "gemini-2.0-flash",
|
||||
):
|
||||
"""
|
||||
Initializes the ContextualChunker.
|
||||
|
||||
Args:
|
||||
max_chunk_size: The maximum length of a chunk in characters.
|
||||
model: The name of the language model to use.
|
||||
llm_client: An optional instance of a language model client.
|
||||
"""
|
||||
self.max_chunk_size = max_chunk_size
|
||||
self.model = model
|
||||
self.llm_client = llm_client
|
||||
|
||||
def _split_text(self, text: str) -> List[str]:
|
||||
"""Splits text into evenly sized chunks of a maximum size, trying to respect sentence and paragraph boundaries."""
|
||||
import math
|
||||
|
||||
num_chunks = math.ceil(len(text) / self.max_chunk_size)
|
||||
if num_chunks == 1:
|
||||
return [text]
|
||||
|
||||
ideal_chunk_size = math.ceil(len(text) / num_chunks)
|
||||
|
||||
chunks = []
|
||||
current_pos = 0
|
||||
while current_pos < len(text):
|
||||
end_pos = min(current_pos + ideal_chunk_size, len(text))
|
||||
|
||||
# Find a good split point around the end_pos
|
||||
split_point = -1
|
||||
if end_pos < len(text):
|
||||
paragraph_break = text.rfind("\n\n", current_pos, end_pos)
|
||||
if paragraph_break != -1:
|
||||
split_point = paragraph_break + 2
|
||||
else:
|
||||
sentence_break = text.rfind(". ", current_pos, end_pos)
|
||||
if sentence_break != -1:
|
||||
split_point = sentence_break + 1
|
||||
else:
|
||||
split_point = end_pos
|
||||
else:
|
||||
split_point = end_pos
|
||||
|
||||
chunks.append(text[current_pos:split_point])
|
||||
current_pos = split_point
|
||||
|
||||
return chunks
|
||||
|
||||
def process_text(self, text: str) -> List[Document]:
|
||||
"""
|
||||
Processes a string of text into a list of context-aware Document chunks.
|
||||
"""
|
||||
if len(text) <= self.max_chunk_size:
|
||||
return [{"page_content": text, "metadata": {}}]
|
||||
|
||||
chunks = self._split_text(text)
|
||||
processed_chunks: List[Document] = []
|
||||
|
||||
for i, chunk_content in enumerate(chunks):
|
||||
prompt = f"""
|
||||
Documento Original:
|
||||
---
|
||||
{text}
|
||||
---
|
||||
|
||||
Fragmento Actual:
|
||||
---
|
||||
{chunk_content}
|
||||
---
|
||||
|
||||
Tarea:
|
||||
Genera un resumen conciso del "Documento Original" que proporcione el contexto necesario para entender el "Fragmento Actual". El resumen debe ser un solo párrafo en español.
|
||||
"""
|
||||
|
||||
summary = self.llm_client.generate(self.model, prompt).text
|
||||
contextualized_chunk = (
|
||||
f"> **Contexto del documento original:**\n> {summary}\n\n---\n\n"
|
||||
+ chunk_content
|
||||
)
|
||||
|
||||
processed_chunks.append(
|
||||
{
|
||||
"page_content": contextualized_chunk,
|
||||
"metadata": {"chunk_index": i},
|
||||
}
|
||||
)
|
||||
|
||||
return processed_chunks
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_file_path: Annotated[
|
||||
str, typer.Argument(help="Path to the input text file.")
|
||||
],
|
||||
output_dir: Annotated[
|
||||
str, typer.Argument(help="Directory to save the output file.")
|
||||
],
|
||||
max_chunk_size: Annotated[
|
||||
int, typer.Option(help="Maximum chunk size in characters.")
|
||||
] = 800,
|
||||
model: Annotated[
|
||||
str, typer.Option(help="Model to use for the processing")
|
||||
] = "gemini-2.0-flash",
|
||||
):
|
||||
"""
|
||||
Processes a text file using ContextualChunker and saves the output to a JSONL file.
|
||||
"""
|
||||
print(f"Starting to process {input_file_path}...")
|
||||
|
||||
chunker = ContextualChunker(max_chunk_size=max_chunk_size, model=model)
|
||||
documents = chunker.process_path(Path(input_file_path))
|
||||
|
||||
print(f"Successfully created {len(documents)} chunks.")
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
print(f"Created output directory: {output_dir}")
|
||||
|
||||
output_file_path = os.path.join(output_dir, "chunked_documents.jsonl")
|
||||
|
||||
with open(output_file_path, "w", encoding="utf-8") as f:
|
||||
for doc in documents:
|
||||
doc["metadata"]["source_file"] = os.path.basename(input_file_path)
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Successfully saved {len(documents)} chunks to {output_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
577
packages/chunker/src/chunker/llm_chunker.py
Normal file
577
packages/chunker/src/chunker/llm_chunker.py
Normal file
@@ -0,0 +1,577 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List
|
||||
|
||||
import tiktoken
|
||||
import typer
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_core.documents import Document as LangchainDocument
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from pdf2image import convert_from_path
|
||||
from pypdf import PdfReader
|
||||
|
||||
from rag_eval.config import Settings
|
||||
|
||||
from .base_chunker import BaseChunker, Document
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Manages token counting and truncation."""
|
||||
|
||||
def __init__(self, model_name: str = "gpt-3.5-turbo"):
|
||||
try:
|
||||
self.encoding = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
self.encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
return len(self.encoding.encode(text))
|
||||
|
||||
def truncate_to_tokens(
|
||||
self, text: str, max_tokens: int, preserve_sentences: bool = True
|
||||
) -> str:
|
||||
tokens = self.encoding.encode(text)
|
||||
|
||||
if len(tokens) <= max_tokens:
|
||||
return text
|
||||
|
||||
truncated_tokens = tokens[:max_tokens]
|
||||
truncated_text = self.encoding.decode(truncated_tokens)
|
||||
|
||||
if preserve_sentences:
|
||||
last_period = truncated_text.rfind(".")
|
||||
if last_period > len(truncated_text) * 0.7:
|
||||
return truncated_text[: last_period + 1]
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
class OptimizedChunkProcessor:
|
||||
"""Uses an LLM to merge and enhance text chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
target_tokens: int = 800,
|
||||
chunks_per_batch: int = 5,
|
||||
gemini_client: VertexAILLM | None = None,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
custom_instructions: str = "",
|
||||
):
|
||||
self.model = model
|
||||
self.client = gemini_client
|
||||
self.chunks_per_batch = chunks_per_batch
|
||||
self.max_tokens = max_tokens
|
||||
self.target_tokens = target_tokens
|
||||
self.token_manager = TokenManager(model_name)
|
||||
self.custom_instructions = custom_instructions
|
||||
self._merge_cache = {}
|
||||
self._enhance_cache = {}
|
||||
|
||||
def _get_cache_key(self, text: str) -> str:
|
||||
combined = text + self.custom_instructions
|
||||
return hashlib.md5(combined.encode()).hexdigest()[:16]
|
||||
|
||||
def should_merge_chunks(self, chunk1: str, chunk2: str) -> bool:
|
||||
cache_key = f"{self._get_cache_key(chunk1)}_{self._get_cache_key(chunk2)}"
|
||||
if cache_key in self._merge_cache:
|
||||
return self._merge_cache[cache_key]
|
||||
|
||||
try:
|
||||
combined_text = f"{chunk1}\n\n{chunk2}"
|
||||
combined_tokens = self.token_manager.count_tokens(combined_text)
|
||||
|
||||
if combined_tokens > self.max_tokens:
|
||||
self._merge_cache[cache_key] = False
|
||||
return False
|
||||
|
||||
if self.client:
|
||||
base_prompt = f"""Analiza estos dos fragmentos de texto y determina si deben unirse.
|
||||
|
||||
LÍMITES ESTRICTOS:
|
||||
- Tokens combinados: {combined_tokens}/{self.max_tokens}
|
||||
- Solo unir si hay continuidad semántica clara
|
||||
|
||||
Criterios de unión:
|
||||
1. El primer fragmento termina abruptamente
|
||||
2. El segundo fragmento continúa la misma idea/concepto
|
||||
3. La unión mejora la coherencia del contenido
|
||||
4. Exceder {self.max_tokens} tokens, SOLAMENTE si es necesario para mantener el contexto"""
|
||||
|
||||
base_prompt += f"""
|
||||
|
||||
Responde SOLO 'SI' o 'NO'.
|
||||
|
||||
Fragmento 1 ({self.token_manager.count_tokens(chunk1)} tokens):
|
||||
{chunk1[:500]}...
|
||||
|
||||
Fragmento 2 ({self.token_manager.count_tokens(chunk2)} tokens):
|
||||
{chunk2[:500]}..."""
|
||||
|
||||
response = self.client.generate(self.model, base_prompt).text
|
||||
result = response.strip().upper() == "SI"
|
||||
self._merge_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
result = (
|
||||
not chunk1.rstrip().endswith((".", "!", "?"))
|
||||
and combined_tokens <= self.target_tokens
|
||||
)
|
||||
self._merge_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error analizando chunks para merge: {e}")
|
||||
self._merge_cache[cache_key] = False
|
||||
return False
|
||||
|
||||
def enhance_chunk(self, chunk_text: str) -> str:
|
||||
cache_key = self._get_cache_key(chunk_text)
|
||||
if cache_key in self._enhance_cache:
|
||||
return self._enhance_cache[cache_key]
|
||||
|
||||
current_tokens = self.token_manager.count_tokens(chunk_text)
|
||||
|
||||
try:
|
||||
if self.client and current_tokens < self.max_tokens:
|
||||
base_prompt = f"""Optimiza este texto siguiendo estas reglas ESTRICTAS:
|
||||
|
||||
LÍMITES DE TOKENS:
|
||||
- Actual: {current_tokens} tokens
|
||||
- Máximo permitido: {self.max_tokens} tokens
|
||||
- Objetivo: {self.target_tokens} tokens
|
||||
|
||||
REGLAS FUNDAMENTALES:
|
||||
NO exceder {self.max_tokens} tokens bajo ninguna circunstancia
|
||||
Mantener TODA la información esencial y metadatos
|
||||
NO cambiar términos técnicos o palabras clave
|
||||
Asegurar oraciones completas y coherentes
|
||||
Optimizar claridad y estructura sin añadir contenido
|
||||
SOLO devuelve el texto no agregues conclusiones NUNCA
|
||||
|
||||
Si el texto está cerca del límite, NO expandir. Solo mejorar estructura."""
|
||||
|
||||
if self.custom_instructions.strip():
|
||||
base_prompt += (
|
||||
f"\n\nINSTRUCCIONES ADICIONALES:\n{self.custom_instructions}"
|
||||
)
|
||||
|
||||
base_prompt += f"\n\nTexto a optimizar:\n{chunk_text}"
|
||||
|
||||
response = self.client.generate(self.model, base_prompt).text
|
||||
enhanced_text = response.strip()
|
||||
|
||||
enhanced_tokens = self.token_manager.count_tokens(enhanced_text)
|
||||
if enhanced_tokens > self.max_tokens:
|
||||
print(
|
||||
f"Advertencia: Texto optimizado excede límite ({enhanced_tokens} > {self.max_tokens})"
|
||||
)
|
||||
enhanced_text = self.token_manager.truncate_to_tokens(
|
||||
enhanced_text, self.max_tokens
|
||||
)
|
||||
|
||||
self._enhance_cache[cache_key] = enhanced_text
|
||||
return enhanced_text
|
||||
else:
|
||||
if current_tokens > self.max_tokens:
|
||||
truncated = self.token_manager.truncate_to_tokens(
|
||||
chunk_text, self.max_tokens
|
||||
)
|
||||
self._enhance_cache[cache_key] = truncated
|
||||
return truncated
|
||||
|
||||
self._enhance_cache[cache_key] = chunk_text
|
||||
return chunk_text
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error procesando chunk: {e}")
|
||||
if current_tokens > self.max_tokens:
|
||||
truncated = self.token_manager.truncate_to_tokens(
|
||||
chunk_text, self.max_tokens
|
||||
)
|
||||
self._enhance_cache[cache_key] = truncated
|
||||
return truncated
|
||||
|
||||
self._enhance_cache[cache_key] = chunk_text
|
||||
return chunk_text
|
||||
|
||||
def process_chunks_batch(
|
||||
self, chunks: List[LangchainDocument], merge_related: bool = False
|
||||
) -> List[LangchainDocument]:
|
||||
processed_chunks = []
|
||||
total_chunks = len(chunks)
|
||||
|
||||
print(f"Procesando {total_chunks} chunks en lotes de {self.chunks_per_batch}")
|
||||
if self.custom_instructions:
|
||||
print(
|
||||
f"Con instrucciones personalizadas: {self.custom_instructions[:100]}..."
|
||||
)
|
||||
|
||||
i = 0
|
||||
while i < len(chunks):
|
||||
batch_start = time.time()
|
||||
current_chunk = chunks[i]
|
||||
merged_content = current_chunk.page_content
|
||||
original_tokens = self.token_manager.count_tokens(merged_content)
|
||||
|
||||
if merge_related and i < len(chunks) - 1:
|
||||
merge_count = 0
|
||||
while i + merge_count < len(chunks) - 1 and self.should_merge_chunks(
|
||||
merged_content, chunks[i + merge_count + 1].page_content
|
||||
):
|
||||
merge_count += 1
|
||||
merged_content += "\n\n" + chunks[i + merge_count].page_content
|
||||
print(f" Uniendo chunk {i + 1} con chunk {i + merge_count + 1}")
|
||||
|
||||
i += merge_count
|
||||
|
||||
print(f"\nProcesando chunk {i + 1}/{total_chunks}")
|
||||
print(f" Tokens originales: {original_tokens}")
|
||||
|
||||
enhanced_content = self.enhance_chunk(merged_content)
|
||||
final_tokens = self.token_manager.count_tokens(enhanced_content)
|
||||
|
||||
processed_chunks.append(
|
||||
LangchainDocument(
|
||||
page_content=enhanced_content,
|
||||
metadata={
|
||||
**current_chunk.metadata,
|
||||
"final_tokens": final_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
print(f" Tokens finales: {final_tokens}")
|
||||
print(f" Tiempo de procesamiento: {time.time() - batch_start:.2f}s")
|
||||
|
||||
i += 1
|
||||
|
||||
if i % self.chunks_per_batch == 0 and i < len(chunks):
|
||||
print(f"\nCompletados {i}/{total_chunks} chunks")
|
||||
time.sleep(0.1)
|
||||
|
||||
return processed_chunks
|
||||
|
||||
|
||||
class LLMChunker(BaseChunker):
|
||||
"""Implements a chunker that uses an LLM to optimize PDF and text content."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str,
|
||||
model: str,
|
||||
max_tokens: int = 1000,
|
||||
target_tokens: int = 800,
|
||||
gemini_client: VertexAILLM | None = None,
|
||||
custom_instructions: str = "",
|
||||
extract_images: bool = True,
|
||||
max_workers: int = 4,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
merge_related: bool = True,
|
||||
):
|
||||
self.output_dir = output_dir
|
||||
self.model = model
|
||||
self.client = gemini_client
|
||||
self.max_workers = max_workers
|
||||
self.token_manager = TokenManager()
|
||||
self.custom_instructions = custom_instructions
|
||||
self.extract_images = extract_images
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.merge_related = merge_related
|
||||
self._format_cache = {}
|
||||
|
||||
self.chunk_processor = OptimizedChunkProcessor(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
target_tokens=target_tokens,
|
||||
gemini_client=gemini_client,
|
||||
custom_instructions=custom_instructions,
|
||||
)
|
||||
|
||||
def process_text(self, text: str) -> List[Document]:
|
||||
"""Processes raw text using the LLM optimizer."""
|
||||
print("\n=== Iniciando procesamiento de texto ===")
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
length_function=self.token_manager.count_tokens,
|
||||
separators=["\n\n", "\n", ". ", " ", ""],
|
||||
)
|
||||
# Create dummy LangchainDocuments for compatibility with process_chunks_batch
|
||||
langchain_docs = text_splitter.create_documents([text])
|
||||
|
||||
processed_docs = self.chunk_processor.process_chunks_batch(
|
||||
langchain_docs, self.merge_related
|
||||
)
|
||||
|
||||
# Convert from LangchainDocument to our Document TypedDict
|
||||
final_documents: List[Document] = [
|
||||
{"page_content": doc.page_content, "metadata": doc.metadata}
|
||||
for doc in processed_docs
|
||||
]
|
||||
print(
|
||||
f"\n=== Procesamiento de texto completado: {len(final_documents)} chunks creados ==="
|
||||
)
|
||||
return final_documents
|
||||
|
||||
def process_path(self, path: Path) -> List[Document]:
|
||||
"""Processes a PDF file, extracts text and images, and optimizes chunks."""
|
||||
overall_start = time.time()
|
||||
print(f"\n=== Iniciando procesamiento optimizado de PDF: {path.name} ===")
|
||||
# ... (rest of the logic from process_pdf_optimized)
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir)
|
||||
|
||||
print("\n1. Creando chunks del PDF...")
|
||||
chunks = self._create_optimized_chunks(
|
||||
str(path), self.chunk_size, self.chunk_overlap
|
||||
)
|
||||
print(f" Total chunks creados: {len(chunks)}")
|
||||
|
||||
pages_to_extract = set()
|
||||
if self.extract_images:
|
||||
print("\n2. Detectando formatos especiales...")
|
||||
format_results = self.detect_special_format_batch(chunks)
|
||||
for i, has_special_format in format_results.items():
|
||||
if has_special_format:
|
||||
page_number = chunks[i].metadata.get("page")
|
||||
if page_number:
|
||||
pages_to_extract.add(page_number)
|
||||
print(f" Páginas con formato especial: {sorted(pages_to_extract)}")
|
||||
|
||||
if self.extract_images and pages_to_extract:
|
||||
print(f"\n3. Extrayendo {len(pages_to_extract)} páginas como imágenes...")
|
||||
self._extract_pages_parallel(str(path), self.output_dir, pages_to_extract)
|
||||
|
||||
print("\n4. Procesando y optimizando chunks...")
|
||||
processed_chunks = self.chunk_processor.process_chunks_batch(
|
||||
chunks, self.merge_related
|
||||
)
|
||||
|
||||
if self.extract_images:
|
||||
final_chunks = self._add_image_references(
|
||||
processed_chunks, pages_to_extract, str(path), self.output_dir
|
||||
)
|
||||
else:
|
||||
final_chunks = processed_chunks
|
||||
|
||||
total_time = time.time() - overall_start
|
||||
print(f"\n=== Procesamiento completado en {total_time:.2f}s ===")
|
||||
|
||||
# Convert from LangchainDocument to our Document TypedDict
|
||||
final_documents: List[Document] = [
|
||||
{"page_content": doc.page_content, "metadata": doc.metadata}
|
||||
for doc in final_chunks
|
||||
]
|
||||
return final_documents
|
||||
|
||||
def detect_special_format_batch(
|
||||
self, chunks: List[LangchainDocument]
|
||||
) -> dict[int, bool]:
|
||||
results = {}
|
||||
chunks_to_process = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
cache_key = hashlib.md5(chunk.page_content.encode()).hexdigest()[:16]
|
||||
if cache_key in self._format_cache:
|
||||
results[i] = self._format_cache[cache_key]
|
||||
else:
|
||||
chunks_to_process.append((i, chunk, cache_key))
|
||||
|
||||
if not chunks_to_process:
|
||||
return results
|
||||
|
||||
if self.client and len(chunks_to_process) > 1:
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=min(self.max_workers, len(chunks_to_process))
|
||||
) as executor:
|
||||
futures = {
|
||||
executor.submit(self._detect_single_format, chunk): (i, cache_key)
|
||||
for i, chunk, cache_key in chunks_to_process
|
||||
}
|
||||
for future in futures:
|
||||
i, cache_key = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
results[i] = result
|
||||
self._format_cache[cache_key] = result
|
||||
except Exception as e:
|
||||
print(f"Error procesando chunk {i}: {e}")
|
||||
results[i] = False
|
||||
else:
|
||||
for i, chunk, cache_key in chunks_to_process:
|
||||
result = self._detect_single_format(chunk)
|
||||
results[i] = result
|
||||
self._format_cache[cache_key] = result
|
||||
return results
|
||||
|
||||
def _detect_single_format(self, chunk: LangchainDocument) -> bool:
|
||||
if not self.client:
|
||||
content = chunk.page_content
|
||||
table_indicators = ["│", "├", "┼", "┤", "┬", "┴", "|", "+", "-"]
|
||||
has_table_chars = any(char in content for char in table_indicators)
|
||||
has_multiple_columns = content.count("\t") > 10 or content.count(" ") > 20
|
||||
return has_table_chars or has_multiple_columns
|
||||
try:
|
||||
prompt = f"""¿Contiene este texto tablas estructuradas, diagramas ASCII, o elementos que requieren formato especial?
|
||||
|
||||
Responde SOLO 'SI' o 'NO'.
|
||||
|
||||
Texto:
|
||||
{chunk.page_content[:1000]}"""
|
||||
response = self.client.generate(self.model, prompt).text
|
||||
return response.strip().upper() == "SI"
|
||||
except Exception as e:
|
||||
print(f"Error detectando formato: {e}")
|
||||
return False
|
||||
|
||||
def _create_optimized_chunks(
|
||||
self, pdf_path: str, chunk_size: int, chunk_overlap: int
|
||||
) -> List[LangchainDocument]:
|
||||
pdf = PdfReader(pdf_path)
|
||||
chunks = []
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=self.token_manager.count_tokens,
|
||||
separators=["\n\n", "\n", ". ", " ", ""],
|
||||
)
|
||||
for page_num, page in enumerate(pdf.pages, 1):
|
||||
text = page.extract_text()
|
||||
if text.strip():
|
||||
page_chunks = text_splitter.create_documents(
|
||||
[text],
|
||||
metadatas=[
|
||||
{
|
||||
"page": page_num,
|
||||
"file_name": os.path.basename(pdf_path),
|
||||
}
|
||||
],
|
||||
)
|
||||
chunks.extend(page_chunks)
|
||||
return chunks
|
||||
|
||||
def _extract_pages_parallel(self, pdf_path: str, output_dir: str, pages: set):
|
||||
def extract_single_page(page_number):
|
||||
try:
|
||||
pdf_filename = os.path.basename(pdf_path)
|
||||
image_path = os.path.join(
|
||||
output_dir, f"{page_number}_{pdf_filename}.png"
|
||||
)
|
||||
images = convert_from_path(
|
||||
pdf_path,
|
||||
first_page=page_number,
|
||||
last_page=page_number,
|
||||
dpi=150,
|
||||
thread_count=1,
|
||||
grayscale=False,
|
||||
)
|
||||
if images:
|
||||
images[0].save(image_path, "PNG", optimize=True)
|
||||
except Exception as e:
|
||||
print(f" Error extrayendo página {page_number}: {e}")
|
||||
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=min(self.max_workers, len(pages))
|
||||
) as executor:
|
||||
futures = [executor.submit(extract_single_page, page) for page in pages]
|
||||
for future in futures:
|
||||
future.result() # Wait for completion
|
||||
|
||||
def _add_image_references(
|
||||
self,
|
||||
chunks: List[LangchainDocument],
|
||||
pages_to_extract: set,
|
||||
pdf_path: str,
|
||||
output_dir: str,
|
||||
) -> List[LangchainDocument]:
|
||||
pdf_filename = os.path.basename(pdf_path)
|
||||
for chunk in chunks:
|
||||
page_number = chunk.metadata.get("page")
|
||||
if page_number in pages_to_extract:
|
||||
image_path = os.path.join(
|
||||
output_dir, f"page_{page_number}_{pdf_filename}.png"
|
||||
)
|
||||
if os.path.exists(image_path):
|
||||
image_reference = (
|
||||
f"\n[IMAGEN DISPONIBLE - Página {page_number}: {image_path}]\n"
|
||||
)
|
||||
chunk.page_content = image_reference + chunk.page_content
|
||||
chunk.metadata["has_image"] = True
|
||||
chunk.metadata["image_path"] = image_path
|
||||
return chunks
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
pdf_path: Annotated[str, typer.Argument(help="Ruta al archivo PDF")],
|
||||
output_dir: Annotated[
|
||||
str, typer.Argument(help="Directorio de salida para imágenes y chunks")
|
||||
],
|
||||
model: Annotated[
|
||||
str, typer.Option(help="Modelo a usar para el procesamiento")
|
||||
] = "gemini-2.0-flash",
|
||||
max_tokens: Annotated[
|
||||
int, typer.Option(help="Límite máximo de tokens por chunk")
|
||||
] = 950,
|
||||
target_tokens: Annotated[
|
||||
int, typer.Option(help="Tokens objetivo para optimización")
|
||||
] = 800,
|
||||
chunk_size: Annotated[int, typer.Option(help="Tamaño base de chunks")] = 1000,
|
||||
chunk_overlap: Annotated[int, typer.Option(help="Solapamiento entre chunks")] = 200,
|
||||
merge_related: Annotated[
|
||||
bool, typer.Option(help="Si unir chunks relacionados")
|
||||
] = True,
|
||||
custom_instructions: Annotated[
|
||||
str, typer.Option(help="Instrucciones adicionales para optimización")
|
||||
] = "",
|
||||
extract_images: Annotated[
|
||||
bool,
|
||||
typer.Option(help="Si True, extrae páginas con formato especial como imágenes"),
|
||||
] = True,
|
||||
):
|
||||
"""
|
||||
Función principal para procesar PDFs con control completo de tokens.
|
||||
"""
|
||||
settings = Settings()
|
||||
llm = VertexAILLM(
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
)
|
||||
|
||||
chunker = LLMChunker(
|
||||
output_dir=output_dir,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
target_tokens=target_tokens,
|
||||
gemini_client=llm,
|
||||
custom_instructions=custom_instructions,
|
||||
extract_images=extract_images,
|
||||
max_workers=4,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
merge_related=merge_related,
|
||||
)
|
||||
|
||||
documents = chunker.process_path(Path(pdf_path))
|
||||
print(f"Processed {len(documents)} documents.")
|
||||
|
||||
output_file_path = os.path.join(output_dir, "chunked_documents.jsonl")
|
||||
with open(output_file_path, "w", encoding="utf-8") as f:
|
||||
for doc in documents:
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Saved {len(documents)} documents to {output_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
0
packages/chunker/src/chunker/py.typed
Normal file
0
packages/chunker/src/chunker/py.typed
Normal file
80
packages/chunker/src/chunker/recursive_chunker.py
Normal file
80
packages/chunker/src/chunker/recursive_chunker.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List
|
||||
|
||||
import chonkie
|
||||
import typer
|
||||
|
||||
from .base_chunker import BaseChunker, Document
|
||||
|
||||
|
||||
class RecursiveChunker(BaseChunker):
|
||||
"""A chunker that uses the chonkie RecursiveChunker."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes the RecursiveChunker."""
|
||||
self.processor = chonkie.RecursiveChunker()
|
||||
|
||||
def process_text(self, text: str) -> List[Document]:
|
||||
"""
|
||||
Processes a string of text into a list of Document chunks.
|
||||
|
||||
Args:
|
||||
text: The input string to process.
|
||||
|
||||
Returns:
|
||||
A list of Document objects.
|
||||
"""
|
||||
chunks = self.processor(text)
|
||||
documents: List[Document] = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
doc: Document = {
|
||||
"page_content": chunk.text,
|
||||
"metadata": {"chunk_index": i},
|
||||
}
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_file_path: Annotated[
|
||||
str, typer.Argument(help="Path to the input text file.")
|
||||
],
|
||||
output_dir: Annotated[
|
||||
str, typer.Argument(help="Directory to save the output file.")
|
||||
],
|
||||
):
|
||||
"""
|
||||
Processes a text file using RecursiveChunker and saves the output to a JSONL file.
|
||||
"""
|
||||
print(f"Starting to process {input_file_path}...")
|
||||
|
||||
# 1. Instantiate chunker and process the file using the inherited method
|
||||
chunker = RecursiveChunker()
|
||||
documents = chunker.process_path(Path(input_file_path))
|
||||
|
||||
print(f"Successfully created {len(documents)} chunks.")
|
||||
|
||||
# 2. Prepare and save the output
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
print(f"Created output directory: {output_dir}")
|
||||
|
||||
output_file_path = os.path.join(output_dir, "chunked_documents.jsonl")
|
||||
|
||||
with open(output_file_path, "w", encoding="utf-8") as f:
|
||||
for doc in documents:
|
||||
# Add source file info to metadata before writing
|
||||
doc["metadata"]["source_file"] = os.path.basename(input_file_path)
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Successfully saved {len(documents)} chunks to {output_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
0
packages/dialogflow/README.md
Normal file
0
packages/dialogflow/README.md
Normal file
21
packages/dialogflow/pyproject.toml
Normal file
21
packages/dialogflow/pyproject.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[project]
|
||||
name = "dialogflow"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"google-cloud-dialogflow-cx",
|
||||
"typer",
|
||||
"rich"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
conv-agents = "dialogflow.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.12,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
2
packages/dialogflow/src/dialogflow/__init__.py
Normal file
2
packages/dialogflow/src/dialogflow/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from dialogflow!"
|
||||
54
packages/dialogflow/src/dialogflow/cli.py
Normal file
54
packages/dialogflow/src/dialogflow/cli.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import typer
|
||||
from rich import print
|
||||
|
||||
from .main import DialogflowAgent
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@app.command()
|
||||
def send(
|
||||
message: str = typer.Argument(..., help="The message to send to the agent."),
|
||||
flow_id: str = typer.Option(None, "--flow-id", "-f", help="The specific flow ID to target."),
|
||||
):
|
||||
"""
|
||||
Sends a message to the Dialogflow CX agent and prints the response.
|
||||
"""
|
||||
try:
|
||||
agent = DialogflowAgent()
|
||||
response = agent.call(query=message, flow_id=flow_id)
|
||||
print(f"Agent: {response['response_text']}")
|
||||
print(f"[dim]Match Type: {response['match_type']} | Details: {response['details']}[/dim]")
|
||||
except Exception as e:
|
||||
print(f"[bold red]Error:[/bold red] {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
@app.command()
|
||||
def chat(
|
||||
flow_id: str = typer.Option(None, "--flow-id", "-f", help="The specific flow ID to start the conversation with."),
|
||||
):
|
||||
"""
|
||||
Starts a multi-turn, stateful conversation with the Dialogflow CX agent.
|
||||
"""
|
||||
try:
|
||||
agent = DialogflowAgent()
|
||||
print("[bold green]Starting a new conversation. Type 'exit' or 'quit' to end.[/bold green]")
|
||||
|
||||
is_first_message = True
|
||||
while True:
|
||||
message = input("You: ")
|
||||
if message.lower() in ["exit", "quit"]:
|
||||
print("[bold yellow]Ending conversation.[/bold yellow]")
|
||||
break
|
||||
|
||||
# Only use the flow_id for the very first message in the chat session
|
||||
response = agent.call(query=message, flow_id=flow_id if is_first_message else None)
|
||||
is_first_message = False
|
||||
|
||||
print(f"Agent: {response['response_text']}")
|
||||
print(f"[dim]Match Type: {response['match_type']} | Details: {response['details']}[/dim]")
|
||||
except Exception as e:
|
||||
print(f"[bold red]Error:[/bold red] {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
80
packages/dialogflow/src/dialogflow/main.py
Normal file
80
packages/dialogflow/src/dialogflow/main.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import uuid
|
||||
|
||||
from google.api_core import client_options
|
||||
from google.cloud import dialogflowcx_v3beta1 as dialogflow
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
|
||||
class DialogflowAgent:
|
||||
"""A class to interact with a Dialogflow CX agent."""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the DialogflowAgent.
|
||||
"""
|
||||
self.agent_path = (
|
||||
f"projects/{settings.project_id}/locations/{settings.location}"
|
||||
f"/agents/{settings.dialogflow_agent_id}"
|
||||
)
|
||||
|
||||
api_endpoint = f"{settings.location}-dialogflow.googleapis.com"
|
||||
client_opts = client_options.ClientOptions(api_endpoint=api_endpoint)
|
||||
|
||||
self.sessions_client = dialogflow.SessionsClient(client_options=client_opts)
|
||||
self.session_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
def call(self, query: str, flow_id: str | None = None, session_id: str | None = None) -> dict:
|
||||
"""
|
||||
Sends a message to the Dialogflow CX agent and gets the response.
|
||||
|
||||
Args:
|
||||
query: The message to send to the agent.
|
||||
flow_id: The specific flow to target within the agent.
|
||||
session_id: The session ID to use for the conversation. If not provided, the instance's default session ID is used.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the agent's response and other details.
|
||||
"""
|
||||
current_session_id = session_id or self.session_id
|
||||
session_path = f"{self.agent_path}/sessions/{current_session_id}"
|
||||
|
||||
text_input = dialogflow.TextInput(text=query)
|
||||
query_input = dialogflow.QueryInput(text=text_input, language_code="ES")
|
||||
|
||||
request = dialogflow.DetectIntentRequest(
|
||||
session=session_path,
|
||||
query_input=query_input,
|
||||
)
|
||||
|
||||
if flow_id:
|
||||
flow_path = f"{self.agent_path}/flows/{flow_id}"
|
||||
request.query_params = dialogflow.QueryParameters(flow=flow_path)
|
||||
|
||||
response = self.sessions_client.detect_intent(request=request)
|
||||
query_result = response.query_result
|
||||
|
||||
response_messages = [
|
||||
" ".join(msg.text.text) for msg in query_result.response_messages if msg.text
|
||||
]
|
||||
response_text = " ".join(response_messages)
|
||||
|
||||
match_type = query_result.match.match_type
|
||||
|
||||
details = {
|
||||
"response_text": response_text,
|
||||
"match_type": match_type.name,
|
||||
"details": "N/A"
|
||||
}
|
||||
|
||||
if match_type == dialogflow.Match.MatchType.PLAYBOOK:
|
||||
playbook = query_result.generative_info.current_playbooks[0].split('/')[-1] if query_result.generative_info.current_playbooks else "N/A"
|
||||
details["details"] = f"Playbook: {playbook}"
|
||||
elif match_type == dialogflow.Match.MatchType.INTENT:
|
||||
flow = query_result.current_flow.display_name if query_result.current_flow else "N/A"
|
||||
page = query_result.current_page.display_name if query_result.current_page else "N/A"
|
||||
intent = query_result.intent.display_name if query_result.intent else "N/A"
|
||||
details["details"] = f"Flow: {flow} | Page: {page} | Intent: {intent}"
|
||||
|
||||
return details
|
||||
0
packages/dialogflow/src/dialogflow/py.typed
Normal file
0
packages/dialogflow/src/dialogflow/py.typed
Normal file
1
packages/document-converter/.python-version
Normal file
1
packages/document-converter/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
0
packages/document-converter/README.md
Normal file
0
packages/document-converter/README.md
Normal file
20
packages/document-converter/pyproject.toml
Normal file
20
packages/document-converter/pyproject.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[project]
|
||||
name = "document-converter"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"markitdown[pdf]>=0.1.2",
|
||||
"pypdf>=6.1.2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
convert-md = "document_converter.markdown:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from document-converter!"
|
||||
35
packages/document-converter/src/document_converter/base.py
Normal file
35
packages/document-converter/src/document_converter/base.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class BaseConverter(ABC):
|
||||
"""
|
||||
Abstract base class for a remote file processor.
|
||||
|
||||
This class defines the interface for listing and processing files from a remote source.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def process_file(self, file: str) -> str:
|
||||
"""
|
||||
Processes a single file from a remote source and returns the result.
|
||||
|
||||
Args:
|
||||
file: The path to the file to be processed from the remote source.
|
||||
|
||||
Returns:
|
||||
A string containing the processing result for the file.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_files(self, files: List[str]) -> List[str]:
|
||||
"""
|
||||
Processes a list of files from a remote source and returns the results.
|
||||
|
||||
Args:
|
||||
files: A list of file paths to be processed from the remote source.
|
||||
|
||||
Returns:
|
||||
A list of strings containing the processing results for each file.
|
||||
"""
|
||||
return [self.process_file(file) for file in files]
|
||||
131
packages/document-converter/src/document_converter/markdown.py
Normal file
131
packages/document-converter/src/document_converter/markdown.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from pathlib import Path
|
||||
from typing import Annotated, BinaryIO, Union
|
||||
|
||||
import typer
|
||||
from markitdown import MarkItDown
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress
|
||||
|
||||
from .base import BaseConverter
|
||||
|
||||
|
||||
class MarkdownConverter(BaseConverter):
|
||||
"""Converts PDF documents to Markdown format."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes the MarkItDown converter."""
|
||||
self.markitdown = MarkItDown(enable_plugins=False)
|
||||
|
||||
def process_file(self, file_stream: Union[str, Path, BinaryIO]) -> str:
|
||||
"""
|
||||
Processes a single file and returns the result as a markdown string.
|
||||
|
||||
Args:
|
||||
file_stream: A file path (string or Path) or a binary file stream.
|
||||
|
||||
Returns:
|
||||
The converted markdown content as a string.
|
||||
"""
|
||||
result = self.markitdown.convert(file_stream)
|
||||
return result.text_content
|
||||
|
||||
|
||||
# --- CLI Application ---
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_path: Annotated[
|
||||
Path,
|
||||
typer.Argument(
|
||||
help="Path to the input PDF file or directory.",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
dir_okay=True,
|
||||
readable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
],
|
||||
output_path: Annotated[
|
||||
Path,
|
||||
typer.Argument(
|
||||
help="Path for the output Markdown file or directory.",
|
||||
file_okay=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Converts a PDF file or a directory of PDF files into Markdown.
|
||||
"""
|
||||
console = Console()
|
||||
converter = MarkdownConverter()
|
||||
|
||||
if input_path.is_dir():
|
||||
# --- Directory Processing ---
|
||||
console.print(f"[bold green]Processing directory:[/bold green] {input_path}")
|
||||
output_dir = output_path
|
||||
|
||||
if output_dir.exists() and not output_dir.is_dir():
|
||||
console.print(
|
||||
f"[bold red]Error:[/bold red] Input is a directory, but output path '{output_dir}' is an existing file."
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
pdf_files = sorted(list(input_path.rglob("*.pdf")))
|
||||
if not pdf_files:
|
||||
console.print("[yellow]No PDF files found in the input directory.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"Found {len(pdf_files)} PDF files to convert.")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with Progress(console=console) as progress:
|
||||
task = progress.add_task("[cyan]Converting...", total=len(pdf_files))
|
||||
for pdf_file in pdf_files:
|
||||
relative_path = pdf_file.relative_to(input_path)
|
||||
output_md_path = output_dir.joinpath(relative_path).with_suffix(".md")
|
||||
output_md_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
progress.update(task, description=f"Processing {pdf_file.name}")
|
||||
try:
|
||||
markdown_content = converter.process_file(pdf_file)
|
||||
output_md_path.write_text(markdown_content, encoding="utf-8")
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"\n[bold red]Failed to process {pdf_file.name}:[/bold red] {e}"
|
||||
)
|
||||
progress.advance(task)
|
||||
|
||||
console.print(
|
||||
f"[bold green]Conversion complete.[/bold green] Output directory: {output_dir}"
|
||||
)
|
||||
|
||||
elif input_path.is_file():
|
||||
# --- Single File Processing ---
|
||||
console.print(f"[bold green]Processing file:[/bold green] {input_path.name}")
|
||||
final_output_path = output_path
|
||||
|
||||
# If output path is a directory, create a file inside it
|
||||
if output_path.is_dir():
|
||||
final_output_path = output_path / input_path.with_suffix(".md").name
|
||||
|
||||
final_output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
markdown_content = converter.process_file(input_path)
|
||||
final_output_path.write_text(markdown_content, encoding="utf-8")
|
||||
console.print(
|
||||
f"[bold green]Successfully converted file to:[/bold green] {final_output_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error processing file:[/bold red] {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
1
packages/embedder/.python-version
Normal file
1
packages/embedder/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
0
packages/embedder/README.md
Normal file
0
packages/embedder/README.md
Normal file
16
packages/embedder/pyproject.toml
Normal file
16
packages/embedder/pyproject.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[project]
|
||||
name = "embedder"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"google-cloud-aiplatform>=1.106.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
0
packages/embedder/src/embedder/__init__.py
Normal file
0
packages/embedder/src/embedder/__init__.py
Normal file
79
packages/embedder/src/embedder/base.py
Normal file
79
packages/embedder/src/embedder/base.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseEmbedder(ABC):
|
||||
"""Base class for all embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def generate_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate embeddings for text.
|
||||
|
||||
Args:
|
||||
text: Single text string or list of texts
|
||||
|
||||
Returns:
|
||||
Single embedding vector or list of embedding vectors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a batch of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
pass
|
||||
|
||||
def preprocess_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
into_lowercase: bool = False,
|
||||
normalize_whitespace: bool = True,
|
||||
remove_punctuation: bool = False,
|
||||
) -> str:
|
||||
"""Preprocess text before embedding."""
|
||||
# Basic preprocessing
|
||||
text = text.strip()
|
||||
|
||||
if into_lowercase:
|
||||
text = text.lower()
|
||||
|
||||
if normalize_whitespace:
|
||||
text = " ".join(text.split())
|
||||
|
||||
if remove_punctuation:
|
||||
import string
|
||||
|
||||
text = text.translate(str.maketrans("", "", string.punctuation))
|
||||
|
||||
return text
|
||||
|
||||
def normalize_embedding(self, embedding: List[float]) -> List[float]:
|
||||
"""Normalize embedding vector to unit length."""
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm > 0:
|
||||
return (np.array(embedding) / norm).tolist()
|
||||
return embedding
|
||||
|
||||
@abstractmethod
|
||||
async def async_generate_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate embeddings for text.
|
||||
|
||||
Args:
|
||||
text: Single text string or list of texts
|
||||
|
||||
Returns:
|
||||
Single embedding vector or list of embedding vectors
|
||||
"""
|
||||
pass
|
||||
0
packages/embedder/src/embedder/py.typed
Normal file
0
packages/embedder/src/embedder/py.typed
Normal file
77
packages/embedder/src/embedder/vertex_ai.py
Normal file
77
packages/embedder/src/embedder/vertex_ai.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VertexAIEmbedder(BaseEmbedder):
|
||||
"""Embedder using Vertex AI text embedding models."""
|
||||
|
||||
def __init__(
|
||||
self, model_name: str, project: str, location: str, task: str = "RETRIEVAL_DOCUMENT"
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
self.task = task
|
||||
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=30)
|
||||
# )
|
||||
def generate_embedding(self, text: str) -> List[float]:
|
||||
preprocessed_text = self.preprocess_text(text)
|
||||
result = self.client.models.embed_content(
|
||||
model=self.model_name, contents=preprocessed_text, config=types.EmbedContentConfig(task_type=self.task)
|
||||
)
|
||||
return result.embeddings[0].values
|
||||
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=30)
|
||||
# )
|
||||
def generate_embeddings_batch(
|
||||
self, texts: List[str], batch_size: int = 10
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings for a batch of texts."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Preprocess texts
|
||||
preprocessed_texts = [self.preprocess_text(text) for text in texts]
|
||||
|
||||
# Process in batches if necessary
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(preprocessed_texts), batch_size):
|
||||
batch = preprocessed_texts[i : i + batch_size]
|
||||
|
||||
# Generate embeddings for batch
|
||||
result = self.client.models.embed_content(
|
||||
model=self.model_name, contents=batch, config=types.EmbedContentConfig(task_type=self.task)
|
||||
)
|
||||
|
||||
# Extract values
|
||||
batch_embeddings = [emb.values for emb in result.embeddings]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Rate limiting
|
||||
if i + batch_size < len(preprocessed_texts):
|
||||
time.sleep(0.1) # Small delay between batches
|
||||
|
||||
return all_embeddings
|
||||
|
||||
async def async_generate_embedding(self, text: str) -> List[float]:
|
||||
preprocessed_text = self.preprocess_text(text)
|
||||
result = await self.client.aio.models.embed_content(
|
||||
model=self.model_name, contents=preprocessed_text, config=types.EmbedContentConfig(task_type=self.task)
|
||||
)
|
||||
return result.embeddings[0].values
|
||||
1
packages/file-storage/.python-version
Normal file
1
packages/file-storage/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
0
packages/file-storage/README.md
Normal file
0
packages/file-storage/README.md
Normal file
22
packages/file-storage/pyproject.toml
Normal file
22
packages/file-storage/pyproject.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[project]
|
||||
name = "file-storage"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"gcloud-aio-storage>=9.6.1",
|
||||
"google-cloud-storage>=2.19.0",
|
||||
"aiohttp>=3.10.11,<4",
|
||||
"typer>=0.12.3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
file-storage = "file_storage.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
2
packages/file-storage/src/file_storage/__init__.py
Normal file
2
packages/file-storage/src/file_storage/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from file-storage!"
|
||||
48
packages/file-storage/src/file_storage/base.py
Normal file
48
packages/file-storage/src/file_storage/base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO, List, Optional
|
||||
|
||||
|
||||
class BaseFileStorage(ABC):
|
||||
"""
|
||||
Abstract base class for a remote file processor.
|
||||
|
||||
This class defines the interface for listing and processing files from a remote source.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Uploads a file to the remote source.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: The name of the file in the remote source.
|
||||
content_type: The content type of the file.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Lists files from a remote location.
|
||||
|
||||
Args:
|
||||
path: The path to a specific file or directory in the remote bucket.
|
||||
If None, it recursively lists all files in the bucket.
|
||||
|
||||
Returns:
|
||||
A list of file paths.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""
|
||||
Gets a file from the remote source and returns it as a file-like object.
|
||||
"""
|
||||
...
|
||||
89
packages/file-storage/src/file_storage/cli.py
Normal file
89
packages/file-storage/src/file_storage/cli.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
from typing import Annotated
|
||||
|
||||
import rich
|
||||
import typer
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
from .google_cloud import GoogleCloudFileStorage
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def get_storage_client() -> GoogleCloudFileStorage:
|
||||
return GoogleCloudFileStorage(bucket=settings.bucket)
|
||||
|
||||
|
||||
@app.command("upload")
|
||||
def upload(
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: Annotated[str, typer.Option()] = None,
|
||||
):
|
||||
"""
|
||||
Uploads a file or directory to the remote source.
|
||||
"""
|
||||
storage_client = get_storage_client()
|
||||
if os.path.isdir(file_path):
|
||||
for root, _, files in os.walk(file_path):
|
||||
for file in files:
|
||||
local_file_path = os.path.join(root, file)
|
||||
# preserve the directory structure and use forward slashes for blob name
|
||||
dest_blob_name = os.path.join(
|
||||
destination_blob_name, os.path.relpath(local_file_path, file_path)
|
||||
).replace(os.sep, "/")
|
||||
storage_client.upload_file(
|
||||
local_file_path, dest_blob_name, content_type
|
||||
)
|
||||
rich.print(
|
||||
f"[green]File {local_file_path} uploaded to {dest_blob_name}.[/green]"
|
||||
)
|
||||
rich.print(
|
||||
f"[bold green]Directory {file_path} uploaded to {destination_blob_name}.[/bold green]"
|
||||
)
|
||||
else:
|
||||
storage_client.upload_file(file_path, destination_blob_name, content_type)
|
||||
rich.print(
|
||||
f"[green]File {file_path} uploaded to {destination_blob_name}.[/green]"
|
||||
)
|
||||
|
||||
|
||||
@app.command("list")
|
||||
def list_items(path: Annotated[str, typer.Option()] = None):
|
||||
"""
|
||||
Obtain a list of all files at the given location inside the remote bucket
|
||||
If path is none, recursively shows all files in the remote bucket.
|
||||
"""
|
||||
storage_client = get_storage_client()
|
||||
files = storage_client.list_files(path)
|
||||
for file in files:
|
||||
rich.print(f"[blue]{file}[/blue]")
|
||||
|
||||
|
||||
@app.command("download")
|
||||
def download(file_name: str, destination_path: str):
|
||||
"""
|
||||
Gets a file from the remote source and returns it as a file-like object.
|
||||
"""
|
||||
storage_client = get_storage_client()
|
||||
file_stream = storage_client.get_file_stream(file_name)
|
||||
with open(destination_path, "wb") as f:
|
||||
f.write(file_stream.read())
|
||||
rich.print(f"[green]File {file_name} downloaded to {destination_path}[/green]")
|
||||
|
||||
|
||||
@app.command("delete")
|
||||
def delete(path: str):
|
||||
"""
|
||||
Deletes all files at the given location inside the remote bucket.
|
||||
If path is a single file, it will delete only that file.
|
||||
If path is a directory, it will delete all files in that directory.
|
||||
"""
|
||||
storage_client = get_storage_client()
|
||||
storage_client.delete_files(path)
|
||||
rich.print(f"[bold red]Files at {path} deleted.[/bold red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
138
packages/file-storage/src/file_storage/google_cloud.py
Normal file
138
packages/file-storage/src/file_storage/google_cloud.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
from typing import BinaryIO, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio.storage import Storage
|
||||
from google.cloud import storage
|
||||
|
||||
from .base import BaseFileStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleCloudFileStorage(BaseFileStorage):
|
||||
def __init__(self, bucket: str) -> None:
|
||||
self.bucket_name = bucket
|
||||
|
||||
self.storage_client = storage.Client()
|
||||
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._aio_storage: Storage | None = None
|
||||
self._cache: dict[str, bytes] = {}
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Uploads a file to the remote source.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: The name of the file in the remote source.
|
||||
content_type: The content type of the file.
|
||||
"""
|
||||
blob = self.bucket_client.blob(destination_blob_name)
|
||||
blob.upload_from_filename(
|
||||
file_path,
|
||||
content_type=content_type,
|
||||
if_generation_match=0,
|
||||
)
|
||||
self._cache.pop(destination_blob_name, None)
|
||||
|
||||
def list_files(self, path: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Obtain a list of all files at the given location inside the remote bucket
|
||||
If path is none, recursively shows all files in the remote bucket.
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(self.bucket_name, prefix=path)
|
||||
return [blob.name for blob in blobs]
|
||||
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""
|
||||
Gets a file from the remote source and returns it as a file-like object.
|
||||
"""
|
||||
if file_name not in self._cache:
|
||||
blob = self.bucket_client.blob(file_name)
|
||||
self._cache[file_name] = blob.download_as_bytes()
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(limit=300, limit_per_host=50)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def _get_aio_storage(self) -> Storage:
|
||||
if self._aio_storage is None:
|
||||
self._aio_storage = Storage(session=self._get_aio_session())
|
||||
return self._aio_storage
|
||||
|
||||
async def async_get_file_stream(
|
||||
self, file_name: str, max_retries: int = 3
|
||||
) -> BinaryIO:
|
||||
"""
|
||||
Gets a file from the remote source asynchronously and returns it as a file-like object.
|
||||
Retries on transient errors (429, 5xx, timeouts) with exponential backoff.
|
||||
"""
|
||||
if file_name in self._cache:
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
storage_client = self._get_aio_storage()
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self._cache[file_name] = await storage_client.download(
|
||||
self.bucket_name, file_name
|
||||
)
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
except asyncio.TimeoutError as exc:
|
||||
last_exception = exc
|
||||
logger.warning(
|
||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
||||
self.bucket_name, file_name, attempt + 1, max_retries,
|
||||
)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
last_exception = exc
|
||||
if exc.status == 429 or exc.status >= 500:
|
||||
logger.warning(
|
||||
"HTTP %d downloading gs://%s/%s (attempt %d/%d)",
|
||||
exc.status, self.bucket_name, file_name,
|
||||
attempt + 1, max_retries,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
delay = 0.5 * (2 ** attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
||||
f"after {max_retries} attempts"
|
||||
) from last_exception
|
||||
|
||||
def delete_files(self, path: str) -> None:
|
||||
"""
|
||||
Deletes all files at the given location inside the remote bucket.
|
||||
If path is a single file, it will delete only that file.
|
||||
If path is a directory, it will delete all files in that directory.
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(self.bucket_name, prefix=path)
|
||||
for blob in blobs:
|
||||
blob.delete()
|
||||
self._cache.pop(blob.name, None)
|
||||
0
packages/file-storage/src/file_storage/py.typed
Normal file
0
packages/file-storage/src/file_storage/py.typed
Normal file
1
packages/llm/.python-version
Normal file
1
packages/llm/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
0
packages/llm/README.md
Normal file
0
packages/llm/README.md
Normal file
18
packages/llm/pyproject.toml
Normal file
18
packages/llm/pyproject.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[project]
|
||||
name = "llm"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"google-genai>=1.20.0",
|
||||
"pydantic>=2.11.7",
|
||||
"tenacity>=9.1.2",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
2
packages/llm/src/llm/__init__.py
Normal file
2
packages/llm/src/llm/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from llm!"
|
||||
128
packages/llm/src/llm/base.py
Normal file
128
packages/llm/src/llm/base.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
name: str
|
||||
arguments: dict
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int | None = 0
|
||||
thought_tokens: int | None = 0
|
||||
response_tokens: int | None = 0
|
||||
|
||||
@field_validator("prompt_tokens", "thought_tokens", "response_tokens", mode="before")
|
||||
@classmethod
|
||||
def _validate_tokens(cls, v: int | None) -> int:
|
||||
return v or 0
|
||||
|
||||
def __add__(self, other):
|
||||
return Usage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
thought_tokens=self.thought_tokens + other.thought_tokens,
|
||||
response_tokens=self.response_tokens + other.response_tokens
|
||||
)
|
||||
|
||||
def get_cost(self, name: str) -> int:
|
||||
million = 1000000
|
||||
if name == "gemini-2.5-pro":
|
||||
if self.prompt_tokens > 200000:
|
||||
input_cost = self.prompt_tokens * (2.5/million)
|
||||
output_cost = self.thought_tokens * (15/million) + self.response_tokens * (15/million)
|
||||
else:
|
||||
input_cost = self.prompt_tokens * (1.25/million)
|
||||
output_cost = self.thought_tokens * (10/million) + self.response_tokens * (10/million)
|
||||
return (input_cost + output_cost) * 18.65
|
||||
if name == "gemini-2.5-flash":
|
||||
input_cost = self.prompt_tokens * (0.30/million)
|
||||
output_cost = self.thought_tokens * (2.5/million) + self.response_tokens * (2.5/million)
|
||||
return (input_cost + output_cost) * 18.65
|
||||
else:
|
||||
raise Exception("Invalid model")
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
"""A class to represent a single generation from a model.
|
||||
|
||||
Attributes:
|
||||
text: The generated text.
|
||||
usage: A dictionary containing usage metadata.
|
||||
"""
|
||||
|
||||
text: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
usage: Usage = Usage()
|
||||
extra: dict = {}
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""An abstract base class for all LLMs."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
tools: list | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> Generation:
|
||||
"""Generates text from a prompt.
|
||||
|
||||
Args:
|
||||
model: The model to use for generation.
|
||||
prompt: The prompt to generate text from.
|
||||
tools: An optional list of tools to use for generation.
|
||||
system_prompt: An optional system prompt to guide the model's behavior.
|
||||
|
||||
Returns:
|
||||
A Generation object containing the generated text and usage metadata.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def structured_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
response_model: Type[T],
|
||||
tools: list | None = None,
|
||||
) -> T:
|
||||
"""Generates structured data from a prompt.
|
||||
|
||||
Args:
|
||||
model: The model to use for generation.
|
||||
prompt: The prompt to generate text from.
|
||||
response_model: The pydantic model to parse the response into.
|
||||
tools: An optional list of tools to use for generation.
|
||||
|
||||
Returns:
|
||||
An instance of the provided pydantic model.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def async_generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
tools: list | None = None,
|
||||
system_prompt: str | None = None,
|
||||
tool_mode: str = "AUTO",
|
||||
) -> Generation:
|
||||
"""Generates text from a prompt.
|
||||
|
||||
Args:
|
||||
model: The model to use for generation.
|
||||
prompt: The prompt to generate text from.
|
||||
tools: An optional list of tools to use for generation.
|
||||
system_prompt: An optional system prompt to guide the model's behavior.
|
||||
|
||||
Returns:
|
||||
A Generation object containing the generated text and usage metadata.
|
||||
"""
|
||||
...
|
||||
0
packages/llm/src/llm/py.typed
Normal file
0
packages/llm/src/llm/py.typed
Normal file
181
packages/llm/src/llm/vertex_ai.py
Normal file
181
packages/llm/src/llm/vertex_ai.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import logging
|
||||
from typing import Any, Type
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
from .base import BaseLLM, Generation, T, ToolCall, Usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VertexAILLM(BaseLLM):
|
||||
"""A class for interacting with the Vertex AI API."""
|
||||
|
||||
def __init__(
|
||||
self, project: str | None = None, location: str | None = None, thinking: int = 0
|
||||
) -> None:
|
||||
"""Initializes the VertexAILLM client.
|
||||
Args:
|
||||
project: The Google Cloud project ID.
|
||||
location: The Google Cloud location.
|
||||
"""
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=project or settings.project_id,
|
||||
location=location or settings.location,
|
||||
)
|
||||
self.thinking_budget = thinking
|
||||
|
||||
# @retry(
|
||||
# wait=wait_exponential(multiplier=1, min=2, max=60),
|
||||
# stop=stop_after_attempt(3),
|
||||
# reraise=True,
|
||||
# )
|
||||
def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
tools: list = [],
|
||||
system_prompt: str | None = None,
|
||||
tool_mode: str = "AUTO",
|
||||
) -> Generation:
|
||||
"""Generates text using the specified model and prompt.
|
||||
Args:
|
||||
model: The name of the model to use for generation.
|
||||
prompt: The prompt to use for generation.
|
||||
tools: A list of tools to use for generation.
|
||||
system_prompt: An optional system prompt to guide the model's behavior.
|
||||
Returns:
|
||||
A Generation object containing the generated text and usage metadata.
|
||||
"""
|
||||
logger.debug("Entering VertexAILLM.generate")
|
||||
logger.debug(f"Model: {model}, Tool Mode: {tool_mode}")
|
||||
logger.debug(f"System prompt: {system_prompt}")
|
||||
logger.debug("Calling Vertex AI API: models.generate_content...")
|
||||
response = self.client.models.generate_content(
|
||||
model=model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
tools=tools,
|
||||
system_instruction=system_prompt,
|
||||
thinking_config=genai.types.ThinkingConfig(
|
||||
thinking_budget=self.thinking_budget
|
||||
),
|
||||
tool_config=types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
mode=tool_mode
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
logger.debug("Received response from Vertex AI API.")
|
||||
logger.debug(f"API Response: {response}")
|
||||
|
||||
return self._create_generation(response)
|
||||
|
||||
|
||||
# @retry(
|
||||
# wait=wait_exponential(multiplier=1, min=2, max=60),
|
||||
# stop=stop_after_attempt(3),
|
||||
# reraise=True,
|
||||
# )
|
||||
def structured_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
response_model: Type[T],
|
||||
system_prompt: str | None = None,
|
||||
tools: list | None = None,
|
||||
) -> T:
|
||||
"""Generates structured data from a prompt.
|
||||
Args:
|
||||
model: The model to use for generation.
|
||||
prompt: The prompt to generate text from.
|
||||
response_model: The pydantic model to parse the response into.
|
||||
tools: An optional list of tools to use for generation.
|
||||
Returns:
|
||||
An instance of the provided pydantic model.
|
||||
"""
|
||||
config = genai.types.GenerateContentConfig(
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_model,
|
||||
system_instruction=system_prompt,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
response: genai.types.GenerateContentResponse = (
|
||||
self.client.models.generate_content(
|
||||
model=model, contents=prompt, config=config
|
||||
)
|
||||
)
|
||||
|
||||
return response_model.model_validate_json(response.text)
|
||||
|
||||
# @retry(
|
||||
# wait=wait_exponential(multiplier=1, min=2, max=60),
|
||||
# stop=stop_after_attempt(3),
|
||||
# reraise=True,
|
||||
# )
|
||||
async def async_generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
tools: list = [],
|
||||
system_prompt: str | None = None,
|
||||
tool_mode: str = "AUTO",
|
||||
) -> Generation:
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
tools=tools,
|
||||
system_instruction=system_prompt,
|
||||
thinking_config=genai.types.ThinkingConfig(
|
||||
thinking_budget=self.thinking_budget
|
||||
),
|
||||
tool_config=types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
mode=tool_mode
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return self._create_generation(response)
|
||||
|
||||
|
||||
def _create_generation(self, response):
|
||||
logger.debug("Creating Generation object from API response.")
|
||||
m=response.usage_metadata
|
||||
usage = Usage(
|
||||
prompt_tokens=m.prompt_token_count,
|
||||
thought_tokens=m.thoughts_token_count or 0,
|
||||
response_tokens=m.candidates_token_count
|
||||
)
|
||||
|
||||
logger.debug(f"{usage=}")
|
||||
logger.debug(f"{response=}")
|
||||
|
||||
candidate = response.candidates[0]
|
||||
|
||||
tool_calls = []
|
||||
|
||||
for part in candidate.content.parts:
|
||||
if fn := part.function_call:
|
||||
tool_calls.append(ToolCall(name=fn.name, arguments=fn.args))
|
||||
|
||||
if len(tool_calls) > 0:
|
||||
logger.debug(f"Found {len(tool_calls)} tool calls.")
|
||||
return Generation(
|
||||
tool_calls=tool_calls,
|
||||
usage=usage,
|
||||
extra={"original_content": candidate.content}
|
||||
)
|
||||
|
||||
logger.debug("No tool calls found, returning text response.")
|
||||
text = candidate.content.parts[0].text
|
||||
return Generation(text=text, usage=usage)
|
||||
0
packages/utils/README.md
Normal file
0
packages/utils/README.md
Normal file
17
packages/utils/pyproject.toml
Normal file
17
packages/utils/pyproject.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[project]
|
||||
name = "utils"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = []
|
||||
|
||||
[project.scripts]
|
||||
normalize-filenames = "utils.normalize_filenames:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
2
packages/utils/src/utils/__init__.py
Normal file
2
packages/utils/src/utils/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from utils!"
|
||||
115
packages/utils/src/utils/normalize_filenames.py
Normal file
115
packages/utils/src/utils/normalize_filenames.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Normalize filenames in a directory."""
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def normalize_string(s: str) -> str:
|
||||
"""Normalizes a string to be a valid filename."""
|
||||
# 1. Decompose Unicode characters into base characters and diacritics
|
||||
nfkd_form = unicodedata.normalize("NFKD", s)
|
||||
# 2. Keep only the base characters (non-diacritics)
|
||||
only_ascii = "".join([c for c in nfkd_form if not unicodedata.combining(c)])
|
||||
# 3. To lowercase
|
||||
only_ascii = only_ascii.lower()
|
||||
# 4. Replace spaces with underscores
|
||||
only_ascii = re.sub(r"\s+", "_", only_ascii)
|
||||
# 5. Remove any characters that are not alphanumeric, underscores, dots, or hyphens
|
||||
only_ascii = re.sub(r"[^a-z0-9_.-]", "", only_ascii)
|
||||
return only_ascii
|
||||
|
||||
|
||||
def truncate_string(s: str) -> str:
|
||||
"""given a string with /, return a string with only the text after the last /"""
|
||||
return pathlib.Path(s).name
|
||||
|
||||
|
||||
def remove_extension(s: str) -> str:
|
||||
"""Given a string, if it has a extension like .pdf, remove it and return the new string"""
|
||||
return str(pathlib.Path(s).with_suffix(""))
|
||||
|
||||
|
||||
def remove_duplicate_vowels(s: str) -> str:
|
||||
"""Removes consecutive duplicate vowels (a, e, i, o, u) from a string."""
|
||||
return re.sub(r"([aeiou])\1+", r"\1", s, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def normalize_filenames(
|
||||
directory: str = typer.Argument(
|
||||
..., help="The path to the directory containing files to normalize."
|
||||
),
|
||||
):
|
||||
"""Normalizes all filenames in a directory."""
|
||||
console = Console()
|
||||
console.print(
|
||||
Panel(
|
||||
f"Normalizing filenames in directory: [bold cyan]{directory}[/bold cyan]",
|
||||
title="[bold green]Filename Normalizer[/bold green]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
source_path = pathlib.Path(directory)
|
||||
if not source_path.is_dir():
|
||||
console.print(f"[bold red]Error: Directory not found at {directory}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
files_to_rename = [p for p in source_path.rglob("*") if p.is_file()]
|
||||
|
||||
if not files_to_rename:
|
||||
console.print(
|
||||
f"[bold yellow]No files found in {directory} to normalize.[/bold yellow]"
|
||||
)
|
||||
return
|
||||
|
||||
table = Table(title="File Renaming Summary")
|
||||
table.add_column("Original Name", style="cyan", no_wrap=True)
|
||||
table.add_column("New Name", style="magenta", no_wrap=True)
|
||||
table.add_column("Status", style="green")
|
||||
|
||||
for file_path in files_to_rename:
|
||||
original_name = file_path.name
|
||||
file_stem = file_path.stem
|
||||
file_suffix = file_path.suffix
|
||||
|
||||
normalized_stem = normalize_string(file_stem)
|
||||
new_name = f"{normalized_stem}{file_suffix}"
|
||||
|
||||
if new_name == original_name:
|
||||
table.add_row(
|
||||
original_name, new_name, "[yellow]Skipped (No change)[/yellow]"
|
||||
)
|
||||
continue
|
||||
|
||||
new_path = file_path.with_name(new_name)
|
||||
|
||||
# Handle potential name collisions
|
||||
counter = 1
|
||||
while new_path.exists():
|
||||
new_name = f"{normalized_stem}_{counter}{file_suffix}"
|
||||
new_path = file_path.with_name(new_name)
|
||||
counter += 1
|
||||
|
||||
try:
|
||||
file_path.rename(new_path)
|
||||
table.add_row(original_name, new_name, "[green]Renamed[/green]")
|
||||
except OSError as e:
|
||||
table.add_row(original_name, new_name, f"[bold red]Error: {e}[/bold red]")
|
||||
|
||||
console.print(table)
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold]Normalization complete.[/bold] Processed [bold blue]{len(files_to_rename)}[/bold blue] files.",
|
||||
title="[bold green]Complete[/bold green]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
0
packages/utils/src/utils/py.typed
Normal file
0
packages/utils/src/utils/py.typed
Normal file
1
packages/vector-search/.python-version
Normal file
1
packages/vector-search/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.10
|
||||
0
packages/vector-search/README.md
Normal file
0
packages/vector-search/README.md
Normal file
29
packages/vector-search/pyproject.toml
Normal file
29
packages/vector-search/pyproject.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[project]
|
||||
name = "vector-search"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"embedder",
|
||||
"file-storage",
|
||||
"google-cloud-aiplatform>=1.106.0",
|
||||
"aiohttp>=3.10.11,<4",
|
||||
"gcloud-aio-auth>=5.3.0",
|
||||
"google-auth==2.29.0",
|
||||
"typer>=0.16.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
vector-search = "vector_search.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.uv.sources]
|
||||
file-storage = { workspace = true }
|
||||
embedder = { workspace = true }
|
||||
2
packages/vector-search/src/vector_search/__init__.py
Normal file
2
packages/vector-search/src/vector_search/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from vector-search!"
|
||||
62
packages/vector-search/src/vector_search/base.py
Normal file
62
packages/vector-search/src/vector_search/base.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, TypedDict
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
id: str
|
||||
distance: float
|
||||
content: str
|
||||
|
||||
|
||||
class BaseVectorSearch(ABC):
|
||||
"""
|
||||
Abstract base class for a vector search provider.
|
||||
|
||||
This class defines the standard interface for creating a vector search index
|
||||
and running queries against it.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_index(self, name: str, content_path: str, **kwargs) -> None:
|
||||
"""
|
||||
Creates a new vector search index and populates it with the provided content.
|
||||
|
||||
Args:
|
||||
name: The desired name for the new index.
|
||||
content_path: The local file system path to the data that will be used to
|
||||
populate the index. This is expected to be a JSON file
|
||||
containing a list of objects, each with an 'id', 'name',
|
||||
and 'embedding' key.
|
||||
**kwargs: Additional provider-specific arguments for index creation.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def update_index(self, index_name: str, content_path: str, **kwargs) -> None:
|
||||
"""
|
||||
Updates an existing vector search index with new content.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to update.
|
||||
content_path: The local file system path to the data that will be used to
|
||||
populate the index.
|
||||
**kwargs: Additional provider-specific arguments for index update.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def run_query(
|
||||
self, index: str, query: List[float], limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Runs a similarity search query against the index.
|
||||
|
||||
Args:
|
||||
query: The embedding vector to use for the search query.
|
||||
limit: The maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, where each dictionary represents a matched item
|
||||
and contains at least the item's 'id' and the search 'distance'.
|
||||
"""
|
||||
...
|
||||
10
packages/vector-search/src/vector_search/cli/__init__.py
Normal file
10
packages/vector-search/src/vector_search/cli/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from typer import Typer
|
||||
|
||||
from .create import app as create_callback
|
||||
from .delete import app as delete_callback
|
||||
from .query import app as query_callback
|
||||
|
||||
app = Typer()
|
||||
app.add_typer(create_callback, name="create")
|
||||
app.add_typer(delete_callback, name="delete")
|
||||
app.add_typer(query_callback, name="query")
|
||||
91
packages/vector-search/src/vector_search/cli/create.py
Normal file
91
packages/vector-search/src/vector_search/cli/create.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Create and deploy a Vertex AI Vector Search index."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def create(
|
||||
path: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--path",
|
||||
"-p",
|
||||
help="The GCS URI (gs://...) to the directory containing your embedding JSON file(s).",
|
||||
),
|
||||
],
|
||||
agent_name: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--agent",
|
||||
"-a",
|
||||
help="The name of the agent to create the index for.",
|
||||
),
|
||||
],
|
||||
):
|
||||
"""Create and deploy a Vertex AI Vector Search index for a specific agent."""
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
console.print(
|
||||
f"[bold green]Looking up configuration for agent '{agent_name}'...[/bold green]"
|
||||
)
|
||||
agent_config = config.agents.get(agent_name)
|
||||
if not agent_config:
|
||||
console.print(
|
||||
f"[bold red]Agent '{agent_name}' not found in settings.[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if not agent_config.index:
|
||||
console.print(
|
||||
f"[bold red]Index configuration not found for agent '{agent_name}'.[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
index_config = agent_config.index
|
||||
|
||||
console.print(
|
||||
f"[bold green]Initializing Vertex AI client for project '{config.project_id}' in '{config.location}'...[/bold green]"
|
||||
)
|
||||
vector_search = GoogleCloudVectorSearch(
|
||||
project_id=config.project_id,
|
||||
location=config.location,
|
||||
bucket=config.bucket,
|
||||
index_name=index_config.name,
|
||||
)
|
||||
|
||||
console.print(
|
||||
f"[bold green]Starting creation of index '{index_config.name}'...[/bold green]"
|
||||
)
|
||||
console.print("This may take a while.")
|
||||
vector_search.create_index(
|
||||
name=index_config.name,
|
||||
content_path=f"gs://{config.bucket}/{path}",
|
||||
dimensions=index_config.dimensions,
|
||||
)
|
||||
console.print(
|
||||
f"[bold green]Index '{index_config.name}' created successfully.[/bold green]"
|
||||
)
|
||||
|
||||
console.print("[bold green]Deploying index to a new endpoint...[/bold green]")
|
||||
console.print("This will also take some time.")
|
||||
vector_search.deploy_index(
|
||||
index_name=index_config.name, machine_type=index_config.machine_type
|
||||
)
|
||||
console.print("[bold green]Index deployed successfully![/bold green]")
|
||||
console.print(f"Endpoint name: {vector_search.index_endpoint.display_name}")
|
||||
console.print(
|
||||
f"Endpoint resource name: {vector_search.index_endpoint.resource_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
38
packages/vector-search/src/vector_search/cli/delete.py
Normal file
38
packages/vector-search/src/vector_search/cli/delete.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Delete a vector index or endpoint."""
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def delete(
|
||||
id: str = typer.Argument(..., help="The ID of the index or endpoint to delete."),
|
||||
endpoint: bool = typer.Option(
|
||||
False, "--endpoint", help="Delete an endpoint instead of an index."
|
||||
),
|
||||
):
|
||||
"""Delete a vector index or endpoint."""
|
||||
console = Console()
|
||||
vector_search = GoogleCloudVectorSearch(
|
||||
project_id=config.project_id, location=config.location, bucket=config.bucket
|
||||
)
|
||||
|
||||
try:
|
||||
if endpoint:
|
||||
console.print(f"[bold red]Deleting endpoint {id}...[/bold red]")
|
||||
vector_search.delete_index_endpoint(id)
|
||||
console.print(
|
||||
f"[bold green]Endpoint {id} deleted successfully.[/bold green]"
|
||||
)
|
||||
else:
|
||||
console.print(f"[bold red]Deleting index {id}...[/bold red]")
|
||||
vector_search.delete_index(id)
|
||||
console.print(f"[bold green]Index {id} deleted successfully.[/bold green]")
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
91
packages/vector-search/src/vector_search/cli/generate.py
Normal file
91
packages/vector-search/src/vector_search/cli/generate.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Generate embeddings for documents and save them to a JSON file."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress
|
||||
|
||||
from rag_eval.config import Settings
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def generate(
|
||||
path: str = typer.Argument(..., help="The path to the markdown files."),
|
||||
output_file: str = typer.Option(
|
||||
...,
|
||||
"--output-file",
|
||||
"-o",
|
||||
help="The local path to save the output JSON file.",
|
||||
),
|
||||
batch_size: int = typer.Option(
|
||||
10,
|
||||
"--batch-size",
|
||||
"-b",
|
||||
help="The batch size for processing files.",
|
||||
),
|
||||
jsonl: bool = typer.Option(
|
||||
False,
|
||||
"--jsonl",
|
||||
help="Output in JSONL format instead of JSON.",
|
||||
),
|
||||
):
|
||||
"""Generate embeddings for documents and save them to a JSON file."""
|
||||
config = Settings()
|
||||
console = Console()
|
||||
|
||||
console.print("[bold green]Starting vector generation...[/bold green]")
|
||||
|
||||
try:
|
||||
storage = GoogleCloudFileStorage(bucket=config.bucket)
|
||||
embedder = VertexAIEmbedder(model_name=config.embedding_model)
|
||||
|
||||
remote_files = storage.list_files(path=path)
|
||||
results = []
|
||||
|
||||
with Progress(console=console) as progress:
|
||||
task = progress.add_task(
|
||||
"[cyan]Generating embeddings...", total=len(remote_files)
|
||||
)
|
||||
|
||||
for i in range(0, len(remote_files), batch_size):
|
||||
batch_files = remote_files[i : i + batch_size]
|
||||
batch_contents = []
|
||||
|
||||
for remote_file in batch_files:
|
||||
file_stream = storage.get_file_stream(remote_file)
|
||||
batch_contents.append(
|
||||
file_stream.read().decode("utf-8-sig", errors="replace")
|
||||
)
|
||||
|
||||
batch_embeddings = embedder.generate_embeddings_batch(batch_contents)
|
||||
|
||||
for j, remote_file in enumerate(batch_files):
|
||||
results.append(
|
||||
{"id": remote_file, "embedding": batch_embeddings[j]}
|
||||
)
|
||||
progress.update(task, advance=1)
|
||||
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred during vector generation: {e}[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
if jsonl:
|
||||
for record in results:
|
||||
f.write(json.dumps(record) + "\n")
|
||||
else:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
console.print(
|
||||
f"[bold green]Embedding generation complete. {len(results)} vectors saved to '{output_path.resolve()}'[/bold green]"
|
||||
)
|
||||
55
packages/vector-search/src/vector_search/cli/query.py
Normal file
55
packages/vector-search/src/vector_search/cli/query.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Query the vector search index."""
|
||||
|
||||
import typer
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from typer import Argument, Option
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def query(
|
||||
query: str = Argument(..., help="The text query to search for."),
|
||||
limit: int = Option(5, "--limit", "-l", help="The number of results to return."),
|
||||
):
|
||||
"""Queries the vector search index."""
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
console.print("[bold green]Initializing clients...[/bold green]")
|
||||
embedder = VertexAIEmbedder(model_name=config.embedding_model)
|
||||
vector_search = GoogleCloudVectorSearch(
|
||||
project_id=config.project_id, location=config.location, bucket=config.bucket
|
||||
)
|
||||
|
||||
console.print("[bold green]Loading index endpoint...[/bold green]")
|
||||
vector_search.load_index_endpoint(config.index.endpoint)
|
||||
|
||||
console.print("[bold green]Generating embedding for query...[/bold green]")
|
||||
query_embedding = embedder.generate_embedding(query)
|
||||
|
||||
console.print("[bold green]Running search query...[/bold green]")
|
||||
search_results = vector_search.run_query(
|
||||
deployed_index_id=config.index.deployment,
|
||||
query=query_embedding,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
table = Table(title="Search Results")
|
||||
table.add_column("ID", justify="left", style="cyan")
|
||||
table.add_column("Distance", justify="left", style="magenta")
|
||||
table.add_column("Content", justify="left", style="green")
|
||||
|
||||
for result in search_results:
|
||||
table.add_row(result["id"], str(result["distance"]), result["content"])
|
||||
|
||||
console.print(table)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
0
packages/vector-search/src/vector_search/py.typed
Normal file
0
packages/vector-search/src/vector_search/py.typed
Normal file
255
packages/vector-search/src/vector_search/vertex_ai.py
Normal file
255
packages/vector-search/src/vector_search/vertex_ai.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
||||
from gcloud.aio.auth import Token
|
||||
from google.cloud import aiplatform
|
||||
|
||||
from .base import BaseVectorSearch, SearchResult
|
||||
|
||||
|
||||
class GoogleCloudVectorSearch(BaseVectorSearch):
|
||||
"""
|
||||
A vector search provider that uses Google Cloud's Vertex AI Vector Search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, project_id: str, location: str, bucket: str, index_name: str = None
|
||||
):
|
||||
"""
|
||||
Initializes the GoogleCloudVectorSearch client.
|
||||
|
||||
Args:
|
||||
project_id: The Google Cloud project ID.
|
||||
location: The Google Cloud location (e.g., 'us-central1').
|
||||
bucket: The GCS bucket to use for file storage.
|
||||
index_name: The name of the index. If None, it will be taken from settings.
|
||||
"""
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||
self.index_name = index_name
|
||||
self._credentials = None
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._async_token: Token | None = None
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
if self._credentials is None:
|
||||
self._credentials, _ = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
if not self._credentials.token or self._credentials.expired:
|
||||
self._credentials.refresh(google.auth.transport.requests.Request())
|
||||
return {
|
||||
"Authorization": f"Bearer {self._credentials.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def _async_get_auth_headers(self) -> dict:
|
||||
if self._async_token is None:
|
||||
self._async_token = Token(
|
||||
session=self._get_aio_session(),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
access_token = await self._async_token.get()
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(limit=300, limit_per_host=50)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
name: str,
|
||||
content_path: str,
|
||||
dimensions: int,
|
||||
approximate_neighbors_count: int = 150,
|
||||
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Creates a new Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
name: The display name for the new index.
|
||||
content_path: The GCS URI to the JSON file containing the embeddings.
|
||||
dimensions: The number of dimensions in the embedding vectors.
|
||||
approximate_neighbors_count: The number of neighbors to find for each vector.
|
||||
distance_measure_type: The distance measure to use (e.g., 'DOT_PRODUCT_DISTANCE').
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
||||
display_name=name,
|
||||
contents_delta_uri=content_path,
|
||||
dimensions=dimensions,
|
||||
approximate_neighbors_count=approximate_neighbors_count,
|
||||
distance_measure_type=distance_measure_type,
|
||||
leaf_node_embedding_count=1000,
|
||||
leaf_nodes_to_search_percent=10,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def update_index(self, index_name: str, content_path: str, **kwargs) -> None:
|
||||
"""
|
||||
Updates an existing Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index to update.
|
||||
content_path: The GCS URI to the JSON file containing the new embeddings.
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
||||
index.update_embeddings(
|
||||
contents_delta_uri=content_path,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def deploy_index(
|
||||
self, index_name: str, machine_type: str = "e2-standard-2"
|
||||
) -> None:
|
||||
"""
|
||||
Deploys a Vertex AI Vector Search index to an endpoint.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to deploy.
|
||||
machine_type: The type of machine to use for the endpoint.
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
||||
display_name=f"{index_name}-endpoint",
|
||||
public_endpoint_enabled=True,
|
||||
)
|
||||
index_endpoint.deploy_index(
|
||||
index=self.index,
|
||||
deployed_index_id=f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}",
|
||||
machine_type=machine_type,
|
||||
)
|
||||
self.index_endpoint = index_endpoint
|
||||
|
||||
def load_index_endpoint(self, endpoint_name: str) -> None:
|
||||
"""
|
||||
Loads an existing Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: The resource name of the index endpoint.
|
||||
"""
|
||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(endpoint_name)
|
||||
if not self.index_endpoint.public_endpoint_domain_name:
|
||||
raise ValueError(
|
||||
"The index endpoint does not have a public endpoint. "
|
||||
"Please ensure that the endpoint is configured for public access."
|
||||
)
|
||||
|
||||
def run_query(
|
||||
self, deployed_index_id: str, query: List[float], limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Runs a similarity search query against the deployed index.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector to use for the search query.
|
||||
limit: The maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries representing the matched items.
|
||||
"""
|
||||
response = self.index_endpoint.find_neighbors(
|
||||
deployed_index_id=deployed_index_id, queries=[query], num_neighbors=limit
|
||||
)
|
||||
results = []
|
||||
for neighbor in response[0]:
|
||||
file_path = self.index_name + "/contents/" + neighbor.id + ".md"
|
||||
content = self.storage.get_file_stream(file_path).read().decode("utf-8")
|
||||
results.append(
|
||||
{"id": neighbor.id, "distance": neighbor.distance, "content": content}
|
||||
)
|
||||
return results
|
||||
|
||||
async def async_run_query(
|
||||
self, deployed_index_id: str, query: List[float], limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Runs a non-blocking similarity search query against the deployed index
|
||||
using the REST API directly with an async HTTP client.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector to use for the search query.
|
||||
limit: The maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries representing the matched items.
|
||||
"""
|
||||
domain = self.index_endpoint.public_endpoint_domain_name
|
||||
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
||||
url = (
|
||||
f"https://{domain}/v1/projects/{self.project_id}"
|
||||
f"/locations/{self.location}"
|
||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||
)
|
||||
payload = {
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"queries": [
|
||||
{
|
||||
"datapoint": {"feature_vector": query},
|
||||
"neighbor_count": limit,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
headers = await self._async_get_auth_headers()
|
||||
session = self._get_aio_session()
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||
content_tasks = []
|
||||
for neighbor in neighbors:
|
||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
|
||||
content_tasks.append(self.storage.async_get_file_stream(file_path))
|
||||
|
||||
file_streams = await asyncio.gather(*content_tasks)
|
||||
results: List[SearchResult] = []
|
||||
for neighbor, stream in zip(neighbors, file_streams):
|
||||
results.append(
|
||||
{
|
||||
"id": neighbor["datapoint"]["datapointId"],
|
||||
"distance": neighbor["distance"],
|
||||
"content": stream.read().decode("utf-8"),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def delete_index(self, index_name: str) -> None:
|
||||
"""
|
||||
Deletes a Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index.
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name)
|
||||
index.delete()
|
||||
|
||||
def delete_index_endpoint(self, index_endpoint_name: str) -> None:
|
||||
"""
|
||||
Deletes a Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
index_endpoint_name: The resource name of the index endpoint.
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name)
|
||||
index_endpoint.undeploy_all()
|
||||
index_endpoint.delete(force=True)
|
||||
79
pyproject.toml
Normal file
79
pyproject.toml
Normal file
@@ -0,0 +1,79 @@
|
||||
[project]
|
||||
name = "rag-eval"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = "~=3.12.0"
|
||||
dependencies = [
|
||||
"google-genai==1.45.0",
|
||||
"pip>=25.3",
|
||||
"pydantic-settings[yaml]>=2.10.1",
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
ragops = "rag_eval.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
"apps/*",
|
||||
"packages/*",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
document-converter = { workspace = true }
|
||||
search-eval = { workspace = true }
|
||||
pdf-ingest = { workspace = true }
|
||||
llm = { workspace = true }
|
||||
embedder = { workspace = true }
|
||||
file-storage = { workspace = true }
|
||||
utils = { workspace = true }
|
||||
vector-search = { workspace = true }
|
||||
synth-gen = { workspace = true }
|
||||
keypoint-eval = { workspace = true }
|
||||
chunker = { workspace = true }
|
||||
index-gen = { workspace = true }
|
||||
dialogflow = { workspace = true }
|
||||
integration-layer = { workspace = true }
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"dialogflow",
|
||||
"integration-layer",
|
||||
"ipykernel>=6.30.1",
|
||||
"mypy>=1.17.1",
|
||||
"pytest>=8.4.1",
|
||||
"ruff>=0.12.10",
|
||||
"ty>=0.0.1a19",
|
||||
"vector-search",
|
||||
]
|
||||
processor = [
|
||||
"index-gen",
|
||||
]
|
||||
evals = [
|
||||
"keypoint-eval",
|
||||
"search-eval",
|
||||
"synth-gen",
|
||||
]
|
||||
rag = [
|
||||
"embedder",
|
||||
"fastapi[standard]>=0.116.1",
|
||||
"file-storage",
|
||||
"llm",
|
||||
"numpy>=2.3.5",
|
||||
"structlog>=25.5.0",
|
||||
"vector-search",
|
||||
]
|
||||
pipeline = [
|
||||
"kfp>=2.15.2",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
extend-select = ["I", "F"]
|
||||
80
scripts/diagnose_embeddings.py
Normal file
80
scripts/diagnose_embeddings.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import typer
|
||||
import random
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
|
||||
load_dotenv()
|
||||
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
location = os.getenv("GOOGLE_CLOUD_LOCATION")
|
||||
|
||||
MODEL_NAME = "gemini-embedding-001"
|
||||
CONTENT_LIST = [
|
||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
||||
"¿Qué es una hipoteca y cómo funciona?",
|
||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
||||
"¿Qué es la banca en línea y cómo me registro?",
|
||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
||||
]
|
||||
TASK_TYPE = "RETRIEVAL_DOCUMENT"
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
logger.info(f"Initializing GenAI Client for project '{project}' in '{location}'")
|
||||
embedder = VertexAIEmbedder(MODEL_NAME, project, location)
|
||||
|
||||
async def embed_content_task():
|
||||
"""A single task to send one embedding request using the global client."""
|
||||
content_to_embed = random.choice(CONTENT_LIST)
|
||||
await embedder.async_generate_embedding(content_to_embed)
|
||||
|
||||
async def run_test(concurrency: int):
|
||||
"""Continuously calls the embedding API and tracks requests."""
|
||||
total_requests = 0
|
||||
|
||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on model '{MODEL_NAME}'.")
|
||||
logger.info("Press Ctrl+C to stop.")
|
||||
|
||||
while True:
|
||||
# Create tasks, passing project_id and location
|
||||
tasks = [embed_content_task() for _ in range(concurrency)]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
total_requests += concurrency
|
||||
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
|
||||
except Exception as e:
|
||||
logger.error("Caught an error. Stopping test.")
|
||||
print("\n--- STATS ---")
|
||||
print(f"Total successful requests: {total_requests}")
|
||||
print(f"Concurrent requests during failure: {concurrency}")
|
||||
print(f"Error Type: {e.__class__.__name__}")
|
||||
print(f"Error Details: {e}")
|
||||
print("-------------")
|
||||
break
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
concurrency: int = typer.Option(
|
||||
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
|
||||
),
|
||||
):
|
||||
try:
|
||||
asyncio.run(run_test(concurrency))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nKeyboard interrupt received. Exiting.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
98
scripts/diagnose_rag_endpoint.py
Normal file
98
scripts/diagnose_rag_endpoint.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import typer
|
||||
import httpx
|
||||
|
||||
CONTENT_LIST = [
|
||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
||||
"¿Qué es una hipoteca y cómo funciona?",
|
||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
||||
"¿Qué es la banca en línea y cómo me registro?",
|
||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
||||
]
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
async def call_rag_endpoint_task(client: httpx.AsyncClient, url: str):
|
||||
"""A single task to send one request to the RAG endpoint."""
|
||||
question = random.choice(CONTENT_LIST)
|
||||
json_payload = {
|
||||
"sessionInfo": {
|
||||
"parameters": {
|
||||
"query": question
|
||||
}
|
||||
}
|
||||
}
|
||||
response = await client.post(url, json=json_payload)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
response_data = response.json()
|
||||
response_text = response_data["sessionInfo"]["parameters"]["response"]
|
||||
logger.info(f"Question: {question[:50]}... Response: {response_text[:100]}...")
|
||||
|
||||
async def run_test(concurrency: int, url: str, timeout_seconds: float):
|
||||
"""Continuously calls the RAG endpoint and tracks requests."""
|
||||
total_requests = 0
|
||||
|
||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on endpoint '{url}'.")
|
||||
logger.info(f"Request timeout is set to {timeout_seconds} seconds.")
|
||||
logger.info("Press Ctrl+C to stop.")
|
||||
|
||||
timeout = httpx.Timeout(timeout_seconds)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
while True:
|
||||
tasks = [call_rag_endpoint_task(client, url) for _ in range(concurrency)]
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
total_requests += concurrency
|
||||
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
|
||||
except httpx.TimeoutException as e:
|
||||
logger.error(f"A request timed out: {e.request.method} {e.request.url}")
|
||||
logger.error("Consider increasing the timeout with the --timeout option.")
|
||||
break
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"An HTTP error occurred: {e.response.status_code} - {e.request.method} {e.request.url}")
|
||||
logger.error(f"Response body: {e.response.text}")
|
||||
break
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"A request error occurred: {e.request.method} {e.request.url}")
|
||||
logger.error(f"Error details: {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Caught an unexpected error. Stopping test.")
|
||||
print("\n--- STATS ---")
|
||||
print(f"Total successful requests: {total_requests}")
|
||||
print(f"Concurrent requests during failure: {concurrency}")
|
||||
print(f"Error Type: {e.__class__.__name__}")
|
||||
print(f"Error Details: {e}")
|
||||
print("-------------")
|
||||
break
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
concurrency: int = typer.Option(
|
||||
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
|
||||
),
|
||||
url: str = typer.Option(
|
||||
"http://127.0.0.1:8000/sigma-rag", "--url", "-u", help="The URL of the RAG endpoint to test."
|
||||
),
|
||||
timeout_seconds: float = typer.Option(
|
||||
30.0, "--timeout", "-t", help="Request timeout in seconds."
|
||||
)
|
||||
):
|
||||
try:
|
||||
asyncio.run(run_test(concurrency, url, timeout_seconds))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\nKeyboard interrupt received. Exiting.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
91
scripts/stress_test.py
Normal file
91
scripts/stress_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import requests
|
||||
import time
|
||||
import random
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
# URL for the endpoint
|
||||
url = "http://localhost:8000/sigma-rag"
|
||||
|
||||
# List of Spanish banking questions
|
||||
spanish_questions = [
|
||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
||||
"¿Qué es una hipoteca y cómo funciona?",
|
||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
||||
"¿Qué es la banca en línea y cómo me registro?",
|
||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
||||
]
|
||||
|
||||
# A threading Event to signal all threads to stop
|
||||
stop_event = threading.Event()
|
||||
|
||||
def send_request(question, request_id):
|
||||
"""Sends a single request and handles the response."""
|
||||
if stop_event.is_set():
|
||||
return
|
||||
|
||||
data = {"sessionInfo": {"parameters": {"query": question}}}
|
||||
try:
|
||||
response = requests.post(url, json=data)
|
||||
|
||||
if stop_event.is_set():
|
||||
return
|
||||
|
||||
if response.status_code == 500:
|
||||
print(f"Request {request_id}: Received 500 error with question: '{question}'.")
|
||||
print("Stopping stress test.")
|
||||
stop_event.set()
|
||||
else:
|
||||
print(f"Request {request_id}: Successful with status code {response.status_code}.")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
if not stop_event.is_set():
|
||||
print(f"Request {request_id}: An error occurred: {e}")
|
||||
stop_event.set()
|
||||
|
||||
def main():
|
||||
"""Runs the stress test with parallel requests."""
|
||||
num_workers = 30 # Number of parallel requests
|
||||
print(f"Starting stress test with {num_workers} parallel workers. Press Ctrl+C to stop.")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(send_request, random.choice(spanish_questions), i)
|
||||
for i in range(1, num_workers + 1)
|
||||
}
|
||||
request_id_counter = num_workers + 1
|
||||
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
# Wait for any future to complete
|
||||
done, _ = concurrent.futures.wait(
|
||||
futures, return_when=concurrent.futures.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
for future in done:
|
||||
# Remove the completed future
|
||||
futures.remove(future)
|
||||
|
||||
# If we are not stopping, submit a new one
|
||||
if not stop_event.is_set():
|
||||
futures.add(
|
||||
executor.submit(
|
||||
send_request,
|
||||
random.choice(spanish_questions),
|
||||
request_id_counter,
|
||||
)
|
||||
)
|
||||
request_id_counter += 1
|
||||
except KeyboardInterrupt:
|
||||
print("\nKeyboard interrupt received. Stopping threads.")
|
||||
stop_event.set()
|
||||
|
||||
print("Stress test finished.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
84
scripts/submit_pipeline.py
Normal file
84
scripts/submit_pipeline.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import typer
|
||||
from google.cloud import aiplatform
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
pipeline_spec_path: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--pipeline-spec-path",
|
||||
"-p",
|
||||
help="Path to the compiled pipeline YAML file.",
|
||||
),
|
||||
],
|
||||
input_table: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--input-table",
|
||||
"-i",
|
||||
help="Full BigQuery table name for input (e.g., 'project.dataset.table')",
|
||||
),
|
||||
],
|
||||
output_table: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--output-table",
|
||||
"-o",
|
||||
help="Full BigQuery table name for output (e.g., 'project.dataset.table')",
|
||||
),
|
||||
],
|
||||
project_id: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--project-id",
|
||||
help="Google Cloud project ID.",
|
||||
),
|
||||
] = settings.project_id,
|
||||
location: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--location",
|
||||
help="Google Cloud location for the pipeline job.",
|
||||
),
|
||||
] = settings.location,
|
||||
display_name: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--display-name",
|
||||
help="Display name for the pipeline job.",
|
||||
),
|
||||
] = "search-eval-pipeline-job",
|
||||
):
|
||||
"""Submits a Vertex AI pipeline job."""
|
||||
|
||||
parameter_values = {
|
||||
"project_id": project_id,
|
||||
"location": location,
|
||||
"input_table": input_table,
|
||||
"output_table": output_table,
|
||||
}
|
||||
|
||||
job = aiplatform.PipelineJob(
|
||||
display_name=display_name,
|
||||
template_path=pipeline_spec_path,
|
||||
pipeline_root=f"gs://{settings.bucket}/pipeline_root",
|
||||
parameter_values=parameter_values,
|
||||
project=project_id,
|
||||
location=location,
|
||||
)
|
||||
|
||||
print(f"Submitting pipeline job with parameters: {parameter_values}")
|
||||
job.submit(
|
||||
service_account="sa-cicd-gitlab@bnt-orquestador-cognitivo-dev.iam.gserviceaccount.com"
|
||||
)
|
||||
print(f"Pipeline job submitted. You can view it at: {job._dashboard_uri()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user