225 lines
8.0 KiB
Python
225 lines
8.0 KiB
Python
"""
|
|
This script defines a Kubeflow Pipeline (KFP) for ingesting and processing documents.
|
|
|
|
The pipeline is designed to run on Vertex AI Pipelines and consists of the following steps:
|
|
1. **Gather Files**: Scans a GCS directory for PDF files to process.
|
|
2. **Process Files (in parallel)**: For each PDF file found, this step:
|
|
a. Converts the PDF to Markdown text.
|
|
b. Chunks the text if it's too long.
|
|
c. Generates a vector embedding for each chunk using a Vertex AI embedding model.
|
|
d. Saves the markdown content and the vector embedding to separate GCS output paths.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
from rag_eval.config import settings
|
|
|
|
|
|
def build_gcs_path(base_path: str, suffix: str) -> str:
|
|
"""Builds a GCS path by appending a suffix."""
|
|
return f"{base_path}{suffix}"
|
|
|
|
|
|
def gather_files(
|
|
input_dir: str,
|
|
) -> list:
|
|
"""Gathers all PDF file paths from a GCS directory."""
|
|
from google.cloud import storage
|
|
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
gcs_client = storage.Client()
|
|
bucket_name, prefix = input_dir.replace("gs://", "").split("/", 1)
|
|
bucket = gcs_client.bucket(bucket_name)
|
|
blob_list = bucket.list_blobs(prefix=prefix)
|
|
|
|
pdf_files = [
|
|
f"gs://{bucket_name}/{blob.name}"
|
|
for blob in blob_list
|
|
if blob.name.endswith(".pdf")
|
|
]
|
|
logging.info(f"Found {len(pdf_files)} PDF files in {input_dir}")
|
|
return pdf_files
|
|
|
|
|
|
def process_file(
|
|
file_path: str,
|
|
model_name: str,
|
|
contents_output_dir: str,
|
|
vectors_output_file: Path,
|
|
chunk_limit: int,
|
|
):
|
|
"""
|
|
Processes a single PDF file: converts to markdown, chunks, and generates embeddings.
|
|
The vector embeddings are written to a local JSONL file.
|
|
"""
|
|
# Imports are inside the function as KFP serializes this function
|
|
from pathlib import Path
|
|
|
|
from chunker.contextual_chunker import ContextualChunker
|
|
from document_converter.markdown import MarkdownConverter
|
|
from embedder.vertex_ai import VertexAIEmbedder
|
|
from google.cloud import storage
|
|
from llm.vertex_ai import VertexAILLM
|
|
from utils.normalize_filenames import normalize_string
|
|
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
# Initialize converters and embedders
|
|
converter = MarkdownConverter()
|
|
embedder = VertexAIEmbedder(model_name=model_name, project=settings.project_id, location=settings.location)
|
|
llm = VertexAILLM(project=settings.project_id, location=settings.location)
|
|
chunker = ContextualChunker(llm_client=llm, max_chunk_size=chunk_limit)
|
|
gcs_client = storage.Client()
|
|
|
|
file_id = normalize_string(Path(file_path).stem)
|
|
local_path = Path(f"/tmp/{Path(file_path).name}")
|
|
|
|
with open(vectors_output_file, "w", encoding="utf-8") as f:
|
|
try:
|
|
# Download file from GCS
|
|
bucket_name, blob_name = file_path.replace("gs://", "").split("/", 1)
|
|
bucket = gcs_client.bucket(bucket_name)
|
|
blob = bucket.blob(blob_name)
|
|
blob.download_to_filename(local_path)
|
|
logging.info(f"Processing file: {file_path}")
|
|
|
|
# Process the downloaded file
|
|
markdown_content = converter.process_file(local_path)
|
|
|
|
def upload_to_gcs(bucket_name, blob_name, data):
|
|
bucket = gcs_client.bucket(bucket_name)
|
|
blob = bucket.blob(blob_name)
|
|
blob.upload_from_string(data, content_type="text/markdown; charset=utf-8")
|
|
|
|
# Determine output bucket and paths for markdown
|
|
contents_bucket_name, contents_prefix = contents_output_dir.replace(
|
|
"gs://", ""
|
|
).split("/", 1)
|
|
|
|
if len(markdown_content) > chunk_limit:
|
|
chunks = chunker.process_text(markdown_content)
|
|
for i, chunk in enumerate(chunks):
|
|
chunk_id = f"{file_id}_{i}"
|
|
embedding = embedder.generate_embedding(chunk["page_content"])
|
|
|
|
# Upload markdown chunk
|
|
md_blob_name = f"{contents_prefix}/{chunk_id}.md"
|
|
upload_to_gcs(
|
|
contents_bucket_name, md_blob_name, chunk["page_content"]
|
|
)
|
|
|
|
# Write vector to local JSONL file
|
|
json_line = json.dumps({"id": chunk_id, "embedding": embedding})
|
|
f.write(json_line + '\n')
|
|
else:
|
|
embedding = embedder.generate_embedding(markdown_content)
|
|
|
|
# Upload markdown
|
|
md_blob_name = f"{contents_prefix}/{file_id}.md"
|
|
upload_to_gcs(contents_bucket_name, md_blob_name, markdown_content)
|
|
|
|
# Write vector to local JSONL file
|
|
json_line = json.dumps({"id": file_id, "embedding": embedding})
|
|
f.write(json_line + '\n')
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to process file {file_path}: {e}", exc_info=True)
|
|
raise
|
|
|
|
finally:
|
|
# Clean up the downloaded file
|
|
if os.path.exists(local_path):
|
|
os.remove(local_path)
|
|
|
|
|
|
|
|
|
|
|
|
def aggregate_vectors(
|
|
vector_artifacts: list, # This will be a list of paths to the artifact files
|
|
output_gcs_path: str,
|
|
):
|
|
"""
|
|
Aggregates multiple JSONL artifact files into a single JSONL file in GCS.
|
|
"""
|
|
from google.cloud import storage
|
|
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
# Create a temporary file to aggregate all vector data
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w", delete=False, encoding="utf-8"
|
|
) as temp_agg_file:
|
|
logging.info(f"Aggregating vectors into temporary file: {temp_agg_file.name}")
|
|
for artifact_path in vector_artifacts:
|
|
with open(artifact_path, "r", encoding="utf-8") as f:
|
|
# Each line is a complete JSON object
|
|
for line in f:
|
|
temp_agg_file.write(line) # line already includes newline
|
|
|
|
temp_file_path = temp_agg_file.name
|
|
|
|
logging.info("Uploading aggregated file to GCS...")
|
|
gcs_client = storage.Client()
|
|
bucket_name, blob_name = output_gcs_path.replace("gs://", "").split("/", 1)
|
|
bucket = gcs_client.bucket(bucket_name)
|
|
blob = bucket.blob(blob_name)
|
|
blob.upload_from_filename(temp_file_path, content_type="application/json; charset=utf-8")
|
|
|
|
logging.info(f"Successfully uploaded aggregated vectors to {output_gcs_path}")
|
|
|
|
# Clean up the temporary file
|
|
import os
|
|
|
|
os.remove(temp_file_path)
|
|
|
|
|
|
|
|
def create_vector_index(
|
|
vectors_dir: str,
|
|
):
|
|
"""Creates and deploys a Vertex AI Vector Search Index."""
|
|
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
|
|
|
from rag_eval.config import settings as config
|
|
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
try:
|
|
index_config = config.index
|
|
|
|
logging.info(
|
|
f"Initializing Vertex AI client for project '{config.project_id}' in '{config.location}'..."
|
|
)
|
|
vector_search = GoogleCloudVectorSearch(
|
|
project_id=config.project_id,
|
|
location=config.location,
|
|
bucket=config.bucket,
|
|
index_name=index_config.name,
|
|
)
|
|
|
|
logging.info(f"Starting creation of index '{index_config.name}'...")
|
|
vector_search.create_index(
|
|
name=index_config.name,
|
|
content_path=vectors_dir,
|
|
dimensions=index_config.dimensions,
|
|
)
|
|
logging.info(f"Index '{index_config.name}' created successfully.")
|
|
|
|
logging.info("Deploying index to a new endpoint...")
|
|
vector_search.deploy_index(
|
|
index_name=index_config.name, machine_type=index_config.machine_type
|
|
)
|
|
logging.info("Index deployed successfully!")
|
|
logging.info(f"Endpoint name: {vector_search.index_endpoint.display_name}")
|
|
logging.info(
|
|
f"Endpoint resource name: {vector_search.index_endpoint.resource_name}"
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"An error occurred during index creation or deployment: {e}", exc_info=True)
|
|
raise |