Merge pull request 'Code shred' (#4) from v2 into master
Reviewed-on: #4
This commit was merged in pull request #4.
This commit is contained in:
31
README.md
31
README.md
@@ -260,30 +260,18 @@ embeddings = embedder.generate_embeddings_batch(texts, batch_size=10)
|
|||||||
### **Opción 4: Almacenar en GCS**
|
### **Opción 4: Almacenar en GCS**
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
import gcsfs
|
||||||
|
|
||||||
storage = GoogleCloudFileStorage(bucket="mi-bucket")
|
fs = gcsfs.GCSFileSystem()
|
||||||
|
|
||||||
# Subir archivo
|
# Subir archivo
|
||||||
storage.upload_file(
|
fs.put("local_file.md", "mi-bucket/chunks/documento_0.md")
|
||||||
file_path="local_file.md",
|
|
||||||
destination_blob_name="chunks/documento_0.md",
|
|
||||||
content_type="text/markdown"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Listar archivos
|
# Listar archivos
|
||||||
files = storage.list_files(path="chunks/")
|
files = fs.ls("mi-bucket/chunks/")
|
||||||
|
|
||||||
# Descargar archivo
|
# Descargar archivo
|
||||||
file_stream = storage.get_file_stream("chunks/documento_0.md")
|
content = fs.cat_file("mi-bucket/chunks/documento_0.md").decode("utf-8")
|
||||||
content = file_stream.read().decode("utf-8")
|
|
||||||
```
|
|
||||||
|
|
||||||
**CLI:**
|
|
||||||
```bash
|
|
||||||
file-storage upload local_file.md chunks/documento_0.md
|
|
||||||
file-storage list chunks/
|
|
||||||
file-storage download chunks/documento_0.md
|
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -340,10 +328,10 @@ vector-search delete mi-indice
|
|||||||
## 🔄 Flujo Completo de Ejemplo
|
## 🔄 Flujo Completo de Ejemplo
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import gcsfs
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from chunker.contextual_chunker import ContextualChunker
|
from chunker.contextual_chunker import ContextualChunker
|
||||||
from embedder.vertex_ai import VertexAIEmbedder
|
from embedder.vertex_ai import VertexAIEmbedder
|
||||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
|
||||||
from llm.vertex_ai import VertexAILLM
|
from llm.vertex_ai import VertexAILLM
|
||||||
|
|
||||||
# 1. Setup
|
# 1. Setup
|
||||||
@@ -354,7 +342,7 @@ embedder = VertexAIEmbedder(
|
|||||||
project="mi-proyecto",
|
project="mi-proyecto",
|
||||||
location="us-central1"
|
location="us-central1"
|
||||||
)
|
)
|
||||||
storage = GoogleCloudFileStorage(bucket="mi-bucket")
|
fs = gcsfs.GCSFileSystem()
|
||||||
|
|
||||||
# 2. Chunking
|
# 2. Chunking
|
||||||
documents = chunker.process_path(Path("documento.pdf"))
|
documents = chunker.process_path(Path("documento.pdf"))
|
||||||
@@ -368,10 +356,7 @@ for i, doc in enumerate(documents):
|
|||||||
embedding = embedder.generate_embedding(doc["page_content"])
|
embedding = embedder.generate_embedding(doc["page_content"])
|
||||||
|
|
||||||
# Guardar contenido en GCS
|
# Guardar contenido en GCS
|
||||||
storage.upload_file(
|
fs.put(f"temp_{chunk_id}.md", f"mi-bucket/contents/{chunk_id}.md")
|
||||||
file_path=f"temp_{chunk_id}.md",
|
|
||||||
destination_blob_name=f"contents/{chunk_id}.md"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Guardar vector (escribir a JSONL localmente, luego subir)
|
# Guardar vector (escribir a JSONL localmente, luego subir)
|
||||||
print(f"Chunk {chunk_id}: {len(embedding)} dimensiones")
|
print(f"Chunk {chunk_id}: {len(embedding)} dimensiones")
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "index-gen"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Add your description here"
|
|
||||||
readme = "README.md"
|
|
||||||
authors = [
|
|
||||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
|
||||||
]
|
|
||||||
requires-python = ">=3.12"
|
|
||||||
dependencies = [
|
|
||||||
"chunker",
|
|
||||||
"document-converter",
|
|
||||||
"embedder",
|
|
||||||
"file-storage",
|
|
||||||
"llm",
|
|
||||||
"utils",
|
|
||||||
"vector-search",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
index-gen = "index_gen.cli:app"
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["uv_build>=0.8.12,<0.9.0"]
|
|
||||||
build-backend = "uv_build"
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
|
||||||
file-storage = { workspace = true }
|
|
||||||
vector-search = { workspace = true }
|
|
||||||
utils = { workspace = true }
|
|
||||||
embedder = { workspace = true }
|
|
||||||
chunker = { workspace = true }
|
|
||||||
document-converter = { workspace = true }
|
|
||||||
llm = { workspace = true }
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
def main() -> None:
|
|
||||||
print("Hello from index-gen!")
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
import logging
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from index_gen.main import (
|
|
||||||
aggregate_vectors,
|
|
||||||
build_gcs_path,
|
|
||||||
create_vector_index,
|
|
||||||
gather_files,
|
|
||||||
process_file,
|
|
||||||
)
|
|
||||||
from rag_eval.config import settings
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def run_ingestion():
|
|
||||||
"""Main function for the CLI script."""
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
agent_config = settings.agent
|
|
||||||
index_config = settings.index
|
|
||||||
|
|
||||||
if not agent_config or not index_config:
|
|
||||||
raise ValueError("Agent or index configuration not found in config.yaml")
|
|
||||||
|
|
||||||
# Gather files
|
|
||||||
files = gather_files(index_config.origin)
|
|
||||||
|
|
||||||
# Build output paths
|
|
||||||
contents_output_dir = build_gcs_path(index_config.data, "/contents")
|
|
||||||
vectors_output_dir = build_gcs_path(index_config.data, "/vectors")
|
|
||||||
aggregated_vectors_gcs_path = build_gcs_path(
|
|
||||||
index_config.data, "/vectors/vectors.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
temp_dir_path = Path(temp_dir)
|
|
||||||
vector_artifact_paths = []
|
|
||||||
|
|
||||||
# Process files and create local artifacts
|
|
||||||
for i, file in enumerate(files):
|
|
||||||
artifact_path = temp_dir_path / f"vectors_{i}.jsonl"
|
|
||||||
vector_artifact_paths.append(artifact_path)
|
|
||||||
|
|
||||||
process_file(
|
|
||||||
file,
|
|
||||||
agent_config.embedding_model,
|
|
||||||
contents_output_dir,
|
|
||||||
artifact_path, # Pass the local path
|
|
||||||
index_config.chunk_limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Aggregate the local artifacts into one file in GCS
|
|
||||||
aggregate_vectors(
|
|
||||||
vector_artifacts=vector_artifact_paths,
|
|
||||||
output_gcs_path=aggregated_vectors_gcs_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create vector index
|
|
||||||
create_vector_index(vectors_output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app()
|
|
||||||
@@ -1,238 +0,0 @@
|
|||||||
"""
|
|
||||||
This script defines a Kubeflow Pipeline (KFP) for ingesting and processing documents.
|
|
||||||
|
|
||||||
The pipeline is designed to run on Vertex AI Pipelines and consists of the following steps:
|
|
||||||
1. **Gather Files**: Scans a GCS directory for PDF files to process.
|
|
||||||
2. **Process Files (in parallel)**: For each PDF file found, this step:
|
|
||||||
a. Converts the PDF to Markdown text.
|
|
||||||
b. Chunks the text if it's too long.
|
|
||||||
c. Generates a vector embedding for each chunk using a Vertex AI embedding model.
|
|
||||||
d. Saves the markdown content and the vector embedding to separate GCS output paths.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from rag_eval.config import settings
|
|
||||||
|
|
||||||
|
|
||||||
def build_gcs_path(base_path: str, suffix: str) -> str:
|
|
||||||
"""Builds a GCS path by appending a suffix."""
|
|
||||||
return f"{base_path}{suffix}"
|
|
||||||
|
|
||||||
|
|
||||||
def gather_files(
|
|
||||||
input_dir: str,
|
|
||||||
) -> list:
|
|
||||||
"""Gathers all PDF file paths from a GCS directory."""
|
|
||||||
from google.cloud import storage
|
|
||||||
|
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
|
||||||
|
|
||||||
gcs_client = storage.Client()
|
|
||||||
bucket_name, prefix = input_dir.replace("gs://", "").split("/", 1)
|
|
||||||
bucket = gcs_client.bucket(bucket_name)
|
|
||||||
blob_list = bucket.list_blobs(prefix=prefix)
|
|
||||||
|
|
||||||
pdf_files = [
|
|
||||||
f"gs://{bucket_name}/{blob.name}"
|
|
||||||
for blob in blob_list
|
|
||||||
if blob.name.endswith(".pdf")
|
|
||||||
]
|
|
||||||
logging.info(f"Found {len(pdf_files)} PDF files in {input_dir}")
|
|
||||||
return pdf_files
|
|
||||||
|
|
||||||
|
|
||||||
def process_file(
|
|
||||||
file_path: str,
|
|
||||||
model_name: str,
|
|
||||||
contents_output_dir: str,
|
|
||||||
vectors_output_file: Path,
|
|
||||||
chunk_limit: int,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Processes a single PDF file: converts to markdown, chunks, and generates embeddings.
|
|
||||||
The vector embeddings are written to a local JSONL file.
|
|
||||||
"""
|
|
||||||
# Imports are inside the function as KFP serializes this function
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from chunker.contextual_chunker import ContextualChunker
|
|
||||||
from document_converter.markdown import MarkdownConverter
|
|
||||||
from embedder.vertex_ai import VertexAIEmbedder
|
|
||||||
from google.cloud import storage
|
|
||||||
from llm.vertex_ai import VertexAILLM
|
|
||||||
from utils.normalize_filenames import normalize_string
|
|
||||||
|
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# Initialize converters and embedders
|
|
||||||
converter = MarkdownConverter()
|
|
||||||
embedder = VertexAIEmbedder(model_name=model_name, project=settings.project_id, location=settings.location)
|
|
||||||
llm = VertexAILLM(project=settings.project_id, location=settings.location)
|
|
||||||
chunker = ContextualChunker(llm_client=llm, max_chunk_size=chunk_limit)
|
|
||||||
gcs_client = storage.Client()
|
|
||||||
|
|
||||||
file_id = normalize_string(Path(file_path).stem)
|
|
||||||
local_path = Path(f"/tmp/{Path(file_path).name}")
|
|
||||||
|
|
||||||
with open(vectors_output_file, "w", encoding="utf-8") as f:
|
|
||||||
try:
|
|
||||||
# Download file from GCS
|
|
||||||
bucket_name, blob_name = file_path.replace("gs://", "").split("/", 1)
|
|
||||||
bucket = gcs_client.bucket(bucket_name)
|
|
||||||
blob = bucket.blob(blob_name)
|
|
||||||
blob.download_to_filename(local_path)
|
|
||||||
logging.info(f"Processing file: {file_path}")
|
|
||||||
|
|
||||||
# Process the downloaded file
|
|
||||||
markdown_content = converter.process_file(local_path)
|
|
||||||
|
|
||||||
def upload_to_gcs(bucket_name, blob_name, data):
|
|
||||||
bucket = gcs_client.bucket(bucket_name)
|
|
||||||
blob = bucket.blob(blob_name)
|
|
||||||
blob.upload_from_string(data, content_type="text/markdown; charset=utf-8")
|
|
||||||
|
|
||||||
# Determine output bucket and paths for markdown
|
|
||||||
contents_bucket_name, contents_prefix = contents_output_dir.replace(
|
|
||||||
"gs://", ""
|
|
||||||
).split("/", 1)
|
|
||||||
|
|
||||||
# Extract source folder from file path
|
|
||||||
source_folder = Path(blob_name).parent.as_posix() if blob_name else ""
|
|
||||||
|
|
||||||
if len(markdown_content) > chunk_limit:
|
|
||||||
chunks = chunker.process_text(markdown_content)
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
chunk_id = f"{file_id}_{i}"
|
|
||||||
embedding = embedder.generate_embedding(chunk["page_content"])
|
|
||||||
|
|
||||||
# Upload markdown chunk
|
|
||||||
md_blob_name = f"{contents_prefix}/{chunk_id}.md"
|
|
||||||
upload_to_gcs(
|
|
||||||
contents_bucket_name, md_blob_name, chunk["page_content"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write vector to local JSONL file with source folder
|
|
||||||
vector_data = {
|
|
||||||
"id": chunk_id,
|
|
||||||
"embedding": embedding,
|
|
||||||
"source_folder": source_folder
|
|
||||||
}
|
|
||||||
json_line = json.dumps(vector_data)
|
|
||||||
f.write(json_line + '\n')
|
|
||||||
else:
|
|
||||||
embedding = embedder.generate_embedding(markdown_content)
|
|
||||||
|
|
||||||
# Upload markdown
|
|
||||||
md_blob_name = f"{contents_prefix}/{file_id}.md"
|
|
||||||
upload_to_gcs(contents_bucket_name, md_blob_name, markdown_content)
|
|
||||||
|
|
||||||
# Write vector to local JSONL file with source folder
|
|
||||||
vector_data = {
|
|
||||||
"id": file_id,
|
|
||||||
"embedding": embedding,
|
|
||||||
"source_folder": source_folder
|
|
||||||
}
|
|
||||||
json_line = json.dumps(vector_data)
|
|
||||||
f.write(json_line + '\n')
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to process file {file_path}: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up the downloaded file
|
|
||||||
if os.path.exists(local_path):
|
|
||||||
os.remove(local_path)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def aggregate_vectors(
|
|
||||||
vector_artifacts: list, # This will be a list of paths to the artifact files
|
|
||||||
output_gcs_path: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Aggregates multiple JSONL artifact files into a single JSONL file in GCS.
|
|
||||||
"""
|
|
||||||
from google.cloud import storage
|
|
||||||
|
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# Create a temporary file to aggregate all vector data
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w", delete=False, encoding="utf-8"
|
|
||||||
) as temp_agg_file:
|
|
||||||
logging.info(f"Aggregating vectors into temporary file: {temp_agg_file.name}")
|
|
||||||
for artifact_path in vector_artifacts:
|
|
||||||
with open(artifact_path, "r", encoding="utf-8") as f:
|
|
||||||
# Each line is a complete JSON object
|
|
||||||
for line in f:
|
|
||||||
temp_agg_file.write(line) # line already includes newline
|
|
||||||
|
|
||||||
temp_file_path = temp_agg_file.name
|
|
||||||
|
|
||||||
logging.info("Uploading aggregated file to GCS...")
|
|
||||||
gcs_client = storage.Client()
|
|
||||||
bucket_name, blob_name = output_gcs_path.replace("gs://", "").split("/", 1)
|
|
||||||
bucket = gcs_client.bucket(bucket_name)
|
|
||||||
blob = bucket.blob(blob_name)
|
|
||||||
blob.upload_from_filename(temp_file_path, content_type="application/json; charset=utf-8")
|
|
||||||
|
|
||||||
logging.info(f"Successfully uploaded aggregated vectors to {output_gcs_path}")
|
|
||||||
|
|
||||||
# Clean up the temporary file
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.remove(temp_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_vector_index(
|
|
||||||
vectors_dir: str,
|
|
||||||
):
|
|
||||||
"""Creates and deploys a Vertex AI Vector Search Index."""
|
|
||||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
|
||||||
|
|
||||||
from rag_eval.config import settings as config
|
|
||||||
|
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
|
||||||
|
|
||||||
try:
|
|
||||||
index_config = config.index
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Initializing Vertex AI client for project '{config.project_id}' in '{config.location}'..."
|
|
||||||
)
|
|
||||||
vector_search = GoogleCloudVectorSearch(
|
|
||||||
project_id=config.project_id,
|
|
||||||
location=config.location,
|
|
||||||
bucket=config.bucket,
|
|
||||||
index_name=index_config.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(f"Starting creation of index '{index_config.name}'...")
|
|
||||||
vector_search.create_index(
|
|
||||||
name=index_config.name,
|
|
||||||
content_path=vectors_dir,
|
|
||||||
dimensions=index_config.dimensions,
|
|
||||||
)
|
|
||||||
logging.info(f"Index '{index_config.name}' created successfully.")
|
|
||||||
|
|
||||||
logging.info("Deploying index to a new endpoint...")
|
|
||||||
vector_search.deploy_index(
|
|
||||||
index_name=index_config.name, machine_type=index_config.machine_type
|
|
||||||
)
|
|
||||||
logging.info("Index deployed successfully!")
|
|
||||||
logging.info(f"Endpoint name: {vector_search.index_endpoint.display_name}")
|
|
||||||
logging.info(
|
|
||||||
f"Endpoint resource name: {vector_search.index_endpoint.resource_name}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"An error occurred during index creation or deployment: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
@@ -1,29 +1,15 @@
|
|||||||
# Configuración de Google Cloud Platform
|
# Google Cloud Platform
|
||||||
project_id: "tu-proyecto-gcp"
|
project_id: "tu-proyecto-gcp"
|
||||||
location: "us-central1" # o us-east1, europe-west1, etc.
|
location: "us-central1"
|
||||||
bucket: "tu-bucket-nombre"
|
|
||||||
|
|
||||||
# Configuración del índice vectorial
|
# Embedding model
|
||||||
index:
|
agent_embedding_model: "text-embedding-005"
|
||||||
name: "mi-indice-rag"
|
|
||||||
dimensions: 768 # Para text-embedding-005 usa 768
|
|
||||||
machine_type: "e2-standard-2" # Tipo de máquina para el endpoint
|
|
||||||
approximate_neighbors_count: 150
|
|
||||||
distance_measure_type: "DOT_PRODUCT_DISTANCE" # O "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE"
|
|
||||||
|
|
||||||
# Configuración de embeddings
|
# Vector index
|
||||||
embedder:
|
index_name: "mi-indice-rag"
|
||||||
model_name: "text-embedding-005"
|
index_dimensions: 768
|
||||||
task: "RETRIEVAL_DOCUMENT" # O "RETRIEVAL_QUERY" para queries
|
index_machine_type: "e2-standard-16"
|
||||||
|
index_origin: "gs://tu-bucket/input/"
|
||||||
# Configuración de LLM para chunking
|
index_destination: "gs://tu-bucket/output/"
|
||||||
llm:
|
index_chunk_limit: 800
|
||||||
model: "gemini-2.0-flash" # O "gemini-1.5-pro", "gemini-1.5-flash"
|
index_distance_measure_type: "DOT_PRODUCT_DISTANCE"
|
||||||
|
|
||||||
# Configuración de chunking
|
|
||||||
chunking:
|
|
||||||
strategy: "contextual" # "recursive", "contextual", "llm"
|
|
||||||
max_chunk_size: 800
|
|
||||||
chunk_overlap: 200 # Solo para LLMChunker
|
|
||||||
merge_related: true # Solo para LLMChunker
|
|
||||||
extract_images: true # Solo para LLMChunker
|
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "chunker"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Add your description here"
|
|
||||||
readme = "README.md"
|
|
||||||
authors = [
|
|
||||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
|
||||||
]
|
|
||||||
requires-python = ">=3.12"
|
|
||||||
dependencies = [
|
|
||||||
"chonkie>=1.1.2",
|
|
||||||
"pdf2image>=1.17.0",
|
|
||||||
"pypdf>=6.0.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
llm-chunker = "chunker.llm_chunker:app"
|
|
||||||
recursive-chunker = "chunker.recursive_chunker:app"
|
|
||||||
contextual-chunker = "chunker.contextual_chunker:app"
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
|
||||||
build-backend = "uv_build"
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
3.10
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "document-converter"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Add your description here"
|
|
||||||
readme = "README.md"
|
|
||||||
authors = [
|
|
||||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
|
||||||
]
|
|
||||||
requires-python = ">=3.12"
|
|
||||||
dependencies = [
|
|
||||||
"markitdown[pdf]>=0.1.2",
|
|
||||||
"pypdf>=6.1.2",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
convert-md = "document_converter.markdown:app"
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
|
||||||
build-backend = "uv_build"
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
def hello() -> str:
|
|
||||||
return "Hello from document-converter!"
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
class BaseConverter(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for a remote file processor.
|
|
||||||
|
|
||||||
This class defines the interface for listing and processing files from a remote source.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def process_file(self, file: str) -> str:
|
|
||||||
"""
|
|
||||||
Processes a single file from a remote source and returns the result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file: The path to the file to be processed from the remote source.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string containing the processing result for the file.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def process_files(self, files: List[str]) -> List[str]:
|
|
||||||
"""
|
|
||||||
Processes a list of files from a remote source and returns the results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files: A list of file paths to be processed from the remote source.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of strings containing the processing results for each file.
|
|
||||||
"""
|
|
||||||
return [self.process_file(file) for file in files]
|
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Annotated, BinaryIO, Union
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from markitdown import MarkItDown
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.progress import Progress
|
|
||||||
|
|
||||||
from .base import BaseConverter
|
|
||||||
|
|
||||||
|
|
||||||
class MarkdownConverter(BaseConverter):
|
|
||||||
"""Converts PDF documents to Markdown format."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initializes the MarkItDown converter."""
|
|
||||||
self.markitdown = MarkItDown(enable_plugins=False)
|
|
||||||
|
|
||||||
def process_file(self, file_stream: Union[str, Path, BinaryIO]) -> str:
|
|
||||||
"""
|
|
||||||
Processes a single file and returns the result as a markdown string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_stream: A file path (string or Path) or a binary file stream.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The converted markdown content as a string.
|
|
||||||
"""
|
|
||||||
result = self.markitdown.convert(file_stream)
|
|
||||||
return result.text_content
|
|
||||||
|
|
||||||
|
|
||||||
# --- CLI Application ---
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def main(
|
|
||||||
input_path: Annotated[
|
|
||||||
Path,
|
|
||||||
typer.Argument(
|
|
||||||
help="Path to the input PDF file or directory.",
|
|
||||||
exists=True,
|
|
||||||
file_okay=True,
|
|
||||||
dir_okay=True,
|
|
||||||
readable=True,
|
|
||||||
resolve_path=True,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
output_path: Annotated[
|
|
||||||
Path,
|
|
||||||
typer.Argument(
|
|
||||||
help="Path for the output Markdown file or directory.",
|
|
||||||
file_okay=True,
|
|
||||||
dir_okay=True,
|
|
||||||
writable=True,
|
|
||||||
resolve_path=True,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Converts a PDF file or a directory of PDF files into Markdown.
|
|
||||||
"""
|
|
||||||
console = Console()
|
|
||||||
converter = MarkdownConverter()
|
|
||||||
|
|
||||||
if input_path.is_dir():
|
|
||||||
# --- Directory Processing ---
|
|
||||||
console.print(f"[bold green]Processing directory:[/bold green] {input_path}")
|
|
||||||
output_dir = output_path
|
|
||||||
|
|
||||||
if output_dir.exists() and not output_dir.is_dir():
|
|
||||||
console.print(
|
|
||||||
f"[bold red]Error:[/bold red] Input is a directory, but output path '{output_dir}' is an existing file."
|
|
||||||
)
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
pdf_files = sorted(list(input_path.rglob("*.pdf")))
|
|
||||||
if not pdf_files:
|
|
||||||
console.print("[yellow]No PDF files found in the input directory.[/yellow]")
|
|
||||||
return
|
|
||||||
|
|
||||||
console.print(f"Found {len(pdf_files)} PDF files to convert.")
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
with Progress(console=console) as progress:
|
|
||||||
task = progress.add_task("[cyan]Converting...", total=len(pdf_files))
|
|
||||||
for pdf_file in pdf_files:
|
|
||||||
relative_path = pdf_file.relative_to(input_path)
|
|
||||||
output_md_path = output_dir.joinpath(relative_path).with_suffix(".md")
|
|
||||||
output_md_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
progress.update(task, description=f"Processing {pdf_file.name}")
|
|
||||||
try:
|
|
||||||
markdown_content = converter.process_file(pdf_file)
|
|
||||||
output_md_path.write_text(markdown_content, encoding="utf-8")
|
|
||||||
except Exception as e:
|
|
||||||
console.print(
|
|
||||||
f"\n[bold red]Failed to process {pdf_file.name}:[/bold red] {e}"
|
|
||||||
)
|
|
||||||
progress.advance(task)
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
f"[bold green]Conversion complete.[/bold green] Output directory: {output_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif input_path.is_file():
|
|
||||||
# --- Single File Processing ---
|
|
||||||
console.print(f"[bold green]Processing file:[/bold green] {input_path.name}")
|
|
||||||
final_output_path = output_path
|
|
||||||
|
|
||||||
# If output path is a directory, create a file inside it
|
|
||||||
if output_path.is_dir():
|
|
||||||
final_output_path = output_path / input_path.with_suffix(".md").name
|
|
||||||
|
|
||||||
final_output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
try:
|
|
||||||
markdown_content = converter.process_file(input_path)
|
|
||||||
final_output_path.write_text(markdown_content, encoding="utf-8")
|
|
||||||
console.print(
|
|
||||||
f"[bold green]Successfully converted file to:[/bold green] {final_output_path}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
console.print(f"[bold red]Error processing file:[/bold red] {e}")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app()
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
3.10
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "embedder"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Add your description here"
|
|
||||||
readme = "README.md"
|
|
||||||
authors = [
|
|
||||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
|
||||||
]
|
|
||||||
requires-python = ">=3.12"
|
|
||||||
dependencies = [
|
|
||||||
"google-cloud-aiplatform>=1.106.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
|
||||||
build-backend = "uv_build"
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbedder(ABC):
|
|
||||||
"""Base class for all embedding models."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate_embedding(self, text: str) -> List[float]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Single text string or list of texts
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Single embedding vector or list of embedding vectors
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for a batch of texts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of text strings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embedding vectors
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def preprocess_text(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
*,
|
|
||||||
into_lowercase: bool = False,
|
|
||||||
normalize_whitespace: bool = True,
|
|
||||||
remove_punctuation: bool = False,
|
|
||||||
) -> str:
|
|
||||||
"""Preprocess text before embedding."""
|
|
||||||
# Basic preprocessing
|
|
||||||
text = text.strip()
|
|
||||||
|
|
||||||
if into_lowercase:
|
|
||||||
text = text.lower()
|
|
||||||
|
|
||||||
if normalize_whitespace:
|
|
||||||
text = " ".join(text.split())
|
|
||||||
|
|
||||||
if remove_punctuation:
|
|
||||||
import string
|
|
||||||
|
|
||||||
text = text.translate(str.maketrans("", "", string.punctuation))
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
def normalize_embedding(self, embedding: List[float]) -> List[float]:
|
|
||||||
"""Normalize embedding vector to unit length."""
|
|
||||||
norm = np.linalg.norm(embedding)
|
|
||||||
if norm > 0:
|
|
||||||
return (np.array(embedding) / norm).tolist()
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def async_generate_embedding(self, text: str) -> List[float]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Single text string or list of texts
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Single embedding vector or list of embedding vectors
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from google import genai
|
|
||||||
from google.genai import types
|
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
||||||
|
|
||||||
from .base import BaseEmbedder
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIEmbedder(BaseEmbedder):
|
|
||||||
"""Embedder using Vertex AI text embedding models."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, model_name: str, project: str, location: str, task: str = "RETRIEVAL_DOCUMENT"
|
|
||||||
) -> None:
|
|
||||||
self.model_name = model_name
|
|
||||||
self.client = genai.Client(
|
|
||||||
vertexai=True,
|
|
||||||
project=project,
|
|
||||||
location=location,
|
|
||||||
)
|
|
||||||
self.task = task
|
|
||||||
|
|
||||||
# @retry(
|
|
||||||
# stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=30)
|
|
||||||
# )
|
|
||||||
def generate_embedding(self, text: str) -> List[float]:
|
|
||||||
preprocessed_text = self.preprocess_text(text)
|
|
||||||
result = self.client.models.embed_content(
|
|
||||||
model=self.model_name, contents=preprocessed_text, config=types.EmbedContentConfig(task_type=self.task)
|
|
||||||
)
|
|
||||||
return result.embeddings[0].values
|
|
||||||
|
|
||||||
# @retry(
|
|
||||||
# stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=30)
|
|
||||||
# )
|
|
||||||
def generate_embeddings_batch(
|
|
||||||
self, texts: List[str], batch_size: int = 10
|
|
||||||
) -> List[List[float]]:
|
|
||||||
"""Generate embeddings for a batch of texts."""
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Preprocess texts
|
|
||||||
preprocessed_texts = [self.preprocess_text(text) for text in texts]
|
|
||||||
|
|
||||||
# Process in batches if necessary
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
for i in range(0, len(preprocessed_texts), batch_size):
|
|
||||||
batch = preprocessed_texts[i : i + batch_size]
|
|
||||||
|
|
||||||
# Generate embeddings for batch
|
|
||||||
result = self.client.models.embed_content(
|
|
||||||
model=self.model_name, contents=batch, config=types.EmbedContentConfig(task_type=self.task)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract values
|
|
||||||
batch_embeddings = [emb.values for emb in result.embeddings]
|
|
||||||
all_embeddings.extend(batch_embeddings)
|
|
||||||
|
|
||||||
# Rate limiting
|
|
||||||
if i + batch_size < len(preprocessed_texts):
|
|
||||||
time.sleep(0.1) # Small delay between batches
|
|
||||||
|
|
||||||
return all_embeddings
|
|
||||||
|
|
||||||
async def async_generate_embedding(self, text: str) -> List[float]:
|
|
||||||
preprocessed_text = self.preprocess_text(text)
|
|
||||||
result = await self.client.aio.models.embed_content(
|
|
||||||
model=self.model_name, contents=preprocessed_text, config=types.EmbedContentConfig(task_type=self.task)
|
|
||||||
)
|
|
||||||
return result.embeddings[0].values
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
3.10
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "file-storage"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Add your description here"
|
|
||||||
readme = "README.md"
|
|
||||||
authors = [
|
|
||||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
|
||||||
]
|
|
||||||
requires-python = ">=3.12"
|
|
||||||
dependencies = [
|
|
||||||
"gcloud-aio-storage>=9.6.1",
|
|
||||||
"google-cloud-storage>=2.19.0",
|
|
||||||
"aiohttp>=3.10.11,<4",
|
|
||||||
"typer>=0.12.3",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
file-storage = "file_storage.cli:app"
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
|
||||||
build-backend = "uv_build"
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
def hello() -> str:
|
|
||||||
return "Hello from file-storage!"
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import BinaryIO, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFileStorage(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for a remote file processor.
|
|
||||||
|
|
||||||
This class defines the interface for listing and processing files from a remote source.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def upload_file(
|
|
||||||
self,
|
|
||||||
file_path: str,
|
|
||||||
destination_blob_name: str,
|
|
||||||
content_type: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Uploads a file to the remote source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: The local path to the file to upload.
|
|
||||||
destination_blob_name: The name of the file in the remote source.
|
|
||||||
content_type: The content type of the file.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_files(self, path: Optional[str] = None) -> List[str]:
|
|
||||||
"""
|
|
||||||
Lists files from a remote location.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: The path to a specific file or directory in the remote bucket.
|
|
||||||
If None, it recursively lists all files in the bucket.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of file paths.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
|
||||||
"""
|
|
||||||
Gets a file from the remote source and returns it as a file-like object.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import rich
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from rag_eval.config import settings
|
|
||||||
|
|
||||||
from .google_cloud import GoogleCloudFileStorage
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
def get_storage_client() -> GoogleCloudFileStorage:
|
|
||||||
return GoogleCloudFileStorage(bucket=settings.bucket)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("upload")
|
|
||||||
def upload(
|
|
||||||
file_path: str,
|
|
||||||
destination_blob_name: str,
|
|
||||||
content_type: Annotated[str, typer.Option()] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Uploads a file or directory to the remote source.
|
|
||||||
"""
|
|
||||||
storage_client = get_storage_client()
|
|
||||||
if os.path.isdir(file_path):
|
|
||||||
for root, _, files in os.walk(file_path):
|
|
||||||
for file in files:
|
|
||||||
local_file_path = os.path.join(root, file)
|
|
||||||
# preserve the directory structure and use forward slashes for blob name
|
|
||||||
dest_blob_name = os.path.join(
|
|
||||||
destination_blob_name, os.path.relpath(local_file_path, file_path)
|
|
||||||
).replace(os.sep, "/")
|
|
||||||
storage_client.upload_file(
|
|
||||||
local_file_path, dest_blob_name, content_type
|
|
||||||
)
|
|
||||||
rich.print(
|
|
||||||
f"[green]File {local_file_path} uploaded to {dest_blob_name}.[/green]"
|
|
||||||
)
|
|
||||||
rich.print(
|
|
||||||
f"[bold green]Directory {file_path} uploaded to {destination_blob_name}.[/bold green]"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
storage_client.upload_file(file_path, destination_blob_name, content_type)
|
|
||||||
rich.print(
|
|
||||||
f"[green]File {file_path} uploaded to {destination_blob_name}.[/green]"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("list")
|
|
||||||
def list_items(path: Annotated[str, typer.Option()] = None):
|
|
||||||
"""
|
|
||||||
Obtain a list of all files at the given location inside the remote bucket
|
|
||||||
If path is none, recursively shows all files in the remote bucket.
|
|
||||||
"""
|
|
||||||
storage_client = get_storage_client()
|
|
||||||
files = storage_client.list_files(path)
|
|
||||||
for file in files:
|
|
||||||
rich.print(f"[blue]{file}[/blue]")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("download")
|
|
||||||
def download(file_name: str, destination_path: str):
|
|
||||||
"""
|
|
||||||
Gets a file from the remote source and returns it as a file-like object.
|
|
||||||
"""
|
|
||||||
storage_client = get_storage_client()
|
|
||||||
file_stream = storage_client.get_file_stream(file_name)
|
|
||||||
with open(destination_path, "wb") as f:
|
|
||||||
f.write(file_stream.read())
|
|
||||||
rich.print(f"[green]File {file_name} downloaded to {destination_path}[/green]")
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("delete")
|
|
||||||
def delete(path: str):
|
|
||||||
"""
|
|
||||||
Deletes all files at the given location inside the remote bucket.
|
|
||||||
If path is a single file, it will delete only that file.
|
|
||||||
If path is a directory, it will delete all files in that directory.
|
|
||||||
"""
|
|
||||||
storage_client = get_storage_client()
|
|
||||||
storage_client.delete_files(path)
|
|
||||||
rich.print(f"[bold red]Files at {path} deleted.[/bold red]")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app()
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
from typing import BinaryIO, List, Optional
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from gcloud.aio.storage import Storage
|
|
||||||
from google.cloud import storage
|
|
||||||
|
|
||||||
from .base import BaseFileStorage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleCloudFileStorage(BaseFileStorage):
|
|
||||||
def __init__(self, bucket: str) -> None:
|
|
||||||
self.bucket_name = bucket
|
|
||||||
|
|
||||||
self.storage_client = storage.Client()
|
|
||||||
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
|
||||||
self._aio_session: aiohttp.ClientSession | None = None
|
|
||||||
self._aio_storage: Storage | None = None
|
|
||||||
self._cache: dict[str, bytes] = {}
|
|
||||||
|
|
||||||
def upload_file(
|
|
||||||
self,
|
|
||||||
file_path: str,
|
|
||||||
destination_blob_name: str,
|
|
||||||
content_type: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Uploads a file to the remote source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: The local path to the file to upload.
|
|
||||||
destination_blob_name: The name of the file in the remote source.
|
|
||||||
content_type: The content type of the file.
|
|
||||||
"""
|
|
||||||
blob = self.bucket_client.blob(destination_blob_name)
|
|
||||||
blob.upload_from_filename(
|
|
||||||
file_path,
|
|
||||||
content_type=content_type,
|
|
||||||
if_generation_match=0,
|
|
||||||
)
|
|
||||||
self._cache.pop(destination_blob_name, None)
|
|
||||||
|
|
||||||
def list_files(self, path: Optional[str] = None) -> List[str]:
|
|
||||||
"""
|
|
||||||
Obtain a list of all files at the given location inside the remote bucket
|
|
||||||
If path is none, recursively shows all files in the remote bucket.
|
|
||||||
"""
|
|
||||||
blobs = self.storage_client.list_blobs(self.bucket_name, prefix=path)
|
|
||||||
return [blob.name for blob in blobs]
|
|
||||||
|
|
||||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
|
||||||
"""
|
|
||||||
Gets a file from the remote source and returns it as a file-like object.
|
|
||||||
"""
|
|
||||||
if file_name not in self._cache:
|
|
||||||
blob = self.bucket_client.blob(file_name)
|
|
||||||
self._cache[file_name] = blob.download_as_bytes()
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._aio_session is None or self._aio_session.closed:
|
|
||||||
connector = aiohttp.TCPConnector(limit=300, limit_per_host=50)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=60)
|
|
||||||
self._aio_session = aiohttp.ClientSession(
|
|
||||||
timeout=timeout, connector=connector
|
|
||||||
)
|
|
||||||
return self._aio_session
|
|
||||||
|
|
||||||
def _get_aio_storage(self) -> Storage:
|
|
||||||
if self._aio_storage is None:
|
|
||||||
self._aio_storage = Storage(session=self._get_aio_session())
|
|
||||||
return self._aio_storage
|
|
||||||
|
|
||||||
async def async_get_file_stream(
|
|
||||||
self, file_name: str, max_retries: int = 3
|
|
||||||
) -> BinaryIO:
|
|
||||||
"""
|
|
||||||
Gets a file from the remote source asynchronously and returns it as a file-like object.
|
|
||||||
Retries on transient errors (429, 5xx, timeouts) with exponential backoff.
|
|
||||||
"""
|
|
||||||
if file_name in self._cache:
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
storage_client = self._get_aio_storage()
|
|
||||||
last_exception: Exception | None = None
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
self._cache[file_name] = await storage_client.download(
|
|
||||||
self.bucket_name, file_name
|
|
||||||
)
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
return file_stream
|
|
||||||
except asyncio.TimeoutError as exc:
|
|
||||||
last_exception = exc
|
|
||||||
logger.warning(
|
|
||||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
|
||||||
self.bucket_name, file_name, attempt + 1, max_retries,
|
|
||||||
)
|
|
||||||
except aiohttp.ClientResponseError as exc:
|
|
||||||
last_exception = exc
|
|
||||||
if exc.status == 429 or exc.status >= 500:
|
|
||||||
logger.warning(
|
|
||||||
"HTTP %d downloading gs://%s/%s (attempt %d/%d)",
|
|
||||||
exc.status, self.bucket_name, file_name,
|
|
||||||
attempt + 1, max_retries,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
delay = 0.5 * (2 ** attempt)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
raise TimeoutError(
|
|
||||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
|
||||||
f"after {max_retries} attempts"
|
|
||||||
) from last_exception
|
|
||||||
|
|
||||||
def delete_files(self, path: str) -> None:
|
|
||||||
"""
|
|
||||||
Deletes all files at the given location inside the remote bucket.
|
|
||||||
If path is a single file, it will delete only that file.
|
|
||||||
If path is a directory, it will delete all files in that directory.
|
|
||||||
"""
|
|
||||||
blobs = self.storage_client.list_blobs(self.bucket_name, prefix=path)
|
|
||||||
for blob in blobs:
|
|
||||||
blob.delete()
|
|
||||||
self._cache.pop(blob.name, None)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
3.10
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "llm"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Add your description here"
|
|
||||||
readme = "README.md"
|
|
||||||
authors = [
|
|
||||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
|
||||||
]
|
|
||||||
requires-python = ">=3.12"
|
|
||||||
dependencies = [
|
|
||||||
"google-genai>=1.20.0",
|
|
||||||
"pydantic>=2.11.7",
|
|
||||||
"tenacity>=9.1.2",
|
|
||||||
]
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
|
||||||
build-backend = "uv_build"
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
def hello() -> str:
|
|
||||||
return "Hello from llm!"
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Type, TypeVar
|
|
||||||
|
|
||||||
from pydantic import BaseModel, field_validator
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
|
||||||
name: str
|
|
||||||
arguments: dict
|
|
||||||
|
|
||||||
class Usage(BaseModel):
|
|
||||||
prompt_tokens: int | None = 0
|
|
||||||
thought_tokens: int | None = 0
|
|
||||||
response_tokens: int | None = 0
|
|
||||||
|
|
||||||
@field_validator("prompt_tokens", "thought_tokens", "response_tokens", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _validate_tokens(cls, v: int | None) -> int:
|
|
||||||
return v or 0
|
|
||||||
|
|
||||||
def __add__(self, other):
|
|
||||||
return Usage(
|
|
||||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
|
||||||
thought_tokens=self.thought_tokens + other.thought_tokens,
|
|
||||||
response_tokens=self.response_tokens + other.response_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_cost(self, name: str) -> int:
|
|
||||||
million = 1000000
|
|
||||||
if name == "gemini-2.5-pro":
|
|
||||||
if self.prompt_tokens > 200000:
|
|
||||||
input_cost = self.prompt_tokens * (2.5/million)
|
|
||||||
output_cost = self.thought_tokens * (15/million) + self.response_tokens * (15/million)
|
|
||||||
else:
|
|
||||||
input_cost = self.prompt_tokens * (1.25/million)
|
|
||||||
output_cost = self.thought_tokens * (10/million) + self.response_tokens * (10/million)
|
|
||||||
return (input_cost + output_cost) * 18.65
|
|
||||||
if name == "gemini-2.5-flash":
|
|
||||||
input_cost = self.prompt_tokens * (0.30/million)
|
|
||||||
output_cost = self.thought_tokens * (2.5/million) + self.response_tokens * (2.5/million)
|
|
||||||
return (input_cost + output_cost) * 18.65
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid model")
|
|
||||||
|
|
||||||
|
|
||||||
class Generation(BaseModel):
|
|
||||||
"""A class to represent a single generation from a model.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
text: The generated text.
|
|
||||||
usage: A dictionary containing usage metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
text: str | None = None
|
|
||||||
tool_calls: list[ToolCall] | None = None
|
|
||||||
usage: Usage = Usage()
|
|
||||||
extra: dict = {}
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(ABC):
|
|
||||||
"""An abstract base class for all LLMs."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: Any,
|
|
||||||
tools: list | None = None,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
) -> Generation:
|
|
||||||
"""Generates text from a prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to use for generation.
|
|
||||||
prompt: The prompt to generate text from.
|
|
||||||
tools: An optional list of tools to use for generation.
|
|
||||||
system_prompt: An optional system prompt to guide the model's behavior.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Generation object containing the generated text and usage metadata.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def structured_generation(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: Any,
|
|
||||||
response_model: Type[T],
|
|
||||||
tools: list | None = None,
|
|
||||||
) -> T:
|
|
||||||
"""Generates structured data from a prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to use for generation.
|
|
||||||
prompt: The prompt to generate text from.
|
|
||||||
response_model: The pydantic model to parse the response into.
|
|
||||||
tools: An optional list of tools to use for generation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of the provided pydantic model.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def async_generate(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: Any,
|
|
||||||
tools: list | None = None,
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
tool_mode: str = "AUTO",
|
|
||||||
) -> Generation:
|
|
||||||
"""Generates text from a prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model to use for generation.
|
|
||||||
prompt: The prompt to generate text from.
|
|
||||||
tools: An optional list of tools to use for generation.
|
|
||||||
system_prompt: An optional system prompt to guide the model's behavior.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Generation object containing the generated text and usage metadata.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@@ -1,181 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Any, Type
|
|
||||||
|
|
||||||
from google import genai
|
|
||||||
from google.genai import types
|
|
||||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
||||||
|
|
||||||
from rag_eval.config import settings
|
|
||||||
|
|
||||||
from .base import BaseLLM, Generation, T, ToolCall, Usage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAILLM(BaseLLM):
|
|
||||||
"""A class for interacting with the Vertex AI API."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, project: str | None = None, location: str | None = None, thinking: int = 0
|
|
||||||
) -> None:
|
|
||||||
"""Initializes the VertexAILLM client.
|
|
||||||
Args:
|
|
||||||
project: The Google Cloud project ID.
|
|
||||||
location: The Google Cloud location.
|
|
||||||
"""
|
|
||||||
self.client = genai.Client(
|
|
||||||
vertexai=True,
|
|
||||||
project=project or settings.project_id,
|
|
||||||
location=location or settings.location,
|
|
||||||
)
|
|
||||||
self.thinking_budget = thinking
|
|
||||||
|
|
||||||
# @retry(
|
|
||||||
# wait=wait_exponential(multiplier=1, min=2, max=60),
|
|
||||||
# stop=stop_after_attempt(3),
|
|
||||||
# reraise=True,
|
|
||||||
# )
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: Any,
|
|
||||||
tools: list = [],
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
tool_mode: str = "AUTO",
|
|
||||||
) -> Generation:
|
|
||||||
"""Generates text using the specified model and prompt.
|
|
||||||
Args:
|
|
||||||
model: The name of the model to use for generation.
|
|
||||||
prompt: The prompt to use for generation.
|
|
||||||
tools: A list of tools to use for generation.
|
|
||||||
system_prompt: An optional system prompt to guide the model's behavior.
|
|
||||||
Returns:
|
|
||||||
A Generation object containing the generated text and usage metadata.
|
|
||||||
"""
|
|
||||||
logger.debug("Entering VertexAILLM.generate")
|
|
||||||
logger.debug(f"Model: {model}, Tool Mode: {tool_mode}")
|
|
||||||
logger.debug(f"System prompt: {system_prompt}")
|
|
||||||
logger.debug("Calling Vertex AI API: models.generate_content...")
|
|
||||||
response = self.client.models.generate_content(
|
|
||||||
model=model,
|
|
||||||
contents=prompt,
|
|
||||||
config=types.GenerateContentConfig(
|
|
||||||
tools=tools,
|
|
||||||
system_instruction=system_prompt,
|
|
||||||
thinking_config=genai.types.ThinkingConfig(
|
|
||||||
thinking_budget=self.thinking_budget
|
|
||||||
),
|
|
||||||
tool_config=types.ToolConfig(
|
|
||||||
function_calling_config=types.FunctionCallingConfig(
|
|
||||||
mode=tool_mode
|
|
||||||
)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
logger.debug("Received response from Vertex AI API.")
|
|
||||||
logger.debug(f"API Response: {response}")
|
|
||||||
|
|
||||||
return self._create_generation(response)
|
|
||||||
|
|
||||||
|
|
||||||
# @retry(
|
|
||||||
# wait=wait_exponential(multiplier=1, min=2, max=60),
|
|
||||||
# stop=stop_after_attempt(3),
|
|
||||||
# reraise=True,
|
|
||||||
# )
|
|
||||||
def structured_generation(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: Any,
|
|
||||||
response_model: Type[T],
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
tools: list | None = None,
|
|
||||||
) -> T:
|
|
||||||
"""Generates structured data from a prompt.
|
|
||||||
Args:
|
|
||||||
model: The model to use for generation.
|
|
||||||
prompt: The prompt to generate text from.
|
|
||||||
response_model: The pydantic model to parse the response into.
|
|
||||||
tools: An optional list of tools to use for generation.
|
|
||||||
Returns:
|
|
||||||
An instance of the provided pydantic model.
|
|
||||||
"""
|
|
||||||
config = genai.types.GenerateContentConfig(
|
|
||||||
response_mime_type="application/json",
|
|
||||||
response_schema=response_model,
|
|
||||||
system_instruction=system_prompt,
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
response: genai.types.GenerateContentResponse = (
|
|
||||||
self.client.models.generate_content(
|
|
||||||
model=model, contents=prompt, config=config
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return response_model.model_validate_json(response.text)
|
|
||||||
|
|
||||||
# @retry(
|
|
||||||
# wait=wait_exponential(multiplier=1, min=2, max=60),
|
|
||||||
# stop=stop_after_attempt(3),
|
|
||||||
# reraise=True,
|
|
||||||
# )
|
|
||||||
async def async_generate(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: Any,
|
|
||||||
tools: list = [],
|
|
||||||
system_prompt: str | None = None,
|
|
||||||
tool_mode: str = "AUTO",
|
|
||||||
) -> Generation:
|
|
||||||
response = await self.client.aio.models.generate_content(
|
|
||||||
model=model,
|
|
||||||
contents=prompt,
|
|
||||||
config=types.GenerateContentConfig(
|
|
||||||
tools=tools,
|
|
||||||
system_instruction=system_prompt,
|
|
||||||
thinking_config=genai.types.ThinkingConfig(
|
|
||||||
thinking_budget=self.thinking_budget
|
|
||||||
),
|
|
||||||
tool_config=types.ToolConfig(
|
|
||||||
function_calling_config=types.FunctionCallingConfig(
|
|
||||||
mode=tool_mode
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._create_generation(response)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_generation(self, response):
|
|
||||||
logger.debug("Creating Generation object from API response.")
|
|
||||||
m=response.usage_metadata
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=m.prompt_token_count,
|
|
||||||
thought_tokens=m.thoughts_token_count or 0,
|
|
||||||
response_tokens=m.candidates_token_count
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"{usage=}")
|
|
||||||
logger.debug(f"{response=}")
|
|
||||||
|
|
||||||
candidate = response.candidates[0]
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
|
|
||||||
for part in candidate.content.parts:
|
|
||||||
if fn := part.function_call:
|
|
||||||
tool_calls.append(ToolCall(name=fn.name, arguments=fn.args))
|
|
||||||
|
|
||||||
if len(tool_calls) > 0:
|
|
||||||
logger.debug(f"Found {len(tool_calls)} tool calls.")
|
|
||||||
return Generation(
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
usage=usage,
|
|
||||||
extra={"original_content": candidate.content}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("No tool calls found, returning text response.")
|
|
||||||
text = candidate.content.parts[0].text
|
|
||||||
return Generation(text=text, usage=usage)
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "utils"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Add your description here"
|
|
||||||
readme = "README.md"
|
|
||||||
authors = [
|
|
||||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
|
||||||
]
|
|
||||||
requires-python = ">=3.12"
|
|
||||||
dependencies = []
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
normalize-filenames = "utils.normalize_filenames:app"
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
|
||||||
build-backend = "uv_build"
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
def hello() -> str:
|
|
||||||
return "Hello from utils!"
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
"""Normalize filenames in a directory."""
|
|
||||||
|
|
||||||
import pathlib
|
|
||||||
import re
|
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.table import Table
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_string(s: str) -> str:
|
|
||||||
"""Normalizes a string to be a valid filename."""
|
|
||||||
# 1. Decompose Unicode characters into base characters and diacritics
|
|
||||||
nfkd_form = unicodedata.normalize("NFKD", s)
|
|
||||||
# 2. Keep only the base characters (non-diacritics)
|
|
||||||
only_ascii = "".join([c for c in nfkd_form if not unicodedata.combining(c)])
|
|
||||||
# 3. To lowercase
|
|
||||||
only_ascii = only_ascii.lower()
|
|
||||||
# 4. Replace spaces with underscores
|
|
||||||
only_ascii = re.sub(r"\s+", "_", only_ascii)
|
|
||||||
# 5. Remove any characters that are not alphanumeric, underscores, dots, or hyphens
|
|
||||||
only_ascii = re.sub(r"[^a-z0-9_.-]", "", only_ascii)
|
|
||||||
return only_ascii
|
|
||||||
|
|
||||||
|
|
||||||
def truncate_string(s: str) -> str:
|
|
||||||
"""given a string with /, return a string with only the text after the last /"""
|
|
||||||
return pathlib.Path(s).name
|
|
||||||
|
|
||||||
|
|
||||||
def remove_extension(s: str) -> str:
|
|
||||||
"""Given a string, if it has a extension like .pdf, remove it and return the new string"""
|
|
||||||
return str(pathlib.Path(s).with_suffix(""))
|
|
||||||
|
|
||||||
|
|
||||||
def remove_duplicate_vowels(s: str) -> str:
|
|
||||||
"""Removes consecutive duplicate vowels (a, e, i, o, u) from a string."""
|
|
||||||
return re.sub(r"([aeiou])\1+", r"\1", s, flags=re.IGNORECASE)
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback(invoke_without_command=True)
|
|
||||||
def normalize_filenames(
|
|
||||||
directory: str = typer.Argument(
|
|
||||||
..., help="The path to the directory containing files to normalize."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""Normalizes all filenames in a directory."""
|
|
||||||
console = Console()
|
|
||||||
console.print(
|
|
||||||
Panel(
|
|
||||||
f"Normalizing filenames in directory: [bold cyan]{directory}[/bold cyan]",
|
|
||||||
title="[bold green]Filename Normalizer[/bold green]",
|
|
||||||
expand=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
source_path = pathlib.Path(directory)
|
|
||||||
if not source_path.is_dir():
|
|
||||||
console.print(f"[bold red]Error: Directory not found at {directory}[/bold red]")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
files_to_rename = [p for p in source_path.rglob("*") if p.is_file()]
|
|
||||||
|
|
||||||
if not files_to_rename:
|
|
||||||
console.print(
|
|
||||||
f"[bold yellow]No files found in {directory} to normalize.[/bold yellow]"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
table = Table(title="File Renaming Summary")
|
|
||||||
table.add_column("Original Name", style="cyan", no_wrap=True)
|
|
||||||
table.add_column("New Name", style="magenta", no_wrap=True)
|
|
||||||
table.add_column("Status", style="green")
|
|
||||||
|
|
||||||
for file_path in files_to_rename:
|
|
||||||
original_name = file_path.name
|
|
||||||
file_stem = file_path.stem
|
|
||||||
file_suffix = file_path.suffix
|
|
||||||
|
|
||||||
normalized_stem = normalize_string(file_stem)
|
|
||||||
new_name = f"{normalized_stem}{file_suffix}"
|
|
||||||
|
|
||||||
if new_name == original_name:
|
|
||||||
table.add_row(
|
|
||||||
original_name, new_name, "[yellow]Skipped (No change)[/yellow]"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_path = file_path.with_name(new_name)
|
|
||||||
|
|
||||||
# Handle potential name collisions
|
|
||||||
counter = 1
|
|
||||||
while new_path.exists():
|
|
||||||
new_name = f"{normalized_stem}_{counter}{file_suffix}"
|
|
||||||
new_path = file_path.with_name(new_name)
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
file_path.rename(new_path)
|
|
||||||
table.add_row(original_name, new_name, "[green]Renamed[/green]")
|
|
||||||
except OSError as e:
|
|
||||||
table.add_row(original_name, new_name, f"[bold red]Error: {e}[/bold red]")
|
|
||||||
|
|
||||||
console.print(table)
|
|
||||||
console.print(
|
|
||||||
Panel(
|
|
||||||
f"[bold]Normalization complete.[/bold] Processed [bold blue]{len(files_to_rename)}[/bold blue] files.",
|
|
||||||
title="[bold green]Complete[/bold green]",
|
|
||||||
expand=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
3.10
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
[project]
|
|
||||||
name = "vector-search"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Add your description here"
|
|
||||||
readme = "README.md"
|
|
||||||
authors = [
|
|
||||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
|
||||||
]
|
|
||||||
requires-python = ">=3.12"
|
|
||||||
dependencies = [
|
|
||||||
"embedder",
|
|
||||||
"file-storage",
|
|
||||||
"google-cloud-aiplatform>=1.106.0",
|
|
||||||
"aiohttp>=3.10.11,<4",
|
|
||||||
"gcloud-aio-auth>=5.3.0",
|
|
||||||
"google-auth==2.29.0",
|
|
||||||
"typer>=0.16.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
vector-search = "vector_search.cli:app"
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
|
||||||
build-backend = "uv_build"
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
|
||||||
file-storage = { workspace = true }
|
|
||||||
embedder = { workspace = true }
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
def hello() -> str:
|
|
||||||
return "Hello from vector-search!"
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List, TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResult(TypedDict):
|
|
||||||
id: str
|
|
||||||
distance: float
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class BaseVectorSearch(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for a vector search provider.
|
|
||||||
|
|
||||||
This class defines the standard interface for creating a vector search index
|
|
||||||
and running queries against it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_index(self, name: str, content_path: str, **kwargs) -> None:
|
|
||||||
"""
|
|
||||||
Creates a new vector search index and populates it with the provided content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The desired name for the new index.
|
|
||||||
content_path: The local file system path to the data that will be used to
|
|
||||||
populate the index. This is expected to be a JSON file
|
|
||||||
containing a list of objects, each with an 'id', 'name',
|
|
||||||
and 'embedding' key.
|
|
||||||
**kwargs: Additional provider-specific arguments for index creation.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_index(self, index_name: str, content_path: str, **kwargs) -> None:
|
|
||||||
"""
|
|
||||||
Updates an existing vector search index with new content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The name of the index to update.
|
|
||||||
content_path: The local file system path to the data that will be used to
|
|
||||||
populate the index.
|
|
||||||
**kwargs: Additional provider-specific arguments for index update.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run_query(
|
|
||||||
self, index: str, query: List[float], limit: int
|
|
||||||
) -> List[SearchResult]:
|
|
||||||
"""
|
|
||||||
Runs a similarity search query against the index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The embedding vector to use for the search query.
|
|
||||||
limit: The maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of dictionaries, where each dictionary represents a matched item
|
|
||||||
and contains at least the item's 'id' and the search 'distance'.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
from typer import Typer
|
|
||||||
|
|
||||||
from .create import app as create_callback
|
|
||||||
from .delete import app as delete_callback
|
|
||||||
from .query import app as query_callback
|
|
||||||
|
|
||||||
app = Typer()
|
|
||||||
app.add_typer(create_callback, name="create")
|
|
||||||
app.add_typer(delete_callback, name="delete")
|
|
||||||
app.add_typer(query_callback, name="query")
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
"""Create and deploy a Vertex AI Vector Search index."""
|
|
||||||
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from rich.console import Console
|
|
||||||
|
|
||||||
from rag_eval.config import settings as config
|
|
||||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback(invoke_without_command=True)
|
|
||||||
def create(
|
|
||||||
path: Annotated[
|
|
||||||
str,
|
|
||||||
typer.Option(
|
|
||||||
"--path",
|
|
||||||
"-p",
|
|
||||||
help="The GCS URI (gs://...) to the directory containing your embedding JSON file(s).",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
agent_name: Annotated[
|
|
||||||
str,
|
|
||||||
typer.Option(
|
|
||||||
"--agent",
|
|
||||||
"-a",
|
|
||||||
help="The name of the agent to create the index for.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
):
|
|
||||||
"""Create and deploy a Vertex AI Vector Search index for a specific agent."""
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
try:
|
|
||||||
console.print(
|
|
||||||
f"[bold green]Looking up configuration for agent '{agent_name}'...[/bold green]"
|
|
||||||
)
|
|
||||||
agent_config = config.agents.get(agent_name)
|
|
||||||
if not agent_config:
|
|
||||||
console.print(
|
|
||||||
f"[bold red]Agent '{agent_name}' not found in settings.[/bold red]"
|
|
||||||
)
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
if not agent_config.index:
|
|
||||||
console.print(
|
|
||||||
f"[bold red]Index configuration not found for agent '{agent_name}'.[/bold red]"
|
|
||||||
)
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
index_config = agent_config.index
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
f"[bold green]Initializing Vertex AI client for project '{config.project_id}' in '{config.location}'...[/bold green]"
|
|
||||||
)
|
|
||||||
vector_search = GoogleCloudVectorSearch(
|
|
||||||
project_id=config.project_id,
|
|
||||||
location=config.location,
|
|
||||||
bucket=config.bucket,
|
|
||||||
index_name=index_config.name,
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
f"[bold green]Starting creation of index '{index_config.name}'...[/bold green]"
|
|
||||||
)
|
|
||||||
console.print("This may take a while.")
|
|
||||||
vector_search.create_index(
|
|
||||||
name=index_config.name,
|
|
||||||
content_path=f"gs://{config.bucket}/{path}",
|
|
||||||
dimensions=index_config.dimensions,
|
|
||||||
)
|
|
||||||
console.print(
|
|
||||||
f"[bold green]Index '{index_config.name}' created successfully.[/bold green]"
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print("[bold green]Deploying index to a new endpoint...[/bold green]")
|
|
||||||
console.print("This will also take some time.")
|
|
||||||
vector_search.deploy_index(
|
|
||||||
index_name=index_config.name, machine_type=index_config.machine_type
|
|
||||||
)
|
|
||||||
console.print("[bold green]Index deployed successfully![/bold green]")
|
|
||||||
console.print(f"Endpoint name: {vector_search.index_endpoint.display_name}")
|
|
||||||
console.print(
|
|
||||||
f"Endpoint resource name: {vector_search.index_endpoint.resource_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
console.print(f"[bold red]An error occurred: {e}[/bold red]")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
"""Delete a vector index or endpoint."""
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from rich.console import Console
|
|
||||||
|
|
||||||
from rag_eval.config import settings as config
|
|
||||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback(invoke_without_command=True)
|
|
||||||
def delete(
|
|
||||||
id: str = typer.Argument(..., help="The ID of the index or endpoint to delete."),
|
|
||||||
endpoint: bool = typer.Option(
|
|
||||||
False, "--endpoint", help="Delete an endpoint instead of an index."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""Delete a vector index or endpoint."""
|
|
||||||
console = Console()
|
|
||||||
vector_search = GoogleCloudVectorSearch(
|
|
||||||
project_id=config.project_id, location=config.location, bucket=config.bucket
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if endpoint:
|
|
||||||
console.print(f"[bold red]Deleting endpoint {id}...[/bold red]")
|
|
||||||
vector_search.delete_index_endpoint(id)
|
|
||||||
console.print(
|
|
||||||
f"[bold green]Endpoint {id} deleted successfully.[/bold green]"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
console.print(f"[bold red]Deleting index {id}...[/bold red]")
|
|
||||||
vector_search.delete_index(id)
|
|
||||||
console.print(f"[bold green]Index {id} deleted successfully.[/bold green]")
|
|
||||||
except Exception as e:
|
|
||||||
console.print(f"[bold red]An error occurred: {e}[/bold red]")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
"""Generate embeddings for documents and save them to a JSON file."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from embedder.vertex_ai import VertexAIEmbedder
|
|
||||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.progress import Progress
|
|
||||||
|
|
||||||
from rag_eval.config import Settings
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback(invoke_without_command=True)
|
|
||||||
def generate(
|
|
||||||
path: str = typer.Argument(..., help="The path to the markdown files."),
|
|
||||||
output_file: str = typer.Option(
|
|
||||||
...,
|
|
||||||
"--output-file",
|
|
||||||
"-o",
|
|
||||||
help="The local path to save the output JSON file.",
|
|
||||||
),
|
|
||||||
batch_size: int = typer.Option(
|
|
||||||
10,
|
|
||||||
"--batch-size",
|
|
||||||
"-b",
|
|
||||||
help="The batch size for processing files.",
|
|
||||||
),
|
|
||||||
jsonl: bool = typer.Option(
|
|
||||||
False,
|
|
||||||
"--jsonl",
|
|
||||||
help="Output in JSONL format instead of JSON.",
|
|
||||||
),
|
|
||||||
):
|
|
||||||
"""Generate embeddings for documents and save them to a JSON file."""
|
|
||||||
config = Settings()
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
console.print("[bold green]Starting vector generation...[/bold green]")
|
|
||||||
|
|
||||||
try:
|
|
||||||
storage = GoogleCloudFileStorage(bucket=config.bucket)
|
|
||||||
embedder = VertexAIEmbedder(model_name=config.embedding_model)
|
|
||||||
|
|
||||||
remote_files = storage.list_files(path=path)
|
|
||||||
results = []
|
|
||||||
|
|
||||||
with Progress(console=console) as progress:
|
|
||||||
task = progress.add_task(
|
|
||||||
"[cyan]Generating embeddings...", total=len(remote_files)
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(0, len(remote_files), batch_size):
|
|
||||||
batch_files = remote_files[i : i + batch_size]
|
|
||||||
batch_contents = []
|
|
||||||
|
|
||||||
for remote_file in batch_files:
|
|
||||||
file_stream = storage.get_file_stream(remote_file)
|
|
||||||
batch_contents.append(
|
|
||||||
file_stream.read().decode("utf-8-sig", errors="replace")
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_embeddings = embedder.generate_embeddings_batch(batch_contents)
|
|
||||||
|
|
||||||
for j, remote_file in enumerate(batch_files):
|
|
||||||
results.append(
|
|
||||||
{"id": remote_file, "embedding": batch_embeddings[j]}
|
|
||||||
)
|
|
||||||
progress.update(task, advance=1)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
console.print(
|
|
||||||
f"[bold red]An error occurred during vector generation: {e}[/bold red]"
|
|
||||||
)
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
|
|
||||||
output_path = Path(output_file)
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(output_path, "w") as f:
|
|
||||||
if jsonl:
|
|
||||||
for record in results:
|
|
||||||
f.write(json.dumps(record) + "\n")
|
|
||||||
else:
|
|
||||||
json.dump(results, f, indent=2)
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
f"[bold green]Embedding generation complete. {len(results)} vectors saved to '{output_path.resolve()}'[/bold green]"
|
|
||||||
)
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
"""Query the vector search index."""
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from embedder.vertex_ai import VertexAIEmbedder
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.table import Table
|
|
||||||
from typer import Argument, Option
|
|
||||||
|
|
||||||
from rag_eval.config import settings as config
|
|
||||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback(invoke_without_command=True)
|
|
||||||
def query(
|
|
||||||
query: str = Argument(..., help="The text query to search for."),
|
|
||||||
limit: int = Option(5, "--limit", "-l", help="The number of results to return."),
|
|
||||||
):
|
|
||||||
"""Queries the vector search index."""
|
|
||||||
console = Console()
|
|
||||||
|
|
||||||
try:
|
|
||||||
console.print("[bold green]Initializing clients...[/bold green]")
|
|
||||||
embedder = VertexAIEmbedder(model_name=config.embedding_model)
|
|
||||||
vector_search = GoogleCloudVectorSearch(
|
|
||||||
project_id=config.project_id, location=config.location, bucket=config.bucket
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print("[bold green]Loading index endpoint...[/bold green]")
|
|
||||||
vector_search.load_index_endpoint(config.index.endpoint)
|
|
||||||
|
|
||||||
console.print("[bold green]Generating embedding for query...[/bold green]")
|
|
||||||
query_embedding = embedder.generate_embedding(query)
|
|
||||||
|
|
||||||
console.print("[bold green]Running search query...[/bold green]")
|
|
||||||
search_results = vector_search.run_query(
|
|
||||||
deployed_index_id=config.index.deployment,
|
|
||||||
query=query_embedding,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
table = Table(title="Search Results")
|
|
||||||
table.add_column("ID", justify="left", style="cyan")
|
|
||||||
table.add_column("Distance", justify="left", style="magenta")
|
|
||||||
table.add_column("Content", justify="left", style="green")
|
|
||||||
|
|
||||||
for result in search_results:
|
|
||||||
table.add_row(result["id"], str(result["distance"]), result["content"])
|
|
||||||
|
|
||||||
console.print(table)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
console.print(f"[bold red]An error occurred: {e}[/bold red]")
|
|
||||||
raise typer.Exit(code=1)
|
|
||||||
@@ -1,255 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from typing import List
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import google.auth
|
|
||||||
import google.auth.transport.requests
|
|
||||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
|
||||||
from gcloud.aio.auth import Token
|
|
||||||
from google.cloud import aiplatform
|
|
||||||
|
|
||||||
from .base import BaseVectorSearch, SearchResult
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleCloudVectorSearch(BaseVectorSearch):
|
|
||||||
"""
|
|
||||||
A vector search provider that uses Google Cloud's Vertex AI Vector Search.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, project_id: str, location: str, bucket: str, index_name: str = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initializes the GoogleCloudVectorSearch client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
project_id: The Google Cloud project ID.
|
|
||||||
location: The Google Cloud location (e.g., 'us-central1').
|
|
||||||
bucket: The GCS bucket to use for file storage.
|
|
||||||
index_name: The name of the index. If None, it will be taken from settings.
|
|
||||||
"""
|
|
||||||
aiplatform.init(project=project_id, location=location)
|
|
||||||
self.project_id = project_id
|
|
||||||
self.location = location
|
|
||||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
|
||||||
self.index_name = index_name
|
|
||||||
self._credentials = None
|
|
||||||
self._aio_session: aiohttp.ClientSession | None = None
|
|
||||||
self._async_token: Token | None = None
|
|
||||||
|
|
||||||
def _get_auth_headers(self) -> dict:
|
|
||||||
if self._credentials is None:
|
|
||||||
self._credentials, _ = google.auth.default(
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
|
||||||
)
|
|
||||||
if not self._credentials.token or self._credentials.expired:
|
|
||||||
self._credentials.refresh(google.auth.transport.requests.Request())
|
|
||||||
return {
|
|
||||||
"Authorization": f"Bearer {self._credentials.token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _async_get_auth_headers(self) -> dict:
|
|
||||||
if self._async_token is None:
|
|
||||||
self._async_token = Token(
|
|
||||||
session=self._get_aio_session(),
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
access_token = await self._async_token.get()
|
|
||||||
return {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._aio_session is None or self._aio_session.closed:
|
|
||||||
connector = aiohttp.TCPConnector(limit=300, limit_per_host=50)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=60)
|
|
||||||
self._aio_session = aiohttp.ClientSession(
|
|
||||||
timeout=timeout, connector=connector
|
|
||||||
)
|
|
||||||
return self._aio_session
|
|
||||||
|
|
||||||
def create_index(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
content_path: str,
|
|
||||||
dimensions: int,
|
|
||||||
approximate_neighbors_count: int = 150,
|
|
||||||
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Creates a new Vertex AI Vector Search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The display name for the new index.
|
|
||||||
content_path: The GCS URI to the JSON file containing the embeddings.
|
|
||||||
dimensions: The number of dimensions in the embedding vectors.
|
|
||||||
approximate_neighbors_count: The number of neighbors to find for each vector.
|
|
||||||
distance_measure_type: The distance measure to use (e.g., 'DOT_PRODUCT_DISTANCE').
|
|
||||||
"""
|
|
||||||
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
|
||||||
display_name=name,
|
|
||||||
contents_delta_uri=content_path,
|
|
||||||
dimensions=dimensions,
|
|
||||||
approximate_neighbors_count=approximate_neighbors_count,
|
|
||||||
distance_measure_type=distance_measure_type,
|
|
||||||
leaf_node_embedding_count=1000,
|
|
||||||
leaf_nodes_to_search_percent=10,
|
|
||||||
)
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def update_index(self, index_name: str, content_path: str, **kwargs) -> None:
|
|
||||||
"""
|
|
||||||
Updates an existing Vertex AI Vector Search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The resource name of the index to update.
|
|
||||||
content_path: The GCS URI to the JSON file containing the new embeddings.
|
|
||||||
"""
|
|
||||||
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
|
||||||
index.update_embeddings(
|
|
||||||
contents_delta_uri=content_path,
|
|
||||||
)
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def deploy_index(
|
|
||||||
self, index_name: str, machine_type: str = "e2-standard-2"
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Deploys a Vertex AI Vector Search index to an endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The name of the index to deploy.
|
|
||||||
machine_type: The type of machine to use for the endpoint.
|
|
||||||
"""
|
|
||||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
|
||||||
display_name=f"{index_name}-endpoint",
|
|
||||||
public_endpoint_enabled=True,
|
|
||||||
)
|
|
||||||
index_endpoint.deploy_index(
|
|
||||||
index=self.index,
|
|
||||||
deployed_index_id=f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}",
|
|
||||||
machine_type=machine_type,
|
|
||||||
)
|
|
||||||
self.index_endpoint = index_endpoint
|
|
||||||
|
|
||||||
def load_index_endpoint(self, endpoint_name: str) -> None:
|
|
||||||
"""
|
|
||||||
Loads an existing Vertex AI Vector Search index endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
endpoint_name: The resource name of the index endpoint.
|
|
||||||
"""
|
|
||||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(endpoint_name)
|
|
||||||
if not self.index_endpoint.public_endpoint_domain_name:
|
|
||||||
raise ValueError(
|
|
||||||
"The index endpoint does not have a public endpoint. "
|
|
||||||
"Please ensure that the endpoint is configured for public access."
|
|
||||||
)
|
|
||||||
|
|
||||||
def run_query(
|
|
||||||
self, deployed_index_id: str, query: List[float], limit: int
|
|
||||||
) -> List[SearchResult]:
|
|
||||||
"""
|
|
||||||
Runs a similarity search query against the deployed index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployed_index_id: The ID of the deployed index.
|
|
||||||
query: The embedding vector to use for the search query.
|
|
||||||
limit: The maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of dictionaries representing the matched items.
|
|
||||||
"""
|
|
||||||
response = self.index_endpoint.find_neighbors(
|
|
||||||
deployed_index_id=deployed_index_id, queries=[query], num_neighbors=limit
|
|
||||||
)
|
|
||||||
results = []
|
|
||||||
for neighbor in response[0]:
|
|
||||||
file_path = self.index_name + "/contents/" + neighbor.id + ".md"
|
|
||||||
content = self.storage.get_file_stream(file_path).read().decode("utf-8")
|
|
||||||
results.append(
|
|
||||||
{"id": neighbor.id, "distance": neighbor.distance, "content": content}
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def async_run_query(
|
|
||||||
self, deployed_index_id: str, query: List[float], limit: int
|
|
||||||
) -> List[SearchResult]:
|
|
||||||
"""
|
|
||||||
Runs a non-blocking similarity search query against the deployed index
|
|
||||||
using the REST API directly with an async HTTP client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployed_index_id: The ID of the deployed index.
|
|
||||||
query: The embedding vector to use for the search query.
|
|
||||||
limit: The maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of dictionaries representing the matched items.
|
|
||||||
"""
|
|
||||||
domain = self.index_endpoint.public_endpoint_domain_name
|
|
||||||
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
|
||||||
url = (
|
|
||||||
f"https://{domain}/v1/projects/{self.project_id}"
|
|
||||||
f"/locations/{self.location}"
|
|
||||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
|
||||||
)
|
|
||||||
payload = {
|
|
||||||
"deployed_index_id": deployed_index_id,
|
|
||||||
"queries": [
|
|
||||||
{
|
|
||||||
"datapoint": {"feature_vector": query},
|
|
||||||
"neighbor_count": limit,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = await self._async_get_auth_headers()
|
|
||||||
session = self._get_aio_session()
|
|
||||||
async with session.post(url, json=payload, headers=headers) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
data = await response.json()
|
|
||||||
|
|
||||||
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
|
||||||
content_tasks = []
|
|
||||||
for neighbor in neighbors:
|
|
||||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
|
||||||
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
|
|
||||||
content_tasks.append(self.storage.async_get_file_stream(file_path))
|
|
||||||
|
|
||||||
file_streams = await asyncio.gather(*content_tasks)
|
|
||||||
results: List[SearchResult] = []
|
|
||||||
for neighbor, stream in zip(neighbors, file_streams):
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"id": neighbor["datapoint"]["datapointId"],
|
|
||||||
"distance": neighbor["distance"],
|
|
||||||
"content": stream.read().decode("utf-8"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def delete_index(self, index_name: str) -> None:
|
|
||||||
"""
|
|
||||||
Deletes a Vertex AI Vector Search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The resource name of the index.
|
|
||||||
"""
|
|
||||||
index = aiplatform.MatchingEngineIndex(index_name)
|
|
||||||
index.delete()
|
|
||||||
|
|
||||||
def delete_index_endpoint(self, index_endpoint_name: str) -> None:
|
|
||||||
"""
|
|
||||||
Deletes a Vertex AI Vector Search index endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_endpoint_name: The resource name of the index endpoint.
|
|
||||||
"""
|
|
||||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name)
|
|
||||||
index_endpoint.undeploy_all()
|
|
||||||
index_endpoint.delete(force=True)
|
|
||||||
@@ -1,91 +1,57 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "rag-pipeline"
|
name = "knowledge-pipeline"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "RAG Pipeline for document chunking, embedding, and vector search"
|
description = "RAG Pipeline for document chunking, embedding, and vector search"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Pipeline Team" }
|
{ name = "Anibal Angulo", email = "A8065384@banorte.com" }
|
||||||
]
|
]
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core dependencies
|
# Core dependencies
|
||||||
"google-genai>=1.45.0",
|
|
||||||
"google-cloud-aiplatform>=1.106.0",
|
"google-cloud-aiplatform>=1.106.0",
|
||||||
"google-cloud-storage>=2.19.0",
|
"google-cloud-storage>=2.19.0",
|
||||||
"google-auth>=2.29.0",
|
|
||||||
"pydantic>=2.11.7",
|
"pydantic>=2.11.7",
|
||||||
"pydantic-settings[yaml]>=2.10.1",
|
"pydantic-settings[yaml]>=2.10.1",
|
||||||
"python-dotenv>=1.0.0",
|
|
||||||
|
|
||||||
# Chunking
|
# Chunking
|
||||||
"chonkie>=1.1.2",
|
"chonkie>=1.1.2",
|
||||||
"tiktoken>=0.7.0",
|
"tiktoken>=0.7.0",
|
||||||
"langchain>=0.3.0",
|
"langchain>=0.3.0",
|
||||||
"langchain-core>=0.3.0",
|
"langchain-core>=0.3.0",
|
||||||
|
|
||||||
# Document processing
|
# Document processing
|
||||||
"markitdown[pdf]>=0.1.2",
|
"markitdown[pdf]>=0.1.2",
|
||||||
"pypdf>=6.1.2",
|
"pypdf>=6.1.2",
|
||||||
"pdf2image>=1.17.0",
|
"pdf2image>=1.17.0",
|
||||||
|
|
||||||
# Storage & networking
|
|
||||||
"gcloud-aio-storage>=9.6.1",
|
|
||||||
"gcloud-aio-auth>=5.3.0",
|
|
||||||
"aiohttp>=3.10.11,<4",
|
|
||||||
|
|
||||||
# Utils
|
# Utils
|
||||||
"tenacity>=9.1.2",
|
|
||||||
"typer>=0.16.1",
|
"typer>=0.16.1",
|
||||||
|
"pydantic-ai>=0.0.5",
|
||||||
# Pipeline orchestration (optional)
|
|
||||||
"kfp>=2.15.2",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
# Chunkers
|
knowledge-pipeline = "knowledge_pipeline.cli:app"
|
||||||
llm-chunker = "chunker.llm_chunker:app"
|
|
||||||
recursive-chunker = "chunker.recursive_chunker:app"
|
|
||||||
contextual-chunker = "chunker.contextual_chunker:app"
|
|
||||||
|
|
||||||
# Converters
|
|
||||||
convert-md = "document_converter.markdown:app"
|
|
||||||
|
|
||||||
# Storage
|
|
||||||
file-storage = "file_storage.cli:app"
|
|
||||||
|
|
||||||
# Vector Search
|
|
||||||
vector-search = "vector_search.cli:app"
|
|
||||||
|
|
||||||
# Utils
|
|
||||||
normalize-filenames = "utils.normalize_filenames:app"
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||||
build-backend = "uv_build"
|
build-backend = "uv_build"
|
||||||
|
|
||||||
[tool.uv.workspace]
|
|
||||||
members = [
|
|
||||||
"apps/*",
|
|
||||||
"packages/*",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
|
||||||
chunker = { workspace = true }
|
|
||||||
document-converter = { workspace = true }
|
|
||||||
embedder = { workspace = true }
|
|
||||||
file-storage = { workspace = true }
|
|
||||||
llm = { workspace = true }
|
|
||||||
utils = { workspace = true }
|
|
||||||
vector-search = { workspace = true }
|
|
||||||
index-gen = { workspace = true }
|
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=8.4.1",
|
"pytest>=8.4.1",
|
||||||
"mypy>=1.17.1",
|
|
||||||
"ruff>=0.12.10",
|
"ruff>=0.12.10",
|
||||||
|
"ty>=0.0.18",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
extend-select = ["I", "F"]
|
select = ["I", "F"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
addopts = [
|
||||||
|
"--strict-markers",
|
||||||
|
"--tb=short",
|
||||||
|
"--disable-warnings",
|
||||||
|
]
|
||||||
|
markers = [
|
||||||
|
"unit: Unit tests",
|
||||||
|
"integration: Integration tests",
|
||||||
|
"slow: Slow running tests",
|
||||||
|
]
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ class Document(TypedDict):
|
|||||||
class BaseChunker(ABC):
|
class BaseChunker(ABC):
|
||||||
"""Abstract base class for chunker implementations."""
|
"""Abstract base class for chunker implementations."""
|
||||||
|
|
||||||
|
max_chunk_size: int
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process_text(self, text: str) -> List[Document]:
|
def process_text(self, text: str) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
@@ -1,11 +1,3 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated, List
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from llm.vertex_ai import VertexAILLM
|
|
||||||
|
|
||||||
from .base_chunker import BaseChunker, Document
|
from .base_chunker import BaseChunker, Document
|
||||||
|
|
||||||
|
|
||||||
@@ -16,23 +8,13 @@ class ContextualChunker(BaseChunker):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm_client: VertexAILLM,
|
model: str = "google-vertex:gemini-2.0-flash",
|
||||||
max_chunk_size: int = 800,
|
max_chunk_size: int = 800,
|
||||||
model: str = "gemini-2.0-flash",
|
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Initializes the ContextualChunker.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_chunk_size: The maximum length of a chunk in characters.
|
|
||||||
model: The name of the language model to use.
|
|
||||||
llm_client: An optional instance of a language model client.
|
|
||||||
"""
|
|
||||||
self.max_chunk_size = max_chunk_size
|
self.max_chunk_size = max_chunk_size
|
||||||
self.model = model
|
self.model = model
|
||||||
self.llm_client = llm_client
|
|
||||||
|
|
||||||
def _split_text(self, text: str) -> List[str]:
|
def _split_text(self, text: str) -> list[str]:
|
||||||
"""Splits text into evenly sized chunks of a maximum size, trying to respect sentence and paragraph boundaries."""
|
"""Splits text into evenly sized chunks of a maximum size, trying to respect sentence and paragraph boundaries."""
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@@ -67,7 +49,7 @@ class ContextualChunker(BaseChunker):
|
|||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def process_text(self, text: str) -> List[Document]:
|
def process_text(self, text: str) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Processes a string of text into a list of context-aware Document chunks.
|
Processes a string of text into a list of context-aware Document chunks.
|
||||||
"""
|
"""
|
||||||
@@ -75,7 +57,7 @@ class ContextualChunker(BaseChunker):
|
|||||||
return [{"page_content": text, "metadata": {}}]
|
return [{"page_content": text, "metadata": {}}]
|
||||||
|
|
||||||
chunks = self._split_text(text)
|
chunks = self._split_text(text)
|
||||||
processed_chunks: List[Document] = []
|
processed_chunks: list[Document] = []
|
||||||
|
|
||||||
for i, chunk_content in enumerate(chunks):
|
for i, chunk_content in enumerate(chunks):
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
@@ -93,7 +75,14 @@ class ContextualChunker(BaseChunker):
|
|||||||
Genera un resumen conciso del "Documento Original" que proporcione el contexto necesario para entender el "Fragmento Actual". El resumen debe ser un solo párrafo en español.
|
Genera un resumen conciso del "Documento Original" que proporcione el contexto necesario para entender el "Fragmento Actual". El resumen debe ser un solo párrafo en español.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
summary = self.llm_client.generate(self.model, prompt).text
|
from pydantic_ai import ModelRequest
|
||||||
|
from pydantic_ai.direct import model_request_sync
|
||||||
|
|
||||||
|
response = model_request_sync(
|
||||||
|
self.model,
|
||||||
|
[ModelRequest.user_text_prompt(prompt)],
|
||||||
|
)
|
||||||
|
summary = next(p.content for p in response.parts if p.part_kind == "text")
|
||||||
contextualized_chunk = (
|
contextualized_chunk = (
|
||||||
f"> **Contexto del documento original:**\n> {summary}\n\n---\n\n"
|
f"> **Contexto del documento original:**\n> {summary}\n\n---\n\n"
|
||||||
+ chunk_content
|
+ chunk_content
|
||||||
@@ -107,49 +96,3 @@ class ContextualChunker(BaseChunker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return processed_chunks
|
return processed_chunks
|
||||||
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def main(
|
|
||||||
input_file_path: Annotated[
|
|
||||||
str, typer.Argument(help="Path to the input text file.")
|
|
||||||
],
|
|
||||||
output_dir: Annotated[
|
|
||||||
str, typer.Argument(help="Directory to save the output file.")
|
|
||||||
],
|
|
||||||
max_chunk_size: Annotated[
|
|
||||||
int, typer.Option(help="Maximum chunk size in characters.")
|
|
||||||
] = 800,
|
|
||||||
model: Annotated[
|
|
||||||
str, typer.Option(help="Model to use for the processing")
|
|
||||||
] = "gemini-2.0-flash",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Processes a text file using ContextualChunker and saves the output to a JSONL file.
|
|
||||||
"""
|
|
||||||
print(f"Starting to process {input_file_path}...")
|
|
||||||
|
|
||||||
chunker = ContextualChunker(max_chunk_size=max_chunk_size, model=model)
|
|
||||||
documents = chunker.process_path(Path(input_file_path))
|
|
||||||
|
|
||||||
print(f"Successfully created {len(documents)} chunks.")
|
|
||||||
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
print(f"Created output directory: {output_dir}")
|
|
||||||
|
|
||||||
output_file_path = os.path.join(output_dir, "chunked_documents.jsonl")
|
|
||||||
|
|
||||||
with open(output_file_path, "w", encoding="utf-8") as f:
|
|
||||||
for doc in documents:
|
|
||||||
doc["metadata"]["source_file"] = os.path.basename(input_file_path)
|
|
||||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
|
||||||
|
|
||||||
print(f"Successfully saved {len(documents)} chunks to {output_file_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app()
|
|
||||||
@@ -13,7 +13,6 @@ from langchain_core.documents import Document as LangchainDocument
|
|||||||
from llm.vertex_ai import VertexAILLM
|
from llm.vertex_ai import VertexAILLM
|
||||||
from pdf2image import convert_from_path
|
from pdf2image import convert_from_path
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
|
|
||||||
from rag_eval.config import Settings
|
from rag_eval.config import Settings
|
||||||
|
|
||||||
from .base_chunker import BaseChunker, Document
|
from .base_chunker import BaseChunker, Document
|
||||||
20
src/knowledge_pipeline/cli.py
Normal file
20
src/knowledge_pipeline/cli.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from .config import Settings
|
||||||
|
from .pipeline import run_pipeline
|
||||||
|
|
||||||
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def run_ingestion():
|
||||||
|
"""Main function for the CLI script."""
|
||||||
|
settings = Settings.model_validate({})
|
||||||
|
logging.getLogger("knowledge_pipeline").setLevel(getattr(logging, settings.log_level))
|
||||||
|
run_pipeline(settings)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app()
|
||||||
101
src/knowledge_pipeline/config.py
Normal file
101
src/knowledge_pipeline/config.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
|
||||||
|
DistanceMeasureType,
|
||||||
|
)
|
||||||
|
from pydantic_settings import (
|
||||||
|
BaseSettings,
|
||||||
|
PydanticBaseSettingsSource,
|
||||||
|
SettingsConfigDict,
|
||||||
|
YamlConfigSettingsSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
project_id: str
|
||||||
|
location: str
|
||||||
|
log_level: str = "INFO"
|
||||||
|
|
||||||
|
agent_embedding_model: str
|
||||||
|
|
||||||
|
index_name: str
|
||||||
|
index_dimensions: int
|
||||||
|
index_machine_type: str = "e2-standard-16"
|
||||||
|
index_origin: str
|
||||||
|
index_destination: str
|
||||||
|
index_chunk_limit: int
|
||||||
|
index_distance_measure_type: DistanceMeasureType = (
|
||||||
|
DistanceMeasureType.DOT_PRODUCT_DISTANCE
|
||||||
|
)
|
||||||
|
index_approximate_neighbors_count: int = 150
|
||||||
|
index_leaf_node_embedding_count: int = 1000
|
||||||
|
index_leaf_nodes_to_search_percent: int = 10
|
||||||
|
index_public_endpoint_enabled: bool = True
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
|
||||||
|
|
||||||
|
def model_post_init(self, _):
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
aiplatform.init(project=self.project_id, location=self.location)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index_deployment(self) -> str:
|
||||||
|
return self.index_name.replace("-", "_") + "_deployed"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index_data(self) -> str:
|
||||||
|
return f"{self.index_destination}/{self.index_name}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index_contents_dir(self) -> str:
|
||||||
|
return f"{self.index_data}/contents"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index_vectors_dir(self) -> str:
|
||||||
|
return f"{self.index_data}/vectors"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index_vectors_jsonl_path(self) -> str:
|
||||||
|
return f"{self.index_vectors_dir}/vectors.json"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def gcs_client(self):
|
||||||
|
from google.cloud import storage
|
||||||
|
|
||||||
|
return storage.Client()
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def converter(self):
|
||||||
|
from markitdown import MarkItDown
|
||||||
|
|
||||||
|
return MarkItDown(enable_plugins=False)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def embedder(self):
|
||||||
|
from pydantic_ai import Embedder
|
||||||
|
|
||||||
|
return Embedder(f"google-vertex:{self.agent_embedding_model}")
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def chunker(self):
|
||||||
|
from .chunker.contextual_chunker import ContextualChunker
|
||||||
|
|
||||||
|
return ContextualChunker(max_chunk_size=self.index_chunk_limit)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def settings_customise_sources(
|
||||||
|
cls,
|
||||||
|
settings_cls: type[BaseSettings],
|
||||||
|
init_settings: PydanticBaseSettingsSource,
|
||||||
|
env_settings: PydanticBaseSettingsSource,
|
||||||
|
dotenv_settings: PydanticBaseSettingsSource,
|
||||||
|
file_secret_settings: PydanticBaseSettingsSource,
|
||||||
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||||
|
return (
|
||||||
|
env_settings,
|
||||||
|
YamlConfigSettingsSource(settings_cls),
|
||||||
|
)
|
||||||
209
src/knowledge_pipeline/pipeline.py
Normal file
209
src/knowledge_pipeline/pipeline.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
import unicodedata
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
|
||||||
|
DistanceMeasureType,
|
||||||
|
)
|
||||||
|
from google.cloud.storage import Client as StorageClient
|
||||||
|
from markitdown import MarkItDown
|
||||||
|
from pydantic_ai import Embedder
|
||||||
|
|
||||||
|
from .chunker.base_chunker import BaseChunker, Document
|
||||||
|
from .config import Settings
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_gcs_uri(uri: str) -> tuple[str, str]:
|
||||||
|
"""Parse a 'gs://bucket/path' URI into (bucket_name, object_path)."""
|
||||||
|
bucket, _, path = uri.removeprefix("gs://").partition("/")
|
||||||
|
return bucket, path
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_string(s: str) -> str:
|
||||||
|
"""Normalizes a string to be a valid filename."""
|
||||||
|
nfkd_form = unicodedata.normalize("NFKD", s)
|
||||||
|
only_ascii = "".join([c for c in nfkd_form if not unicodedata.combining(c)])
|
||||||
|
only_ascii = only_ascii.lower()
|
||||||
|
only_ascii = re.sub(r"\s+", "_", only_ascii)
|
||||||
|
only_ascii = re.sub(r"[^a-z0-9_.-]", "", only_ascii)
|
||||||
|
return only_ascii
|
||||||
|
|
||||||
|
|
||||||
|
def gather_pdfs(index_origin: str, gcs_client: StorageClient) -> list[str]:
|
||||||
|
"""Lists all PDF file URIs in a GCS directory."""
|
||||||
|
bucket, prefix = _parse_gcs_uri(index_origin)
|
||||||
|
blobs = gcs_client.bucket(bucket).list_blobs(prefix=prefix)
|
||||||
|
pdf_files = [
|
||||||
|
f"gs://{bucket}/{blob.name}" for blob in blobs if blob.name.endswith(".pdf")
|
||||||
|
]
|
||||||
|
log.info("Found %d PDF files in %s", len(pdf_files), index_origin)
|
||||||
|
return pdf_files
|
||||||
|
|
||||||
|
|
||||||
|
def split_into_chunks(text: str, file_id: str, chunker: BaseChunker) -> list[Document]:
|
||||||
|
"""Splits text into chunks, or returns a single chunk if small enough."""
|
||||||
|
if len(text) <= chunker.max_chunk_size:
|
||||||
|
return [{"page_content": text, "metadata": {"id": file_id}}]
|
||||||
|
|
||||||
|
chunks = chunker.process_text(text)
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk["metadata"]["id"] = f"{file_id}_{i}"
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_gcs(
|
||||||
|
chunks: list[Document],
|
||||||
|
vectors: list[dict],
|
||||||
|
index_contents_dir: str,
|
||||||
|
index_vectors_jsonl_path: str,
|
||||||
|
gcs_client: StorageClient,
|
||||||
|
) -> None:
|
||||||
|
"""Uploads chunk contents and vectors to GCS."""
|
||||||
|
bucket, prefix = _parse_gcs_uri(index_contents_dir)
|
||||||
|
gcs_bucket = gcs_client.bucket(bucket)
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_id = chunk["metadata"]["id"]
|
||||||
|
gcs_bucket.blob(f"{prefix}/{chunk_id}.md").upload_from_string(
|
||||||
|
chunk["page_content"], content_type="text/markdown; charset=utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
vectors_jsonl = "\n".join(json.dumps(v) for v in vectors) + "\n"
|
||||||
|
bucket, obj_path = _parse_gcs_uri(index_vectors_jsonl_path)
|
||||||
|
gcs_client.bucket(bucket).blob(obj_path).upload_from_string(
|
||||||
|
vectors_jsonl, content_type="application/x-ndjson; charset=utf-8"
|
||||||
|
)
|
||||||
|
log.info("Uploaded %d chunks and %d vectors to GCS", len(chunks), len(vectors))
|
||||||
|
|
||||||
|
|
||||||
|
def build_vectors(
|
||||||
|
chunks: list[Document],
|
||||||
|
embeddings: Sequence[Sequence[float]],
|
||||||
|
source_folder: str,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Builds vector records from chunks and their embeddings."""
|
||||||
|
source = Path(source_folder).parts[0] if source_folder else ""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": chunk["metadata"]["id"],
|
||||||
|
"embedding": list(embedding),
|
||||||
|
"restricts": [{"namespace": "source", "allow": [source]}],
|
||||||
|
}
|
||||||
|
for chunk, embedding in zip(chunks, embeddings)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_vector_index(
|
||||||
|
index_name: str,
|
||||||
|
index_vectors_dir: str,
|
||||||
|
index_dimensions: int,
|
||||||
|
index_distance_measure_type: DistanceMeasureType,
|
||||||
|
index_deployment: str,
|
||||||
|
index_machine_type: str,
|
||||||
|
approximate_neighbors_count: int,
|
||||||
|
leaf_node_embedding_count: int,
|
||||||
|
leaf_nodes_to_search_percent: int,
|
||||||
|
public_endpoint_enabled: bool,
|
||||||
|
):
|
||||||
|
"""Creates and deploys a Vertex AI Vector Search Index."""
|
||||||
|
log.info("Creating index '%s'...", index_name)
|
||||||
|
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
||||||
|
display_name=index_name,
|
||||||
|
contents_delta_uri=index_vectors_dir,
|
||||||
|
dimensions=index_dimensions,
|
||||||
|
approximate_neighbors_count=approximate_neighbors_count,
|
||||||
|
distance_measure_type=index_distance_measure_type,
|
||||||
|
leaf_node_embedding_count=leaf_node_embedding_count,
|
||||||
|
leaf_nodes_to_search_percent=leaf_nodes_to_search_percent,
|
||||||
|
)
|
||||||
|
log.info("Index '%s' created successfully.", index_name)
|
||||||
|
|
||||||
|
log.info("Deploying index to a new endpoint...")
|
||||||
|
endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
||||||
|
display_name=f"{index_name}-endpoint",
|
||||||
|
public_endpoint_enabled=public_endpoint_enabled,
|
||||||
|
)
|
||||||
|
endpoint.deploy_index(
|
||||||
|
index=index,
|
||||||
|
deployed_index_id=index_deployment,
|
||||||
|
machine_type=index_machine_type,
|
||||||
|
)
|
||||||
|
log.info("Index deployed: %s", endpoint.display_name)
|
||||||
|
|
||||||
|
|
||||||
|
def process_file(
|
||||||
|
file_uri: str,
|
||||||
|
temp_dir: Path,
|
||||||
|
gcs_client: StorageClient,
|
||||||
|
converter: MarkItDown,
|
||||||
|
embedder: Embedder,
|
||||||
|
chunker: BaseChunker,
|
||||||
|
) -> tuple[list[Document], list[dict]]:
|
||||||
|
"""Downloads a PDF from GCS, converts to markdown, chunks, and embeds."""
|
||||||
|
bucket, obj_path = _parse_gcs_uri(file_uri)
|
||||||
|
local_path = temp_dir / Path(file_uri).name
|
||||||
|
gcs_client.bucket(bucket).blob(obj_path).download_to_filename(local_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
markdown = converter.convert(local_path).text_content
|
||||||
|
file_id = normalize_string(Path(file_uri).stem)
|
||||||
|
source_folder = Path(obj_path).parent.as_posix()
|
||||||
|
|
||||||
|
chunks = split_into_chunks(markdown, file_id, chunker)
|
||||||
|
texts = [c["page_content"] for c in chunks]
|
||||||
|
embeddings = embedder.embed_documents_sync(texts).embeddings
|
||||||
|
|
||||||
|
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||||
|
return chunks, vectors
|
||||||
|
finally:
|
||||||
|
if local_path.exists():
|
||||||
|
local_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def run_pipeline(settings: Settings):
|
||||||
|
"""Runs the full ingestion pipeline: gather → process → aggregate → index."""
|
||||||
|
files = gather_pdfs(settings.index_origin, settings.gcs_client)
|
||||||
|
|
||||||
|
all_chunks: list[Document] = []
|
||||||
|
all_vectors: list[dict] = []
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
for file_uri in files:
|
||||||
|
log.info("Processing file: %s", file_uri)
|
||||||
|
chunks, vectors = process_file(
|
||||||
|
file_uri,
|
||||||
|
Path(temp_dir),
|
||||||
|
settings.gcs_client,
|
||||||
|
settings.converter,
|
||||||
|
settings.embedder,
|
||||||
|
settings.chunker,
|
||||||
|
)
|
||||||
|
all_chunks.extend(chunks)
|
||||||
|
all_vectors.extend(vectors)
|
||||||
|
|
||||||
|
upload_to_gcs(
|
||||||
|
all_chunks,
|
||||||
|
all_vectors,
|
||||||
|
settings.index_contents_dir,
|
||||||
|
settings.index_vectors_jsonl_path,
|
||||||
|
settings.gcs_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_vector_index(
|
||||||
|
settings.index_name,
|
||||||
|
settings.index_vectors_dir,
|
||||||
|
settings.index_dimensions,
|
||||||
|
settings.index_distance_measure_type,
|
||||||
|
settings.index_deployment,
|
||||||
|
settings.index_machine_type,
|
||||||
|
settings.index_approximate_neighbors_count,
|
||||||
|
settings.index_leaf_node_embedding_count,
|
||||||
|
settings.index_leaf_nodes_to_search_percent,
|
||||||
|
settings.index_public_endpoint_enabled,
|
||||||
|
)
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic_settings import (
|
|
||||||
BaseSettings,
|
|
||||||
PydanticBaseSettingsSource,
|
|
||||||
SettingsConfigDict,
|
|
||||||
YamlConfigSettingsSource,
|
|
||||||
)
|
|
||||||
|
|
||||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
class IndexConfig(BaseModel):
|
|
||||||
name: str
|
|
||||||
endpoint: str
|
|
||||||
dimensions: int
|
|
||||||
machine_type: str = "e2-standard-16"
|
|
||||||
origin: str
|
|
||||||
destination: str
|
|
||||||
chunk_limit: int
|
|
||||||
|
|
||||||
@property
|
|
||||||
def deployment(self) -> str:
|
|
||||||
return self.name.replace("-", "_") + "_deployed"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self) -> str:
|
|
||||||
return self.destination + self.name
|
|
||||||
|
|
||||||
class AgentConfig(BaseModel):
|
|
||||||
name: str
|
|
||||||
instructions: str
|
|
||||||
language_model: str
|
|
||||||
embedding_model: str
|
|
||||||
thinking: int
|
|
||||||
|
|
||||||
|
|
||||||
class BigQueryConfig(BaseModel):
|
|
||||||
dataset_id: str
|
|
||||||
project_id: str | None = None
|
|
||||||
table_ids: dict[str, str]
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
project_id: str
|
|
||||||
location: str
|
|
||||||
service_account: str
|
|
||||||
|
|
||||||
# Flattened fields from nested models
|
|
||||||
agent_name: str
|
|
||||||
agent_instructions: str
|
|
||||||
agent_language_model: str
|
|
||||||
agent_embedding_model: str
|
|
||||||
agent_thinking: int
|
|
||||||
|
|
||||||
index_name: str
|
|
||||||
index_endpoint: str
|
|
||||||
index_dimensions: int
|
|
||||||
index_machine_type: str = "e2-standard-16"
|
|
||||||
index_origin: str
|
|
||||||
index_destination: str
|
|
||||||
index_chunk_limit: int
|
|
||||||
|
|
||||||
bigquery_dataset_id: str
|
|
||||||
bigquery_project_id: str | None = None
|
|
||||||
bigquery_table_ids: dict[str, str]
|
|
||||||
|
|
||||||
bucket: str
|
|
||||||
base_image: str
|
|
||||||
dialogflow_agent_id: str
|
|
||||||
processing_image: str
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def agent(self) -> AgentConfig:
|
|
||||||
return AgentConfig(
|
|
||||||
name=self.agent_name,
|
|
||||||
instructions=self.agent_instructions,
|
|
||||||
language_model=self.agent_language_model,
|
|
||||||
embedding_model=self.agent_embedding_model,
|
|
||||||
thinking=self.agent_thinking,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def index(self) -> IndexConfig:
|
|
||||||
return IndexConfig(
|
|
||||||
name=self.index_name,
|
|
||||||
endpoint=self.index_endpoint,
|
|
||||||
dimensions=self.index_dimensions,
|
|
||||||
machine_type=self.index_machine_type,
|
|
||||||
origin=self.index_origin,
|
|
||||||
destination=self.index_destination,
|
|
||||||
chunk_limit=self.index_chunk_limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bigquery(self) -> BigQueryConfig:
|
|
||||||
return BigQueryConfig(
|
|
||||||
dataset_id=self.bigquery_dataset_id,
|
|
||||||
project_id=self.bigquery_project_id,
|
|
||||||
table_ids=self.bigquery_table_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def settings_customise_sources(
|
|
||||||
cls,
|
|
||||||
settings_cls: type[BaseSettings],
|
|
||||||
init_settings: PydanticBaseSettingsSource,
|
|
||||||
env_settings: PydanticBaseSettingsSource,
|
|
||||||
dotenv_settings: PydanticBaseSettingsSource,
|
|
||||||
file_secret_settings: PydanticBaseSettingsSource,
|
|
||||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
|
||||||
return (
|
|
||||||
env_settings,
|
|
||||||
YamlConfigSettingsSource(settings_cls),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Tests package
|
||||||
89
tests/conftest.py
Normal file
89
tests/conftest.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""Shared pytest fixtures for knowledge_pipeline tests."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from knowledge_pipeline.chunker.base_chunker import BaseChunker, Document
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_gcs_client():
|
||||||
|
"""Mock Google Cloud Storage client."""
|
||||||
|
client = Mock()
|
||||||
|
bucket = Mock()
|
||||||
|
blob = Mock()
|
||||||
|
|
||||||
|
client.bucket.return_value = bucket
|
||||||
|
bucket.blob.return_value = blob
|
||||||
|
bucket.list_blobs.return_value = []
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_chunker():
|
||||||
|
"""Mock BaseChunker implementation."""
|
||||||
|
chunker = Mock(spec=BaseChunker)
|
||||||
|
chunker.max_chunk_size = 1000
|
||||||
|
chunker.process_text.return_value = [
|
||||||
|
{"page_content": "Test chunk content", "metadata": {"id": "test_chunk"}}
|
||||||
|
]
|
||||||
|
return chunker
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedder():
|
||||||
|
"""Mock pydantic_ai Embedder."""
|
||||||
|
embedder = Mock()
|
||||||
|
embeddings_result = Mock()
|
||||||
|
embeddings_result.embeddings = [[0.1, 0.2, 0.3]]
|
||||||
|
embedder.embed_documents_sync.return_value = embeddings_result
|
||||||
|
return embedder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_converter():
|
||||||
|
"""Mock MarkItDown converter."""
|
||||||
|
converter = Mock()
|
||||||
|
result = Mock()
|
||||||
|
result.text_content = "# Markdown Content\n\nTest content here."
|
||||||
|
converter.convert.return_value = result
|
||||||
|
return converter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_chunks() -> list[Document]:
|
||||||
|
"""Sample document chunks for testing."""
|
||||||
|
return [
|
||||||
|
{"page_content": "First chunk content", "metadata": {"id": "doc_1_0"}},
|
||||||
|
{"page_content": "Second chunk content", "metadata": {"id": "doc_1_1"}},
|
||||||
|
{"page_content": "Third chunk content", "metadata": {"id": "doc_1_2"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_embeddings():
|
||||||
|
"""Sample embeddings for testing."""
|
||||||
|
return [
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
|
[0.6, 0.7, 0.8, 0.9, 1.0],
|
||||||
|
[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_vectors():
|
||||||
|
"""Sample vector records for testing."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": "doc_1_0",
|
||||||
|
"embedding": [0.1, 0.2, 0.3],
|
||||||
|
"restricts": [{"namespace": "source", "allow": ["documents"]}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc_1_1",
|
||||||
|
"embedding": [0.4, 0.5, 0.6],
|
||||||
|
"restricts": [{"namespace": "source", "allow": ["documents"]}],
|
||||||
|
},
|
||||||
|
]
|
||||||
553
tests/test_pipeline.py
Normal file
553
tests/test_pipeline.py
Normal file
@@ -0,0 +1,553 @@
|
|||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
|
||||||
|
DistanceMeasureType,
|
||||||
|
)
|
||||||
|
|
||||||
|
from knowledge_pipeline.chunker.base_chunker import BaseChunker
|
||||||
|
from knowledge_pipeline.pipeline import (
|
||||||
|
_parse_gcs_uri,
|
||||||
|
build_vectors,
|
||||||
|
create_vector_index,
|
||||||
|
gather_pdfs,
|
||||||
|
normalize_string,
|
||||||
|
process_file,
|
||||||
|
run_pipeline,
|
||||||
|
split_into_chunks,
|
||||||
|
upload_to_gcs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseGcsUri:
|
||||||
|
"""Tests for _parse_gcs_uri function."""
|
||||||
|
|
||||||
|
def test_basic_gcs_uri(self):
|
||||||
|
bucket, path = _parse_gcs_uri("gs://my-bucket/path/to/file.pdf")
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert path == "path/to/file.pdf"
|
||||||
|
|
||||||
|
def test_gcs_uri_with_nested_path(self):
|
||||||
|
bucket, path = _parse_gcs_uri("gs://test-bucket/deep/nested/path/file.txt")
|
||||||
|
assert bucket == "test-bucket"
|
||||||
|
assert path == "deep/nested/path/file.txt"
|
||||||
|
|
||||||
|
def test_gcs_uri_bucket_only(self):
|
||||||
|
bucket, path = _parse_gcs_uri("gs://my-bucket/")
|
||||||
|
assert bucket == "my-bucket"
|
||||||
|
assert path == ""
|
||||||
|
|
||||||
|
def test_gcs_uri_no_trailing_slash(self):
|
||||||
|
bucket, path = _parse_gcs_uri("gs://bucket-name")
|
||||||
|
assert bucket == "bucket-name"
|
||||||
|
assert path == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeString:
|
||||||
|
"""Tests for normalize_string function."""
|
||||||
|
|
||||||
|
def test_normalize_basic_string(self):
|
||||||
|
result = normalize_string("Hello World")
|
||||||
|
assert result == "hello_world"
|
||||||
|
|
||||||
|
def test_normalize_special_characters(self):
|
||||||
|
result = normalize_string("File#Name@2024!.pdf")
|
||||||
|
assert result == "filename2024.pdf"
|
||||||
|
|
||||||
|
def test_normalize_unicode(self):
|
||||||
|
result = normalize_string("Café Münchën")
|
||||||
|
assert result == "cafe_munchen"
|
||||||
|
|
||||||
|
def test_normalize_multiple_spaces(self):
|
||||||
|
result = normalize_string("Multiple Spaces Here")
|
||||||
|
assert result == "multiple_spaces_here"
|
||||||
|
|
||||||
|
def test_normalize_with_hyphens_and_periods(self):
|
||||||
|
result = normalize_string("valid-filename.2024")
|
||||||
|
assert result == "valid-filename.2024"
|
||||||
|
|
||||||
|
def test_normalize_empty_string(self):
|
||||||
|
result = normalize_string("")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_normalize_only_special_chars(self):
|
||||||
|
result = normalize_string("@#$%^&*()")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestGatherFiles:
|
||||||
|
"""Tests for gather_files function."""
|
||||||
|
|
||||||
|
def test_gather_files_finds_pdfs(self):
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_client.bucket.return_value = mock_bucket
|
||||||
|
|
||||||
|
# Create mock blobs
|
||||||
|
mock_blob1 = Mock()
|
||||||
|
mock_blob1.name = "docs/file1.pdf"
|
||||||
|
mock_blob2 = Mock()
|
||||||
|
mock_blob2.name = "docs/file2.pdf"
|
||||||
|
mock_blob3 = Mock()
|
||||||
|
mock_blob3.name = "docs/readme.txt"
|
||||||
|
|
||||||
|
mock_bucket.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3]
|
||||||
|
|
||||||
|
files = gather_pdfs("gs://my-bucket/docs", mock_client)
|
||||||
|
|
||||||
|
assert len(files) == 2
|
||||||
|
assert "gs://my-bucket/docs/file1.pdf" in files
|
||||||
|
assert "gs://my-bucket/docs/file2.pdf" in files
|
||||||
|
assert "gs://my-bucket/docs/readme.txt" not in files
|
||||||
|
|
||||||
|
def test_gather_files_no_pdfs(self):
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_client.bucket.return_value = mock_bucket
|
||||||
|
|
||||||
|
mock_blob = Mock()
|
||||||
|
mock_blob.name = "docs/readme.txt"
|
||||||
|
mock_bucket.list_blobs.return_value = [mock_blob]
|
||||||
|
|
||||||
|
files = gather_pdfs("gs://my-bucket/docs", mock_client)
|
||||||
|
|
||||||
|
assert len(files) == 0
|
||||||
|
|
||||||
|
def test_gather_files_empty_bucket(self):
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_client.bucket.return_value = mock_bucket
|
||||||
|
mock_bucket.list_blobs.return_value = []
|
||||||
|
|
||||||
|
files = gather_pdfs("gs://my-bucket/docs", mock_client)
|
||||||
|
|
||||||
|
assert len(files) == 0
|
||||||
|
|
||||||
|
def test_gather_files_correct_prefix(self):
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_client.bucket.return_value = mock_bucket
|
||||||
|
mock_bucket.list_blobs.return_value = []
|
||||||
|
|
||||||
|
gather_pdfs("gs://my-bucket/docs/subfolder", mock_client)
|
||||||
|
|
||||||
|
mock_client.bucket.assert_called_once_with("my-bucket")
|
||||||
|
mock_bucket.list_blobs.assert_called_once_with(prefix="docs/subfolder")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitIntoChunks:
|
||||||
|
"""Tests for split_into_chunks function."""
|
||||||
|
|
||||||
|
def test_split_small_text_single_chunk(self):
|
||||||
|
mock_chunker = Mock(spec=BaseChunker)
|
||||||
|
mock_chunker.max_chunk_size = 1000
|
||||||
|
|
||||||
|
text = "Small text"
|
||||||
|
file_id = "test_file"
|
||||||
|
|
||||||
|
chunks = split_into_chunks(text, file_id, mock_chunker)
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0]["page_content"] == "Small text"
|
||||||
|
assert chunks[0]["metadata"]["id"] == "test_file"
|
||||||
|
mock_chunker.process_text.assert_not_called()
|
||||||
|
|
||||||
|
def test_split_large_text_multiple_chunks(self):
|
||||||
|
mock_chunker = Mock(spec=BaseChunker)
|
||||||
|
mock_chunker.max_chunk_size = 10
|
||||||
|
|
||||||
|
# Create text larger than max_chunk_size
|
||||||
|
text = "This is a very long text that needs to be split into chunks"
|
||||||
|
file_id = "test_file"
|
||||||
|
|
||||||
|
# Mock the chunker to return multiple chunks
|
||||||
|
mock_chunker.process_text.return_value = [
|
||||||
|
{"page_content": "This is a very", "metadata": {}},
|
||||||
|
{"page_content": "long text that", "metadata": {}},
|
||||||
|
{"page_content": "needs to be split", "metadata": {}},
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = split_into_chunks(text, file_id, mock_chunker)
|
||||||
|
|
||||||
|
assert len(chunks) == 3
|
||||||
|
assert chunks[0]["metadata"]["id"] == "test_file_0"
|
||||||
|
assert chunks[1]["metadata"]["id"] == "test_file_1"
|
||||||
|
assert chunks[2]["metadata"]["id"] == "test_file_2"
|
||||||
|
mock_chunker.process_text.assert_called_once_with(text)
|
||||||
|
|
||||||
|
def test_split_exactly_max_size(self):
|
||||||
|
mock_chunker = Mock(spec=BaseChunker)
|
||||||
|
mock_chunker.max_chunk_size = 10
|
||||||
|
|
||||||
|
text = "0123456789" # Exactly 10 characters
|
||||||
|
file_id = "test_file"
|
||||||
|
|
||||||
|
chunks = split_into_chunks(text, file_id, mock_chunker)
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0]["page_content"] == text
|
||||||
|
mock_chunker.process_text.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadToGcs:
|
||||||
|
"""Tests for upload_to_gcs function."""
|
||||||
|
|
||||||
|
def test_upload_single_chunk_and_vectors(self):
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_blob = Mock()
|
||||||
|
|
||||||
|
mock_client.bucket.return_value = mock_bucket
|
||||||
|
mock_bucket.blob.return_value = mock_blob
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
{
|
||||||
|
"page_content": "Test content",
|
||||||
|
"metadata": {"id": "chunk_1"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
vectors = [{"id": "chunk_1", "embedding": [0.1, 0.2]}]
|
||||||
|
|
||||||
|
upload_to_gcs(
|
||||||
|
chunks,
|
||||||
|
vectors,
|
||||||
|
"gs://my-bucket/contents",
|
||||||
|
"gs://my-bucket/vectors/vectors.jsonl",
|
||||||
|
mock_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
blob_calls = [call[0][0] for call in mock_bucket.blob.call_args_list]
|
||||||
|
assert "contents/chunk_1.md" in blob_calls
|
||||||
|
assert "vectors/vectors.jsonl" in blob_calls
|
||||||
|
|
||||||
|
def test_upload_multiple_chunks(self):
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_blob = Mock()
|
||||||
|
|
||||||
|
mock_client.bucket.return_value = mock_bucket
|
||||||
|
mock_bucket.blob.return_value = mock_blob
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
{"page_content": "Content 1", "metadata": {"id": "chunk_1"}},
|
||||||
|
{"page_content": "Content 2", "metadata": {"id": "chunk_2"}},
|
||||||
|
{"page_content": "Content 3", "metadata": {"id": "chunk_3"}},
|
||||||
|
]
|
||||||
|
vectors = [{"id": "chunk_1", "embedding": [0.1]}]
|
||||||
|
|
||||||
|
upload_to_gcs(
|
||||||
|
chunks,
|
||||||
|
vectors,
|
||||||
|
"gs://my-bucket/contents",
|
||||||
|
"gs://my-bucket/vectors/vectors.jsonl",
|
||||||
|
mock_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3 chunk blobs + 1 vectors blob
|
||||||
|
assert mock_bucket.blob.call_count == 4
|
||||||
|
|
||||||
|
blob_calls = [call[0][0] for call in mock_bucket.blob.call_args_list]
|
||||||
|
assert blob_calls == [
|
||||||
|
"contents/chunk_1.md",
|
||||||
|
"contents/chunk_2.md",
|
||||||
|
"contents/chunk_3.md",
|
||||||
|
"vectors/vectors.jsonl",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildVectors:
|
||||||
|
"""Tests for build_vectors function."""
|
||||||
|
|
||||||
|
def test_build_vectors_basic(self):
|
||||||
|
chunks = [
|
||||||
|
{"metadata": {"id": "doc_1"}, "page_content": "content 1"},
|
||||||
|
{"metadata": {"id": "doc_2"}, "page_content": "content 2"},
|
||||||
|
]
|
||||||
|
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
source_folder = "documents/reports"
|
||||||
|
|
||||||
|
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||||
|
|
||||||
|
assert len(vectors) == 2
|
||||||
|
assert vectors[0]["id"] == "doc_1"
|
||||||
|
assert vectors[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||||
|
assert vectors[0]["restricts"] == [
|
||||||
|
{"namespace": "source", "allow": ["documents"]}
|
||||||
|
]
|
||||||
|
assert vectors[1]["id"] == "doc_2"
|
||||||
|
assert vectors[1]["embedding"] == [0.4, 0.5, 0.6]
|
||||||
|
|
||||||
|
def test_build_vectors_empty_source(self):
|
||||||
|
chunks = [{"metadata": {"id": "doc_1"}, "page_content": "content"}]
|
||||||
|
embeddings = [[0.1, 0.2]]
|
||||||
|
source_folder = ""
|
||||||
|
|
||||||
|
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||||
|
|
||||||
|
assert len(vectors) == 1
|
||||||
|
assert vectors[0]["restricts"] == [{"namespace": "source", "allow": [""]}]
|
||||||
|
|
||||||
|
def test_build_vectors_nested_path(self):
|
||||||
|
chunks = [{"metadata": {"id": "doc_1"}, "page_content": "content"}]
|
||||||
|
embeddings = [[0.1]]
|
||||||
|
source_folder = "a/b/c/d"
|
||||||
|
|
||||||
|
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||||
|
|
||||||
|
assert vectors[0]["restricts"] == [{"namespace": "source", "allow": ["a"]}]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateVectorIndex:
|
||||||
|
"""Tests for create_vector_index function."""
|
||||||
|
|
||||||
|
@patch("knowledge_pipeline.pipeline.aiplatform.MatchingEngineIndexEndpoint")
|
||||||
|
@patch("knowledge_pipeline.pipeline.aiplatform.MatchingEngineIndex")
|
||||||
|
def test_create_vector_index(self, mock_index_class, mock_endpoint_class):
|
||||||
|
mock_index = Mock()
|
||||||
|
mock_endpoint = Mock()
|
||||||
|
|
||||||
|
mock_index_class.create_tree_ah_index.return_value = mock_index
|
||||||
|
mock_endpoint_class.create.return_value = mock_endpoint
|
||||||
|
|
||||||
|
create_vector_index(
|
||||||
|
index_name="test-index",
|
||||||
|
index_vectors_dir="gs://bucket/vectors",
|
||||||
|
index_dimensions=768,
|
||||||
|
index_distance_measure_type=DistanceMeasureType.DOT_PRODUCT_DISTANCE,
|
||||||
|
index_deployment="test_index_deployed",
|
||||||
|
index_machine_type="e2-standard-16",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_index_class.create_tree_ah_index.assert_called_once_with(
|
||||||
|
display_name="test-index",
|
||||||
|
contents_delta_uri="gs://bucket/vectors",
|
||||||
|
dimensions=768,
|
||||||
|
approximate_neighbors_count=150,
|
||||||
|
distance_measure_type=DistanceMeasureType.DOT_PRODUCT_DISTANCE,
|
||||||
|
leaf_node_embedding_count=1000,
|
||||||
|
leaf_nodes_to_search_percent=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_endpoint_class.create.assert_called_once_with(
|
||||||
|
display_name="test-index-endpoint",
|
||||||
|
public_endpoint_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_endpoint.deploy_index.assert_called_once_with(
|
||||||
|
index=mock_index,
|
||||||
|
deployed_index_id="test_index_deployed",
|
||||||
|
machine_type="e2-standard-16",
|
||||||
|
sync=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessFile:
|
||||||
|
"""Tests for process_file function."""
|
||||||
|
|
||||||
|
def test_process_file_success(self):
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_path = Path(temp_dir)
|
||||||
|
|
||||||
|
# Mock dependencies
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_blob = Mock()
|
||||||
|
mock_client.bucket.return_value = mock_bucket
|
||||||
|
mock_bucket.blob.return_value = mock_blob
|
||||||
|
|
||||||
|
mock_converter = Mock()
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.text_content = "Converted markdown content"
|
||||||
|
mock_converter.convert.return_value = mock_result
|
||||||
|
|
||||||
|
mock_embedder = Mock()
|
||||||
|
mock_embeddings_result = Mock()
|
||||||
|
mock_embeddings_result.embeddings = [[0.1, 0.2, 0.3]]
|
||||||
|
mock_embedder.embed_documents_sync.return_value = mock_embeddings_result
|
||||||
|
|
||||||
|
mock_chunker = Mock(spec=BaseChunker)
|
||||||
|
mock_chunker.max_chunk_size = 1000
|
||||||
|
|
||||||
|
file_uri = "gs://my-bucket/docs/test-file.pdf"
|
||||||
|
|
||||||
|
chunks, vectors = process_file(
|
||||||
|
file_uri,
|
||||||
|
temp_path,
|
||||||
|
mock_client,
|
||||||
|
mock_converter,
|
||||||
|
mock_embedder,
|
||||||
|
mock_chunker,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify download was called
|
||||||
|
mock_client.bucket.assert_called_with("my-bucket")
|
||||||
|
mock_bucket.blob.assert_called_with("docs/test-file.pdf")
|
||||||
|
assert mock_blob.download_to_filename.called
|
||||||
|
|
||||||
|
# Verify converter was called
|
||||||
|
assert mock_converter.convert.called
|
||||||
|
|
||||||
|
# Verify embedder was called
|
||||||
|
mock_embedder.embed_documents_sync.assert_called_once()
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0]["page_content"] == "Converted markdown content"
|
||||||
|
assert len(vectors) == 1
|
||||||
|
assert vectors[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
def test_process_file_cleans_up_temp_file(self):
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_path = Path(temp_dir)
|
||||||
|
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_blob = Mock()
|
||||||
|
mock_client.bucket.return_value = mock_bucket
|
||||||
|
mock_bucket.blob.return_value = mock_blob
|
||||||
|
|
||||||
|
mock_converter = Mock()
|
||||||
|
mock_converter.convert.side_effect = Exception("Conversion failed")
|
||||||
|
|
||||||
|
mock_embedder = Mock()
|
||||||
|
mock_chunker = Mock(spec=BaseChunker)
|
||||||
|
|
||||||
|
file_uri = "gs://my-bucket/docs/test.pdf"
|
||||||
|
|
||||||
|
# This should raise an exception but still clean up
|
||||||
|
with pytest.raises(Exception, match="Conversion failed"):
|
||||||
|
process_file(
|
||||||
|
file_uri,
|
||||||
|
temp_path,
|
||||||
|
mock_client,
|
||||||
|
mock_converter,
|
||||||
|
mock_embedder,
|
||||||
|
mock_chunker,
|
||||||
|
)
|
||||||
|
|
||||||
|
# File should be cleaned up even after exception
|
||||||
|
temp_file = temp_path / "test.pdf"
|
||||||
|
assert not temp_file.exists()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunPipeline:
|
||||||
|
"""Tests for run_pipeline function."""
|
||||||
|
|
||||||
|
@patch("knowledge_pipeline.pipeline.create_vector_index")
|
||||||
|
@patch("knowledge_pipeline.pipeline.upload_to_gcs")
|
||||||
|
@patch("knowledge_pipeline.pipeline.process_file")
|
||||||
|
@patch("knowledge_pipeline.pipeline.gather_pdfs")
|
||||||
|
def test_run_pipeline_integration(
|
||||||
|
self,
|
||||||
|
mock_gather,
|
||||||
|
mock_process,
|
||||||
|
mock_upload,
|
||||||
|
mock_create_index,
|
||||||
|
):
|
||||||
|
# Mock settings
|
||||||
|
mock_settings = Mock()
|
||||||
|
mock_settings.index_origin = "gs://bucket/input"
|
||||||
|
mock_settings.index_contents_dir = "gs://bucket/contents"
|
||||||
|
mock_settings.index_vectors_jsonl_path = "gs://bucket/vectors/vectors.jsonl"
|
||||||
|
mock_settings.index_name = "test-index"
|
||||||
|
mock_settings.index_vectors_dir = "gs://bucket/vectors"
|
||||||
|
mock_settings.index_dimensions = 768
|
||||||
|
mock_settings.index_distance_measure_type = (
|
||||||
|
DistanceMeasureType.DOT_PRODUCT_DISTANCE
|
||||||
|
)
|
||||||
|
mock_settings.index_deployment = "test_index_deployed"
|
||||||
|
mock_settings.index_machine_type = "e2-standard-16"
|
||||||
|
|
||||||
|
mock_gcs_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_blob = Mock()
|
||||||
|
mock_gcs_client.bucket.return_value = mock_bucket
|
||||||
|
mock_bucket.blob.return_value = mock_blob
|
||||||
|
mock_settings.gcs_client = mock_gcs_client
|
||||||
|
|
||||||
|
mock_settings.converter = Mock()
|
||||||
|
mock_settings.embedder = Mock()
|
||||||
|
mock_settings.chunker = Mock()
|
||||||
|
|
||||||
|
# Mock gather_files to return test files
|
||||||
|
mock_gather.return_value = ["gs://bucket/input/file1.pdf"]
|
||||||
|
|
||||||
|
# Mock process_file to return chunks and vectors
|
||||||
|
mock_chunks = [{"page_content": "content", "metadata": {"id": "chunk_1"}}]
|
||||||
|
mock_vectors = [
|
||||||
|
{
|
||||||
|
"id": "chunk_1",
|
||||||
|
"embedding": [0.1, 0.2],
|
||||||
|
"restricts": [{"namespace": "source", "allow": ["input"]}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
mock_process.return_value = (mock_chunks, mock_vectors)
|
||||||
|
|
||||||
|
run_pipeline(mock_settings)
|
||||||
|
|
||||||
|
# Verify all steps were called
|
||||||
|
mock_gather.assert_called_once_with("gs://bucket/input", mock_gcs_client)
|
||||||
|
mock_process.assert_called_once()
|
||||||
|
mock_upload.assert_called_once_with(
|
||||||
|
mock_chunks,
|
||||||
|
mock_vectors,
|
||||||
|
"gs://bucket/contents",
|
||||||
|
"gs://bucket/vectors/vectors.jsonl",
|
||||||
|
mock_gcs_client,
|
||||||
|
)
|
||||||
|
mock_create_index.assert_called_once()
|
||||||
|
|
||||||
|
@patch("knowledge_pipeline.pipeline.create_vector_index")
|
||||||
|
@patch("knowledge_pipeline.pipeline.upload_to_gcs")
|
||||||
|
@patch("knowledge_pipeline.pipeline.process_file")
|
||||||
|
@patch("knowledge_pipeline.pipeline.gather_pdfs")
|
||||||
|
def test_run_pipeline_multiple_files(
|
||||||
|
self,
|
||||||
|
mock_gather,
|
||||||
|
mock_process,
|
||||||
|
mock_upload,
|
||||||
|
mock_create_index,
|
||||||
|
):
|
||||||
|
mock_settings = Mock()
|
||||||
|
mock_settings.index_origin = "gs://bucket/input"
|
||||||
|
mock_settings.index_contents_dir = "gs://bucket/contents"
|
||||||
|
mock_settings.index_vectors_jsonl_path = "gs://bucket/vectors/vectors.jsonl"
|
||||||
|
mock_settings.index_name = "test-index"
|
||||||
|
mock_settings.index_vectors_dir = "gs://bucket/vectors"
|
||||||
|
mock_settings.index_dimensions = 768
|
||||||
|
mock_settings.index_distance_measure_type = (
|
||||||
|
DistanceMeasureType.DOT_PRODUCT_DISTANCE
|
||||||
|
)
|
||||||
|
mock_settings.index_deployment = "test_index_deployed"
|
||||||
|
mock_settings.index_machine_type = "e2-standard-16"
|
||||||
|
|
||||||
|
mock_gcs_client = Mock()
|
||||||
|
mock_bucket = Mock()
|
||||||
|
mock_blob = Mock()
|
||||||
|
mock_gcs_client.bucket.return_value = mock_bucket
|
||||||
|
mock_bucket.blob.return_value = mock_blob
|
||||||
|
mock_settings.gcs_client = mock_gcs_client
|
||||||
|
|
||||||
|
mock_settings.converter = Mock()
|
||||||
|
mock_settings.embedder = Mock()
|
||||||
|
mock_settings.chunker = Mock()
|
||||||
|
|
||||||
|
# Return multiple files
|
||||||
|
mock_gather.return_value = [
|
||||||
|
"gs://bucket/input/file1.pdf",
|
||||||
|
"gs://bucket/input/file2.pdf",
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_process.return_value = (
|
||||||
|
[{"page_content": "content", "metadata": {"id": "chunk_1"}}],
|
||||||
|
[{"id": "chunk_1", "embedding": [0.1], "restricts": []}],
|
||||||
|
)
|
||||||
|
|
||||||
|
run_pipeline(mock_settings)
|
||||||
|
|
||||||
|
# Verify process_file was called for each file
|
||||||
|
assert mock_process.call_count == 2
|
||||||
|
# Upload is called once with all accumulated chunks and vectors
|
||||||
|
mock_upload.assert_called_once()
|
||||||
40
utils/delete_endpoint.py
Normal file
40
utils/delete_endpoint.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Delete a GCP Vector Search endpoint by ID.
|
||||||
|
|
||||||
|
Undeploys any deployed indexes before deleting the endpoint.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python utils/delete_endpoint.py <endpoint_id> [--project PROJECT] [--location LOCATION]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
|
||||||
|
def delete_endpoint(endpoint_id: str, project: str, location: str) -> None:
|
||||||
|
aiplatform.init(project=project, location=location)
|
||||||
|
endpoint = aiplatform.MatchingEngineIndexEndpoint(endpoint_id)
|
||||||
|
|
||||||
|
print(f"Endpoint: {endpoint.display_name}")
|
||||||
|
|
||||||
|
for deployed in endpoint.deployed_indexes:
|
||||||
|
print(f"Undeploying index: {deployed.id}")
|
||||||
|
endpoint.undeploy_index(deployed_index_id=deployed.id)
|
||||||
|
print(f"Undeployed: {deployed.id}")
|
||||||
|
|
||||||
|
endpoint.delete()
|
||||||
|
print(f"Endpoint {endpoint_id} deleted successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Delete a GCP Vector Search endpoint.")
|
||||||
|
parser.add_argument("endpoint_id", help="The endpoint ID to delete.")
|
||||||
|
parser.add_argument("--project", default="bnt-orquestador-cognitivo-dev")
|
||||||
|
parser.add_argument("--location", default="us-central1")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
delete_endpoint(args.endpoint_id, args.project, args.location)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
132
utils/search_index.py
Normal file
132
utils/search_index.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Search a deployed Vertex AI Vector Search index.
|
||||||
|
|
||||||
|
Embeds a query, finds nearest neighbors, and retrieves chunk contents from GCS.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python utils/search_index.py "your search query" <endpoint_id> <index_deployment_id> \
|
||||||
|
[--source SOURCE] [--top-k 5] [--project PROJECT] [--location LOCATION]
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Basic search
|
||||||
|
uv run python utils/search_index.py "¿Cómo funciona el proceso?" 123456 blue_ivy_deployed
|
||||||
|
|
||||||
|
# Filter by source folder
|
||||||
|
uv run python utils/search_index.py "requisitos" 123456 blue_ivy_deployed --source "manuales"
|
||||||
|
|
||||||
|
# Return more results
|
||||||
|
uv run python utils/search_index.py "políticas" 123456 blue_ivy_deployed --top-k 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from google.cloud import aiplatform, storage
|
||||||
|
from pydantic_ai import Embedder
|
||||||
|
|
||||||
|
|
||||||
|
def search_index(
|
||||||
|
query: str,
|
||||||
|
endpoint_id: str,
|
||||||
|
deployed_index_id: str,
|
||||||
|
project: str,
|
||||||
|
location: str,
|
||||||
|
embedding_model: str,
|
||||||
|
contents_dir: str,
|
||||||
|
top_k: int,
|
||||||
|
source: str | None,
|
||||||
|
) -> None:
|
||||||
|
aiplatform.init(project=project, location=location)
|
||||||
|
|
||||||
|
embedder = Embedder(f"google-vertex:{embedding_model}")
|
||||||
|
query_embedding = embedder.embed_documents_sync([query]).embeddings[0]
|
||||||
|
|
||||||
|
endpoint = aiplatform.MatchingEngineIndexEndpoint(endpoint_id)
|
||||||
|
|
||||||
|
restricts = None
|
||||||
|
if source:
|
||||||
|
restricts = [
|
||||||
|
aiplatform.matching_engine.matching_engine_index_endpoint.Namespace(
|
||||||
|
name="source",
|
||||||
|
allow_tokens=[source],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
response = endpoint.find_neighbors(
|
||||||
|
deployed_index_id=deployed_index_id,
|
||||||
|
queries=[list(query_embedding)],
|
||||||
|
num_neighbors=top_k,
|
||||||
|
filter=restricts,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response or not response[0]:
|
||||||
|
print("No results found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
gcs_client = storage.Client()
|
||||||
|
neighbors = response[0]
|
||||||
|
|
||||||
|
print(f"Found {len(neighbors)} results for: {query!r}\n")
|
||||||
|
for i, neighbor in enumerate(neighbors, 1):
|
||||||
|
chunk_id = neighbor.id
|
||||||
|
distance = neighbor.distance
|
||||||
|
|
||||||
|
content = _fetch_chunk_content(gcs_client, contents_dir, chunk_id)
|
||||||
|
|
||||||
|
print(f"--- Result {i} (id={chunk_id}, distance={distance:.4f}) ---")
|
||||||
|
print(content)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_chunk_content(
|
||||||
|
gcs_client: storage.Client, contents_dir: str, chunk_id: str
|
||||||
|
) -> str:
|
||||||
|
"""Fetches a chunk's markdown content from GCS."""
|
||||||
|
uri = f"{contents_dir}/{chunk_id}.md"
|
||||||
|
bucket_name, _, obj_path = uri.removeprefix("gs://").partition("/")
|
||||||
|
blob = gcs_client.bucket(bucket_name).blob(obj_path)
|
||||||
|
if not blob.exists():
|
||||||
|
return f"[content not found: {uri}]"
|
||||||
|
return blob.download_as_text()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Search a deployed Vertex AI Vector Search index."
|
||||||
|
)
|
||||||
|
parser.add_argument("query", help="The search query text.")
|
||||||
|
parser.add_argument("endpoint_id", help="The deployed endpoint ID.")
|
||||||
|
parser.add_argument("deployed_index_id", help="The deployed index ID.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--source",
|
||||||
|
default=None,
|
||||||
|
help="Filter results by source folder (metadata namespace).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k", type=int, default=5, help="Number of results to return (default: 5)."
|
||||||
|
)
|
||||||
|
parser.add_argument("--project", default="bnt-orquestador-cognitivo-dev")
|
||||||
|
parser.add_argument("--location", default="us-central1")
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model", default="gemini-embedding-001", help="Embedding model name."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--contents-dir",
|
||||||
|
default="gs://bnt_orquestador_cognitivo_gcs_configs_dev/blue-ivy/contents",
|
||||||
|
help="GCS URI of the contents directory.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
search_index(
|
||||||
|
query=args.query,
|
||||||
|
endpoint_id=args.endpoint_id,
|
||||||
|
deployed_index_id=args.deployed_index_id,
|
||||||
|
project=args.project,
|
||||||
|
location=args.location,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
contents_dir=args.contents_dir,
|
||||||
|
top_k=args.top_k,
|
||||||
|
source=args.source,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user