Merge pull request 'Code shred' (#4) from v2 into master
Reviewed-on: #4
This commit was merged in pull request #4.
This commit is contained in:
31
README.md
31
README.md
@@ -260,30 +260,18 @@ embeddings = embedder.generate_embeddings_batch(texts, batch_size=10)
|
||||
### **Opción 4: Almacenar en GCS**
|
||||
|
||||
```python
|
||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
||||
import gcsfs
|
||||
|
||||
storage = GoogleCloudFileStorage(bucket="mi-bucket")
|
||||
fs = gcsfs.GCSFileSystem()
|
||||
|
||||
# Subir archivo
|
||||
storage.upload_file(
|
||||
file_path="local_file.md",
|
||||
destination_blob_name="chunks/documento_0.md",
|
||||
content_type="text/markdown"
|
||||
)
|
||||
fs.put("local_file.md", "mi-bucket/chunks/documento_0.md")
|
||||
|
||||
# Listar archivos
|
||||
files = storage.list_files(path="chunks/")
|
||||
files = fs.ls("mi-bucket/chunks/")
|
||||
|
||||
# Descargar archivo
|
||||
file_stream = storage.get_file_stream("chunks/documento_0.md")
|
||||
content = file_stream.read().decode("utf-8")
|
||||
```
|
||||
|
||||
**CLI:**
|
||||
```bash
|
||||
file-storage upload local_file.md chunks/documento_0.md
|
||||
file-storage list chunks/
|
||||
file-storage download chunks/documento_0.md
|
||||
content = fs.cat_file("mi-bucket/chunks/documento_0.md").decode("utf-8")
|
||||
```
|
||||
|
||||
---
|
||||
@@ -340,10 +328,10 @@ vector-search delete mi-indice
|
||||
## 🔄 Flujo Completo de Ejemplo
|
||||
|
||||
```python
|
||||
import gcsfs
|
||||
from pathlib import Path
|
||||
from chunker.contextual_chunker import ContextualChunker
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
|
||||
# 1. Setup
|
||||
@@ -354,7 +342,7 @@ embedder = VertexAIEmbedder(
|
||||
project="mi-proyecto",
|
||||
location="us-central1"
|
||||
)
|
||||
storage = GoogleCloudFileStorage(bucket="mi-bucket")
|
||||
fs = gcsfs.GCSFileSystem()
|
||||
|
||||
# 2. Chunking
|
||||
documents = chunker.process_path(Path("documento.pdf"))
|
||||
@@ -368,10 +356,7 @@ for i, doc in enumerate(documents):
|
||||
embedding = embedder.generate_embedding(doc["page_content"])
|
||||
|
||||
# Guardar contenido en GCS
|
||||
storage.upload_file(
|
||||
file_path=f"temp_{chunk_id}.md",
|
||||
destination_blob_name=f"contents/{chunk_id}.md"
|
||||
)
|
||||
fs.put(f"temp_{chunk_id}.md", f"mi-bucket/contents/{chunk_id}.md")
|
||||
|
||||
# Guardar vector (escribir a JSONL localmente, luego subir)
|
||||
print(f"Chunk {chunk_id}: {len(embedding)} dimensiones")
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
[project]
|
||||
name = "index-gen"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"chunker",
|
||||
"document-converter",
|
||||
"embedder",
|
||||
"file-storage",
|
||||
"llm",
|
||||
"utils",
|
||||
"vector-search",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
index-gen = "index_gen.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.12,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.uv.sources]
|
||||
file-storage = { workspace = true }
|
||||
vector-search = { workspace = true }
|
||||
utils = { workspace = true }
|
||||
embedder = { workspace = true }
|
||||
chunker = { workspace = true }
|
||||
document-converter = { workspace = true }
|
||||
llm = { workspace = true }
|
||||
@@ -1,2 +0,0 @@
|
||||
def main() -> None:
|
||||
print("Hello from index-gen!")
|
||||
@@ -1,68 +0,0 @@
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
|
||||
from index_gen.main import (
|
||||
aggregate_vectors,
|
||||
build_gcs_path,
|
||||
create_vector_index,
|
||||
gather_files,
|
||||
process_file,
|
||||
)
|
||||
from rag_eval.config import settings
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_ingestion():
|
||||
"""Main function for the CLI script."""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
agent_config = settings.agent
|
||||
index_config = settings.index
|
||||
|
||||
if not agent_config or not index_config:
|
||||
raise ValueError("Agent or index configuration not found in config.yaml")
|
||||
|
||||
# Gather files
|
||||
files = gather_files(index_config.origin)
|
||||
|
||||
# Build output paths
|
||||
contents_output_dir = build_gcs_path(index_config.data, "/contents")
|
||||
vectors_output_dir = build_gcs_path(index_config.data, "/vectors")
|
||||
aggregated_vectors_gcs_path = build_gcs_path(
|
||||
index_config.data, "/vectors/vectors.json"
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
vector_artifact_paths = []
|
||||
|
||||
# Process files and create local artifacts
|
||||
for i, file in enumerate(files):
|
||||
artifact_path = temp_dir_path / f"vectors_{i}.jsonl"
|
||||
vector_artifact_paths.append(artifact_path)
|
||||
|
||||
process_file(
|
||||
file,
|
||||
agent_config.embedding_model,
|
||||
contents_output_dir,
|
||||
artifact_path, # Pass the local path
|
||||
index_config.chunk_limit,
|
||||
)
|
||||
|
||||
# Aggregate the local artifacts into one file in GCS
|
||||
aggregate_vectors(
|
||||
vector_artifacts=vector_artifact_paths,
|
||||
output_gcs_path=aggregated_vectors_gcs_path,
|
||||
)
|
||||
|
||||
# Create vector index
|
||||
create_vector_index(vectors_output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -1,238 +0,0 @@
|
||||
"""
|
||||
This script defines a Kubeflow Pipeline (KFP) for ingesting and processing documents.
|
||||
|
||||
The pipeline is designed to run on Vertex AI Pipelines and consists of the following steps:
|
||||
1. **Gather Files**: Scans a GCS directory for PDF files to process.
|
||||
2. **Process Files (in parallel)**: For each PDF file found, this step:
|
||||
a. Converts the PDF to Markdown text.
|
||||
b. Chunks the text if it's too long.
|
||||
c. Generates a vector embedding for each chunk using a Vertex AI embedding model.
|
||||
d. Saves the markdown content and the vector embedding to separate GCS output paths.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
|
||||
def build_gcs_path(base_path: str, suffix: str) -> str:
|
||||
"""Builds a GCS path by appending a suffix."""
|
||||
return f"{base_path}{suffix}"
|
||||
|
||||
|
||||
def gather_files(
|
||||
input_dir: str,
|
||||
) -> list:
|
||||
"""Gathers all PDF file paths from a GCS directory."""
|
||||
from google.cloud import storage
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
gcs_client = storage.Client()
|
||||
bucket_name, prefix = input_dir.replace("gs://", "").split("/", 1)
|
||||
bucket = gcs_client.bucket(bucket_name)
|
||||
blob_list = bucket.list_blobs(prefix=prefix)
|
||||
|
||||
pdf_files = [
|
||||
f"gs://{bucket_name}/{blob.name}"
|
||||
for blob in blob_list
|
||||
if blob.name.endswith(".pdf")
|
||||
]
|
||||
logging.info(f"Found {len(pdf_files)} PDF files in {input_dir}")
|
||||
return pdf_files
|
||||
|
||||
|
||||
def process_file(
|
||||
file_path: str,
|
||||
model_name: str,
|
||||
contents_output_dir: str,
|
||||
vectors_output_file: Path,
|
||||
chunk_limit: int,
|
||||
):
|
||||
"""
|
||||
Processes a single PDF file: converts to markdown, chunks, and generates embeddings.
|
||||
The vector embeddings are written to a local JSONL file.
|
||||
"""
|
||||
# Imports are inside the function as KFP serializes this function
|
||||
from pathlib import Path
|
||||
|
||||
from chunker.contextual_chunker import ContextualChunker
|
||||
from document_converter.markdown import MarkdownConverter
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from google.cloud import storage
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from utils.normalize_filenames import normalize_string
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
# Initialize converters and embedders
|
||||
converter = MarkdownConverter()
|
||||
embedder = VertexAIEmbedder(model_name=model_name, project=settings.project_id, location=settings.location)
|
||||
llm = VertexAILLM(project=settings.project_id, location=settings.location)
|
||||
chunker = ContextualChunker(llm_client=llm, max_chunk_size=chunk_limit)
|
||||
gcs_client = storage.Client()
|
||||
|
||||
file_id = normalize_string(Path(file_path).stem)
|
||||
local_path = Path(f"/tmp/{Path(file_path).name}")
|
||||
|
||||
with open(vectors_output_file, "w", encoding="utf-8") as f:
|
||||
try:
|
||||
# Download file from GCS
|
||||
bucket_name, blob_name = file_path.replace("gs://", "").split("/", 1)
|
||||
bucket = gcs_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
blob.download_to_filename(local_path)
|
||||
logging.info(f"Processing file: {file_path}")
|
||||
|
||||
# Process the downloaded file
|
||||
markdown_content = converter.process_file(local_path)
|
||||
|
||||
def upload_to_gcs(bucket_name, blob_name, data):
|
||||
bucket = gcs_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
blob.upload_from_string(data, content_type="text/markdown; charset=utf-8")
|
||||
|
||||
# Determine output bucket and paths for markdown
|
||||
contents_bucket_name, contents_prefix = contents_output_dir.replace(
|
||||
"gs://", ""
|
||||
).split("/", 1)
|
||||
|
||||
# Extract source folder from file path
|
||||
source_folder = Path(blob_name).parent.as_posix() if blob_name else ""
|
||||
|
||||
if len(markdown_content) > chunk_limit:
|
||||
chunks = chunker.process_text(markdown_content)
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_id = f"{file_id}_{i}"
|
||||
embedding = embedder.generate_embedding(chunk["page_content"])
|
||||
|
||||
# Upload markdown chunk
|
||||
md_blob_name = f"{contents_prefix}/{chunk_id}.md"
|
||||
upload_to_gcs(
|
||||
contents_bucket_name, md_blob_name, chunk["page_content"]
|
||||
)
|
||||
|
||||
# Write vector to local JSONL file with source folder
|
||||
vector_data = {
|
||||
"id": chunk_id,
|
||||
"embedding": embedding,
|
||||
"source_folder": source_folder
|
||||
}
|
||||
json_line = json.dumps(vector_data)
|
||||
f.write(json_line + '\n')
|
||||
else:
|
||||
embedding = embedder.generate_embedding(markdown_content)
|
||||
|
||||
# Upload markdown
|
||||
md_blob_name = f"{contents_prefix}/{file_id}.md"
|
||||
upload_to_gcs(contents_bucket_name, md_blob_name, markdown_content)
|
||||
|
||||
# Write vector to local JSONL file with source folder
|
||||
vector_data = {
|
||||
"id": file_id,
|
||||
"embedding": embedding,
|
||||
"source_folder": source_folder
|
||||
}
|
||||
json_line = json.dumps(vector_data)
|
||||
f.write(json_line + '\n')
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process file {file_path}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up the downloaded file
|
||||
if os.path.exists(local_path):
|
||||
os.remove(local_path)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def aggregate_vectors(
|
||||
vector_artifacts: list, # This will be a list of paths to the artifact files
|
||||
output_gcs_path: str,
|
||||
):
|
||||
"""
|
||||
Aggregates multiple JSONL artifact files into a single JSONL file in GCS.
|
||||
"""
|
||||
from google.cloud import storage
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
# Create a temporary file to aggregate all vector data
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, encoding="utf-8"
|
||||
) as temp_agg_file:
|
||||
logging.info(f"Aggregating vectors into temporary file: {temp_agg_file.name}")
|
||||
for artifact_path in vector_artifacts:
|
||||
with open(artifact_path, "r", encoding="utf-8") as f:
|
||||
# Each line is a complete JSON object
|
||||
for line in f:
|
||||
temp_agg_file.write(line) # line already includes newline
|
||||
|
||||
temp_file_path = temp_agg_file.name
|
||||
|
||||
logging.info("Uploading aggregated file to GCS...")
|
||||
gcs_client = storage.Client()
|
||||
bucket_name, blob_name = output_gcs_path.replace("gs://", "").split("/", 1)
|
||||
bucket = gcs_client.bucket(bucket_name)
|
||||
blob = bucket.blob(blob_name)
|
||||
blob.upload_from_filename(temp_file_path, content_type="application/json; charset=utf-8")
|
||||
|
||||
logging.info(f"Successfully uploaded aggregated vectors to {output_gcs_path}")
|
||||
|
||||
# Clean up the temporary file
|
||||
import os
|
||||
|
||||
os.remove(temp_file_path)
|
||||
|
||||
|
||||
|
||||
def create_vector_index(
|
||||
vectors_dir: str,
|
||||
):
|
||||
"""Creates and deploys a Vertex AI Vector Search Index."""
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
try:
|
||||
index_config = config.index
|
||||
|
||||
logging.info(
|
||||
f"Initializing Vertex AI client for project '{config.project_id}' in '{config.location}'..."
|
||||
)
|
||||
vector_search = GoogleCloudVectorSearch(
|
||||
project_id=config.project_id,
|
||||
location=config.location,
|
||||
bucket=config.bucket,
|
||||
index_name=index_config.name,
|
||||
)
|
||||
|
||||
logging.info(f"Starting creation of index '{index_config.name}'...")
|
||||
vector_search.create_index(
|
||||
name=index_config.name,
|
||||
content_path=vectors_dir,
|
||||
dimensions=index_config.dimensions,
|
||||
)
|
||||
logging.info(f"Index '{index_config.name}' created successfully.")
|
||||
|
||||
logging.info("Deploying index to a new endpoint...")
|
||||
vector_search.deploy_index(
|
||||
index_name=index_config.name, machine_type=index_config.machine_type
|
||||
)
|
||||
logging.info("Index deployed successfully!")
|
||||
logging.info(f"Endpoint name: {vector_search.index_endpoint.display_name}")
|
||||
logging.info(
|
||||
f"Endpoint resource name: {vector_search.index_endpoint.resource_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred during index creation or deployment: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -1,29 +1,15 @@
|
||||
# Configuración de Google Cloud Platform
|
||||
# Google Cloud Platform
|
||||
project_id: "tu-proyecto-gcp"
|
||||
location: "us-central1" # o us-east1, europe-west1, etc.
|
||||
bucket: "tu-bucket-nombre"
|
||||
location: "us-central1"
|
||||
|
||||
# Configuración del índice vectorial
|
||||
index:
|
||||
name: "mi-indice-rag"
|
||||
dimensions: 768 # Para text-embedding-005 usa 768
|
||||
machine_type: "e2-standard-2" # Tipo de máquina para el endpoint
|
||||
approximate_neighbors_count: 150
|
||||
distance_measure_type: "DOT_PRODUCT_DISTANCE" # O "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE"
|
||||
# Embedding model
|
||||
agent_embedding_model: "text-embedding-005"
|
||||
|
||||
# Configuración de embeddings
|
||||
embedder:
|
||||
model_name: "text-embedding-005"
|
||||
task: "RETRIEVAL_DOCUMENT" # O "RETRIEVAL_QUERY" para queries
|
||||
|
||||
# Configuración de LLM para chunking
|
||||
llm:
|
||||
model: "gemini-2.0-flash" # O "gemini-1.5-pro", "gemini-1.5-flash"
|
||||
|
||||
# Configuración de chunking
|
||||
chunking:
|
||||
strategy: "contextual" # "recursive", "contextual", "llm"
|
||||
max_chunk_size: 800
|
||||
chunk_overlap: 200 # Solo para LLMChunker
|
||||
merge_related: true # Solo para LLMChunker
|
||||
extract_images: true # Solo para LLMChunker
|
||||
# Vector index
|
||||
index_name: "mi-indice-rag"
|
||||
index_dimensions: 768
|
||||
index_machine_type: "e2-standard-16"
|
||||
index_origin: "gs://tu-bucket/input/"
|
||||
index_destination: "gs://tu-bucket/output/"
|
||||
index_chunk_limit: 800
|
||||
index_distance_measure_type: "DOT_PRODUCT_DISTANCE"
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
[project]
|
||||
name = "chunker"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"chonkie>=1.1.2",
|
||||
"pdf2image>=1.17.0",
|
||||
"pypdf>=6.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
llm-chunker = "chunker.llm_chunker:app"
|
||||
recursive-chunker = "chunker.recursive_chunker:app"
|
||||
contextual-chunker = "chunker.contextual_chunker:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -1 +0,0 @@
|
||||
3.10
|
||||
@@ -1,20 +0,0 @@
|
||||
[project]
|
||||
name = "document-converter"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"markitdown[pdf]>=0.1.2",
|
||||
"pypdf>=6.1.2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
convert-md = "document_converter.markdown:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -1,2 +0,0 @@
|
||||
def hello() -> str:
|
||||
return "Hello from document-converter!"
|
||||
@@ -1,35 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class BaseConverter(ABC):
|
||||
"""
|
||||
Abstract base class for a remote file processor.
|
||||
|
||||
This class defines the interface for listing and processing files from a remote source.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def process_file(self, file: str) -> str:
|
||||
"""
|
||||
Processes a single file from a remote source and returns the result.
|
||||
|
||||
Args:
|
||||
file: The path to the file to be processed from the remote source.
|
||||
|
||||
Returns:
|
||||
A string containing the processing result for the file.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_files(self, files: List[str]) -> List[str]:
|
||||
"""
|
||||
Processes a list of files from a remote source and returns the results.
|
||||
|
||||
Args:
|
||||
files: A list of file paths to be processed from the remote source.
|
||||
|
||||
Returns:
|
||||
A list of strings containing the processing results for each file.
|
||||
"""
|
||||
return [self.process_file(file) for file in files]
|
||||
@@ -1,131 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Annotated, BinaryIO, Union
|
||||
|
||||
import typer
|
||||
from markitdown import MarkItDown
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress
|
||||
|
||||
from .base import BaseConverter
|
||||
|
||||
|
||||
class MarkdownConverter(BaseConverter):
|
||||
"""Converts PDF documents to Markdown format."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initializes the MarkItDown converter."""
|
||||
self.markitdown = MarkItDown(enable_plugins=False)
|
||||
|
||||
def process_file(self, file_stream: Union[str, Path, BinaryIO]) -> str:
|
||||
"""
|
||||
Processes a single file and returns the result as a markdown string.
|
||||
|
||||
Args:
|
||||
file_stream: A file path (string or Path) or a binary file stream.
|
||||
|
||||
Returns:
|
||||
The converted markdown content as a string.
|
||||
"""
|
||||
result = self.markitdown.convert(file_stream)
|
||||
return result.text_content
|
||||
|
||||
|
||||
# --- CLI Application ---
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_path: Annotated[
|
||||
Path,
|
||||
typer.Argument(
|
||||
help="Path to the input PDF file or directory.",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
dir_okay=True,
|
||||
readable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
],
|
||||
output_path: Annotated[
|
||||
Path,
|
||||
typer.Argument(
|
||||
help="Path for the output Markdown file or directory.",
|
||||
file_okay=True,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
resolve_path=True,
|
||||
),
|
||||
],
|
||||
):
|
||||
"""
|
||||
Converts a PDF file or a directory of PDF files into Markdown.
|
||||
"""
|
||||
console = Console()
|
||||
converter = MarkdownConverter()
|
||||
|
||||
if input_path.is_dir():
|
||||
# --- Directory Processing ---
|
||||
console.print(f"[bold green]Processing directory:[/bold green] {input_path}")
|
||||
output_dir = output_path
|
||||
|
||||
if output_dir.exists() and not output_dir.is_dir():
|
||||
console.print(
|
||||
f"[bold red]Error:[/bold red] Input is a directory, but output path '{output_dir}' is an existing file."
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
pdf_files = sorted(list(input_path.rglob("*.pdf")))
|
||||
if not pdf_files:
|
||||
console.print("[yellow]No PDF files found in the input directory.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"Found {len(pdf_files)} PDF files to convert.")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with Progress(console=console) as progress:
|
||||
task = progress.add_task("[cyan]Converting...", total=len(pdf_files))
|
||||
for pdf_file in pdf_files:
|
||||
relative_path = pdf_file.relative_to(input_path)
|
||||
output_md_path = output_dir.joinpath(relative_path).with_suffix(".md")
|
||||
output_md_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
progress.update(task, description=f"Processing {pdf_file.name}")
|
||||
try:
|
||||
markdown_content = converter.process_file(pdf_file)
|
||||
output_md_path.write_text(markdown_content, encoding="utf-8")
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"\n[bold red]Failed to process {pdf_file.name}:[/bold red] {e}"
|
||||
)
|
||||
progress.advance(task)
|
||||
|
||||
console.print(
|
||||
f"[bold green]Conversion complete.[/bold green] Output directory: {output_dir}"
|
||||
)
|
||||
|
||||
elif input_path.is_file():
|
||||
# --- Single File Processing ---
|
||||
console.print(f"[bold green]Processing file:[/bold green] {input_path.name}")
|
||||
final_output_path = output_path
|
||||
|
||||
# If output path is a directory, create a file inside it
|
||||
if output_path.is_dir():
|
||||
final_output_path = output_path / input_path.with_suffix(".md").name
|
||||
|
||||
final_output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
markdown_content = converter.process_file(input_path)
|
||||
final_output_path.write_text(markdown_content, encoding="utf-8")
|
||||
console.print(
|
||||
f"[bold green]Successfully converted file to:[/bold green] {final_output_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]Error processing file:[/bold red] {e}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -1 +0,0 @@
|
||||
3.10
|
||||
@@ -1,16 +0,0 @@
|
||||
[project]
|
||||
name = "embedder"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"google-cloud-aiplatform>=1.106.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -1,79 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseEmbedder(ABC):
|
||||
"""Base class for all embedding models."""
|
||||
|
||||
@abstractmethod
|
||||
def generate_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate embeddings for text.
|
||||
|
||||
Args:
|
||||
text: Single text string or list of texts
|
||||
|
||||
Returns:
|
||||
Single embedding vector or list of embedding vectors
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Generate embeddings for a batch of texts.
|
||||
|
||||
Args:
|
||||
texts: List of text strings
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
pass
|
||||
|
||||
def preprocess_text(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
into_lowercase: bool = False,
|
||||
normalize_whitespace: bool = True,
|
||||
remove_punctuation: bool = False,
|
||||
) -> str:
|
||||
"""Preprocess text before embedding."""
|
||||
# Basic preprocessing
|
||||
text = text.strip()
|
||||
|
||||
if into_lowercase:
|
||||
text = text.lower()
|
||||
|
||||
if normalize_whitespace:
|
||||
text = " ".join(text.split())
|
||||
|
||||
if remove_punctuation:
|
||||
import string
|
||||
|
||||
text = text.translate(str.maketrans("", "", string.punctuation))
|
||||
|
||||
return text
|
||||
|
||||
def normalize_embedding(self, embedding: List[float]) -> List[float]:
|
||||
"""Normalize embedding vector to unit length."""
|
||||
norm = np.linalg.norm(embedding)
|
||||
if norm > 0:
|
||||
return (np.array(embedding) / norm).tolist()
|
||||
return embedding
|
||||
|
||||
@abstractmethod
|
||||
async def async_generate_embedding(self, text: str) -> List[float]:
|
||||
"""
|
||||
Generate embeddings for text.
|
||||
|
||||
Args:
|
||||
text: Single text string or list of texts
|
||||
|
||||
Returns:
|
||||
Single embedding vector or list of embedding vectors
|
||||
"""
|
||||
pass
|
||||
@@ -1,77 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from .base import BaseEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VertexAIEmbedder(BaseEmbedder):
|
||||
"""Embedder using Vertex AI text embedding models."""
|
||||
|
||||
def __init__(
|
||||
self, model_name: str, project: str, location: str, task: str = "RETRIEVAL_DOCUMENT"
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=project,
|
||||
location=location,
|
||||
)
|
||||
self.task = task
|
||||
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=30)
|
||||
# )
|
||||
def generate_embedding(self, text: str) -> List[float]:
|
||||
preprocessed_text = self.preprocess_text(text)
|
||||
result = self.client.models.embed_content(
|
||||
model=self.model_name, contents=preprocessed_text, config=types.EmbedContentConfig(task_type=self.task)
|
||||
)
|
||||
return result.embeddings[0].values
|
||||
|
||||
# @retry(
|
||||
# stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=30)
|
||||
# )
|
||||
def generate_embeddings_batch(
|
||||
self, texts: List[str], batch_size: int = 10
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings for a batch of texts."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Preprocess texts
|
||||
preprocessed_texts = [self.preprocess_text(text) for text in texts]
|
||||
|
||||
# Process in batches if necessary
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(preprocessed_texts), batch_size):
|
||||
batch = preprocessed_texts[i : i + batch_size]
|
||||
|
||||
# Generate embeddings for batch
|
||||
result = self.client.models.embed_content(
|
||||
model=self.model_name, contents=batch, config=types.EmbedContentConfig(task_type=self.task)
|
||||
)
|
||||
|
||||
# Extract values
|
||||
batch_embeddings = [emb.values for emb in result.embeddings]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Rate limiting
|
||||
if i + batch_size < len(preprocessed_texts):
|
||||
time.sleep(0.1) # Small delay between batches
|
||||
|
||||
return all_embeddings
|
||||
|
||||
async def async_generate_embedding(self, text: str) -> List[float]:
|
||||
preprocessed_text = self.preprocess_text(text)
|
||||
result = await self.client.aio.models.embed_content(
|
||||
model=self.model_name, contents=preprocessed_text, config=types.EmbedContentConfig(task_type=self.task)
|
||||
)
|
||||
return result.embeddings[0].values
|
||||
@@ -1 +0,0 @@
|
||||
3.10
|
||||
@@ -1,22 +0,0 @@
|
||||
[project]
|
||||
name = "file-storage"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"gcloud-aio-storage>=9.6.1",
|
||||
"google-cloud-storage>=2.19.0",
|
||||
"aiohttp>=3.10.11,<4",
|
||||
"typer>=0.12.3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
file-storage = "file_storage.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -1,2 +0,0 @@
|
||||
def hello() -> str:
|
||||
return "Hello from file-storage!"
|
||||
@@ -1,48 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO, List, Optional
|
||||
|
||||
|
||||
class BaseFileStorage(ABC):
|
||||
"""
|
||||
Abstract base class for a remote file processor.
|
||||
|
||||
This class defines the interface for listing and processing files from a remote source.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Uploads a file to the remote source.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: The name of the file in the remote source.
|
||||
content_type: The content type of the file.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Lists files from a remote location.
|
||||
|
||||
Args:
|
||||
path: The path to a specific file or directory in the remote bucket.
|
||||
If None, it recursively lists all files in the bucket.
|
||||
|
||||
Returns:
|
||||
A list of file paths.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""
|
||||
Gets a file from the remote source and returns it as a file-like object.
|
||||
"""
|
||||
...
|
||||
@@ -1,89 +0,0 @@
|
||||
import os
|
||||
from typing import Annotated
|
||||
|
||||
import rich
|
||||
import typer
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
from .google_cloud import GoogleCloudFileStorage
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def get_storage_client() -> GoogleCloudFileStorage:
|
||||
return GoogleCloudFileStorage(bucket=settings.bucket)
|
||||
|
||||
|
||||
@app.command("upload")
|
||||
def upload(
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: Annotated[str, typer.Option()] = None,
|
||||
):
|
||||
"""
|
||||
Uploads a file or directory to the remote source.
|
||||
"""
|
||||
storage_client = get_storage_client()
|
||||
if os.path.isdir(file_path):
|
||||
for root, _, files in os.walk(file_path):
|
||||
for file in files:
|
||||
local_file_path = os.path.join(root, file)
|
||||
# preserve the directory structure and use forward slashes for blob name
|
||||
dest_blob_name = os.path.join(
|
||||
destination_blob_name, os.path.relpath(local_file_path, file_path)
|
||||
).replace(os.sep, "/")
|
||||
storage_client.upload_file(
|
||||
local_file_path, dest_blob_name, content_type
|
||||
)
|
||||
rich.print(
|
||||
f"[green]File {local_file_path} uploaded to {dest_blob_name}.[/green]"
|
||||
)
|
||||
rich.print(
|
||||
f"[bold green]Directory {file_path} uploaded to {destination_blob_name}.[/bold green]"
|
||||
)
|
||||
else:
|
||||
storage_client.upload_file(file_path, destination_blob_name, content_type)
|
||||
rich.print(
|
||||
f"[green]File {file_path} uploaded to {destination_blob_name}.[/green]"
|
||||
)
|
||||
|
||||
|
||||
@app.command("list")
|
||||
def list_items(path: Annotated[str, typer.Option()] = None):
|
||||
"""
|
||||
Obtain a list of all files at the given location inside the remote bucket
|
||||
If path is none, recursively shows all files in the remote bucket.
|
||||
"""
|
||||
storage_client = get_storage_client()
|
||||
files = storage_client.list_files(path)
|
||||
for file in files:
|
||||
rich.print(f"[blue]{file}[/blue]")
|
||||
|
||||
|
||||
@app.command("download")
|
||||
def download(file_name: str, destination_path: str):
|
||||
"""
|
||||
Gets a file from the remote source and returns it as a file-like object.
|
||||
"""
|
||||
storage_client = get_storage_client()
|
||||
file_stream = storage_client.get_file_stream(file_name)
|
||||
with open(destination_path, "wb") as f:
|
||||
f.write(file_stream.read())
|
||||
rich.print(f"[green]File {file_name} downloaded to {destination_path}[/green]")
|
||||
|
||||
|
||||
@app.command("delete")
|
||||
def delete(path: str):
|
||||
"""
|
||||
Deletes all files at the given location inside the remote bucket.
|
||||
If path is a single file, it will delete only that file.
|
||||
If path is a directory, it will delete all files in that directory.
|
||||
"""
|
||||
storage_client = get_storage_client()
|
||||
storage_client.delete_files(path)
|
||||
rich.print(f"[bold red]Files at {path} deleted.[/bold red]")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -1,138 +0,0 @@
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
from typing import BinaryIO, List, Optional
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio.storage import Storage
|
||||
from google.cloud import storage
|
||||
|
||||
from .base import BaseFileStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GoogleCloudFileStorage(BaseFileStorage):
|
||||
def __init__(self, bucket: str) -> None:
|
||||
self.bucket_name = bucket
|
||||
|
||||
self.storage_client = storage.Client()
|
||||
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._aio_storage: Storage | None = None
|
||||
self._cache: dict[str, bytes] = {}
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Uploads a file to the remote source.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: The name of the file in the remote source.
|
||||
content_type: The content type of the file.
|
||||
"""
|
||||
blob = self.bucket_client.blob(destination_blob_name)
|
||||
blob.upload_from_filename(
|
||||
file_path,
|
||||
content_type=content_type,
|
||||
if_generation_match=0,
|
||||
)
|
||||
self._cache.pop(destination_blob_name, None)
|
||||
|
||||
def list_files(self, path: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Obtain a list of all files at the given location inside the remote bucket
|
||||
If path is none, recursively shows all files in the remote bucket.
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(self.bucket_name, prefix=path)
|
||||
return [blob.name for blob in blobs]
|
||||
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""
|
||||
Gets a file from the remote source and returns it as a file-like object.
|
||||
"""
|
||||
if file_name not in self._cache:
|
||||
blob = self.bucket_client.blob(file_name)
|
||||
self._cache[file_name] = blob.download_as_bytes()
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(limit=300, limit_per_host=50)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def _get_aio_storage(self) -> Storage:
|
||||
if self._aio_storage is None:
|
||||
self._aio_storage = Storage(session=self._get_aio_session())
|
||||
return self._aio_storage
|
||||
|
||||
async def async_get_file_stream(
|
||||
self, file_name: str, max_retries: int = 3
|
||||
) -> BinaryIO:
|
||||
"""
|
||||
Gets a file from the remote source asynchronously and returns it as a file-like object.
|
||||
Retries on transient errors (429, 5xx, timeouts) with exponential backoff.
|
||||
"""
|
||||
if file_name in self._cache:
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
storage_client = self._get_aio_storage()
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self._cache[file_name] = await storage_client.download(
|
||||
self.bucket_name, file_name
|
||||
)
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
except asyncio.TimeoutError as exc:
|
||||
last_exception = exc
|
||||
logger.warning(
|
||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
||||
self.bucket_name, file_name, attempt + 1, max_retries,
|
||||
)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
last_exception = exc
|
||||
if exc.status == 429 or exc.status >= 500:
|
||||
logger.warning(
|
||||
"HTTP %d downloading gs://%s/%s (attempt %d/%d)",
|
||||
exc.status, self.bucket_name, file_name,
|
||||
attempt + 1, max_retries,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
delay = 0.5 * (2 ** attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
||||
f"after {max_retries} attempts"
|
||||
) from last_exception
|
||||
|
||||
def delete_files(self, path: str) -> None:
|
||||
"""
|
||||
Deletes all files at the given location inside the remote bucket.
|
||||
If path is a single file, it will delete only that file.
|
||||
If path is a directory, it will delete all files in that directory.
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(self.bucket_name, prefix=path)
|
||||
for blob in blobs:
|
||||
blob.delete()
|
||||
self._cache.pop(blob.name, None)
|
||||
@@ -1 +0,0 @@
|
||||
3.10
|
||||
@@ -1,18 +0,0 @@
|
||||
[project]
|
||||
name = "llm"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"google-genai>=1.20.0",
|
||||
"pydantic>=2.11.7",
|
||||
"tenacity>=9.1.2",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -1,2 +0,0 @@
|
||||
def hello() -> str:
|
||||
return "Hello from llm!"
|
||||
@@ -1,128 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
name: str
|
||||
arguments: dict
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int | None = 0
|
||||
thought_tokens: int | None = 0
|
||||
response_tokens: int | None = 0
|
||||
|
||||
@field_validator("prompt_tokens", "thought_tokens", "response_tokens", mode="before")
|
||||
@classmethod
|
||||
def _validate_tokens(cls, v: int | None) -> int:
|
||||
return v or 0
|
||||
|
||||
def __add__(self, other):
|
||||
return Usage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
thought_tokens=self.thought_tokens + other.thought_tokens,
|
||||
response_tokens=self.response_tokens + other.response_tokens
|
||||
)
|
||||
|
||||
def get_cost(self, name: str) -> int:
|
||||
million = 1000000
|
||||
if name == "gemini-2.5-pro":
|
||||
if self.prompt_tokens > 200000:
|
||||
input_cost = self.prompt_tokens * (2.5/million)
|
||||
output_cost = self.thought_tokens * (15/million) + self.response_tokens * (15/million)
|
||||
else:
|
||||
input_cost = self.prompt_tokens * (1.25/million)
|
||||
output_cost = self.thought_tokens * (10/million) + self.response_tokens * (10/million)
|
||||
return (input_cost + output_cost) * 18.65
|
||||
if name == "gemini-2.5-flash":
|
||||
input_cost = self.prompt_tokens * (0.30/million)
|
||||
output_cost = self.thought_tokens * (2.5/million) + self.response_tokens * (2.5/million)
|
||||
return (input_cost + output_cost) * 18.65
|
||||
else:
|
||||
raise Exception("Invalid model")
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
"""A class to represent a single generation from a model.
|
||||
|
||||
Attributes:
|
||||
text: The generated text.
|
||||
usage: A dictionary containing usage metadata.
|
||||
"""
|
||||
|
||||
text: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
usage: Usage = Usage()
|
||||
extra: dict = {}
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""An abstract base class for all LLMs."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
tools: list | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> Generation:
|
||||
"""Generates text from a prompt.
|
||||
|
||||
Args:
|
||||
model: The model to use for generation.
|
||||
prompt: The prompt to generate text from.
|
||||
tools: An optional list of tools to use for generation.
|
||||
system_prompt: An optional system prompt to guide the model's behavior.
|
||||
|
||||
Returns:
|
||||
A Generation object containing the generated text and usage metadata.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def structured_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
response_model: Type[T],
|
||||
tools: list | None = None,
|
||||
) -> T:
|
||||
"""Generates structured data from a prompt.
|
||||
|
||||
Args:
|
||||
model: The model to use for generation.
|
||||
prompt: The prompt to generate text from.
|
||||
response_model: The pydantic model to parse the response into.
|
||||
tools: An optional list of tools to use for generation.
|
||||
|
||||
Returns:
|
||||
An instance of the provided pydantic model.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def async_generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
tools: list | None = None,
|
||||
system_prompt: str | None = None,
|
||||
tool_mode: str = "AUTO",
|
||||
) -> Generation:
|
||||
"""Generates text from a prompt.
|
||||
|
||||
Args:
|
||||
model: The model to use for generation.
|
||||
prompt: The prompt to generate text from.
|
||||
tools: An optional list of tools to use for generation.
|
||||
system_prompt: An optional system prompt to guide the model's behavior.
|
||||
|
||||
Returns:
|
||||
A Generation object containing the generated text and usage metadata.
|
||||
"""
|
||||
...
|
||||
@@ -1,181 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Type
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
from .base import BaseLLM, Generation, T, ToolCall, Usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VertexAILLM(BaseLLM):
|
||||
"""A class for interacting with the Vertex AI API."""
|
||||
|
||||
def __init__(
|
||||
self, project: str | None = None, location: str | None = None, thinking: int = 0
|
||||
) -> None:
|
||||
"""Initializes the VertexAILLM client.
|
||||
Args:
|
||||
project: The Google Cloud project ID.
|
||||
location: The Google Cloud location.
|
||||
"""
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=project or settings.project_id,
|
||||
location=location or settings.location,
|
||||
)
|
||||
self.thinking_budget = thinking
|
||||
|
||||
# @retry(
|
||||
# wait=wait_exponential(multiplier=1, min=2, max=60),
|
||||
# stop=stop_after_attempt(3),
|
||||
# reraise=True,
|
||||
# )
|
||||
def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
tools: list = [],
|
||||
system_prompt: str | None = None,
|
||||
tool_mode: str = "AUTO",
|
||||
) -> Generation:
|
||||
"""Generates text using the specified model and prompt.
|
||||
Args:
|
||||
model: The name of the model to use for generation.
|
||||
prompt: The prompt to use for generation.
|
||||
tools: A list of tools to use for generation.
|
||||
system_prompt: An optional system prompt to guide the model's behavior.
|
||||
Returns:
|
||||
A Generation object containing the generated text and usage metadata.
|
||||
"""
|
||||
logger.debug("Entering VertexAILLM.generate")
|
||||
logger.debug(f"Model: {model}, Tool Mode: {tool_mode}")
|
||||
logger.debug(f"System prompt: {system_prompt}")
|
||||
logger.debug("Calling Vertex AI API: models.generate_content...")
|
||||
response = self.client.models.generate_content(
|
||||
model=model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
tools=tools,
|
||||
system_instruction=system_prompt,
|
||||
thinking_config=genai.types.ThinkingConfig(
|
||||
thinking_budget=self.thinking_budget
|
||||
),
|
||||
tool_config=types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
mode=tool_mode
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
logger.debug("Received response from Vertex AI API.")
|
||||
logger.debug(f"API Response: {response}")
|
||||
|
||||
return self._create_generation(response)
|
||||
|
||||
|
||||
# @retry(
|
||||
# wait=wait_exponential(multiplier=1, min=2, max=60),
|
||||
# stop=stop_after_attempt(3),
|
||||
# reraise=True,
|
||||
# )
|
||||
def structured_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
response_model: Type[T],
|
||||
system_prompt: str | None = None,
|
||||
tools: list | None = None,
|
||||
) -> T:
|
||||
"""Generates structured data from a prompt.
|
||||
Args:
|
||||
model: The model to use for generation.
|
||||
prompt: The prompt to generate text from.
|
||||
response_model: The pydantic model to parse the response into.
|
||||
tools: An optional list of tools to use for generation.
|
||||
Returns:
|
||||
An instance of the provided pydantic model.
|
||||
"""
|
||||
config = genai.types.GenerateContentConfig(
|
||||
response_mime_type="application/json",
|
||||
response_schema=response_model,
|
||||
system_instruction=system_prompt,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
response: genai.types.GenerateContentResponse = (
|
||||
self.client.models.generate_content(
|
||||
model=model, contents=prompt, config=config
|
||||
)
|
||||
)
|
||||
|
||||
return response_model.model_validate_json(response.text)
|
||||
|
||||
# @retry(
|
||||
# wait=wait_exponential(multiplier=1, min=2, max=60),
|
||||
# stop=stop_after_attempt(3),
|
||||
# reraise=True,
|
||||
# )
|
||||
async def async_generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
tools: list = [],
|
||||
system_prompt: str | None = None,
|
||||
tool_mode: str = "AUTO",
|
||||
) -> Generation:
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
tools=tools,
|
||||
system_instruction=system_prompt,
|
||||
thinking_config=genai.types.ThinkingConfig(
|
||||
thinking_budget=self.thinking_budget
|
||||
),
|
||||
tool_config=types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
mode=tool_mode
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return self._create_generation(response)
|
||||
|
||||
|
||||
def _create_generation(self, response):
|
||||
logger.debug("Creating Generation object from API response.")
|
||||
m=response.usage_metadata
|
||||
usage = Usage(
|
||||
prompt_tokens=m.prompt_token_count,
|
||||
thought_tokens=m.thoughts_token_count or 0,
|
||||
response_tokens=m.candidates_token_count
|
||||
)
|
||||
|
||||
logger.debug(f"{usage=}")
|
||||
logger.debug(f"{response=}")
|
||||
|
||||
candidate = response.candidates[0]
|
||||
|
||||
tool_calls = []
|
||||
|
||||
for part in candidate.content.parts:
|
||||
if fn := part.function_call:
|
||||
tool_calls.append(ToolCall(name=fn.name, arguments=fn.args))
|
||||
|
||||
if len(tool_calls) > 0:
|
||||
logger.debug(f"Found {len(tool_calls)} tool calls.")
|
||||
return Generation(
|
||||
tool_calls=tool_calls,
|
||||
usage=usage,
|
||||
extra={"original_content": candidate.content}
|
||||
)
|
||||
|
||||
logger.debug("No tool calls found, returning text response.")
|
||||
text = candidate.content.parts[0].text
|
||||
return Generation(text=text, usage=usage)
|
||||
@@ -1,17 +0,0 @@
|
||||
[project]
|
||||
name = "utils"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = []
|
||||
|
||||
[project.scripts]
|
||||
normalize-filenames = "utils.normalize_filenames:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -1,2 +0,0 @@
|
||||
def hello() -> str:
|
||||
return "Hello from utils!"
|
||||
@@ -1,115 +0,0 @@
|
||||
"""Normalize filenames in a directory."""
|
||||
|
||||
import pathlib
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def normalize_string(s: str) -> str:
|
||||
"""Normalizes a string to be a valid filename."""
|
||||
# 1. Decompose Unicode characters into base characters and diacritics
|
||||
nfkd_form = unicodedata.normalize("NFKD", s)
|
||||
# 2. Keep only the base characters (non-diacritics)
|
||||
only_ascii = "".join([c for c in nfkd_form if not unicodedata.combining(c)])
|
||||
# 3. To lowercase
|
||||
only_ascii = only_ascii.lower()
|
||||
# 4. Replace spaces with underscores
|
||||
only_ascii = re.sub(r"\s+", "_", only_ascii)
|
||||
# 5. Remove any characters that are not alphanumeric, underscores, dots, or hyphens
|
||||
only_ascii = re.sub(r"[^a-z0-9_.-]", "", only_ascii)
|
||||
return only_ascii
|
||||
|
||||
|
||||
def truncate_string(s: str) -> str:
|
||||
"""given a string with /, return a string with only the text after the last /"""
|
||||
return pathlib.Path(s).name
|
||||
|
||||
|
||||
def remove_extension(s: str) -> str:
|
||||
"""Given a string, if it has a extension like .pdf, remove it and return the new string"""
|
||||
return str(pathlib.Path(s).with_suffix(""))
|
||||
|
||||
|
||||
def remove_duplicate_vowels(s: str) -> str:
|
||||
"""Removes consecutive duplicate vowels (a, e, i, o, u) from a string."""
|
||||
return re.sub(r"([aeiou])\1+", r"\1", s, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def normalize_filenames(
|
||||
directory: str = typer.Argument(
|
||||
..., help="The path to the directory containing files to normalize."
|
||||
),
|
||||
):
|
||||
"""Normalizes all filenames in a directory."""
|
||||
console = Console()
|
||||
console.print(
|
||||
Panel(
|
||||
f"Normalizing filenames in directory: [bold cyan]{directory}[/bold cyan]",
|
||||
title="[bold green]Filename Normalizer[/bold green]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
source_path = pathlib.Path(directory)
|
||||
if not source_path.is_dir():
|
||||
console.print(f"[bold red]Error: Directory not found at {directory}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
files_to_rename = [p for p in source_path.rglob("*") if p.is_file()]
|
||||
|
||||
if not files_to_rename:
|
||||
console.print(
|
||||
f"[bold yellow]No files found in {directory} to normalize.[/bold yellow]"
|
||||
)
|
||||
return
|
||||
|
||||
table = Table(title="File Renaming Summary")
|
||||
table.add_column("Original Name", style="cyan", no_wrap=True)
|
||||
table.add_column("New Name", style="magenta", no_wrap=True)
|
||||
table.add_column("Status", style="green")
|
||||
|
||||
for file_path in files_to_rename:
|
||||
original_name = file_path.name
|
||||
file_stem = file_path.stem
|
||||
file_suffix = file_path.suffix
|
||||
|
||||
normalized_stem = normalize_string(file_stem)
|
||||
new_name = f"{normalized_stem}{file_suffix}"
|
||||
|
||||
if new_name == original_name:
|
||||
table.add_row(
|
||||
original_name, new_name, "[yellow]Skipped (No change)[/yellow]"
|
||||
)
|
||||
continue
|
||||
|
||||
new_path = file_path.with_name(new_name)
|
||||
|
||||
# Handle potential name collisions
|
||||
counter = 1
|
||||
while new_path.exists():
|
||||
new_name = f"{normalized_stem}_{counter}{file_suffix}"
|
||||
new_path = file_path.with_name(new_name)
|
||||
counter += 1
|
||||
|
||||
try:
|
||||
file_path.rename(new_path)
|
||||
table.add_row(original_name, new_name, "[green]Renamed[/green]")
|
||||
except OSError as e:
|
||||
table.add_row(original_name, new_name, f"[bold red]Error: {e}[/bold red]")
|
||||
|
||||
console.print(table)
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold]Normalization complete.[/bold] Processed [bold blue]{len(files_to_rename)}[/bold blue] files.",
|
||||
title="[bold green]Complete[/bold green]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
3.10
|
||||
@@ -1,29 +0,0 @@
|
||||
[project]
|
||||
name = "vector-search"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Anibal Angulo", email = "a8065384@banorte.com" }
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"embedder",
|
||||
"file-storage",
|
||||
"google-cloud-aiplatform>=1.106.0",
|
||||
"aiohttp>=3.10.11,<4",
|
||||
"gcloud-aio-auth>=5.3.0",
|
||||
"google-auth==2.29.0",
|
||||
"typer>=0.16.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
vector-search = "vector_search.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.uv.sources]
|
||||
file-storage = { workspace = true }
|
||||
embedder = { workspace = true }
|
||||
@@ -1,2 +0,0 @@
|
||||
def hello() -> str:
|
||||
return "Hello from vector-search!"
|
||||
@@ -1,62 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, TypedDict
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
id: str
|
||||
distance: float
|
||||
content: str
|
||||
|
||||
|
||||
class BaseVectorSearch(ABC):
|
||||
"""
|
||||
Abstract base class for a vector search provider.
|
||||
|
||||
This class defines the standard interface for creating a vector search index
|
||||
and running queries against it.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_index(self, name: str, content_path: str, **kwargs) -> None:
|
||||
"""
|
||||
Creates a new vector search index and populates it with the provided content.
|
||||
|
||||
Args:
|
||||
name: The desired name for the new index.
|
||||
content_path: The local file system path to the data that will be used to
|
||||
populate the index. This is expected to be a JSON file
|
||||
containing a list of objects, each with an 'id', 'name',
|
||||
and 'embedding' key.
|
||||
**kwargs: Additional provider-specific arguments for index creation.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def update_index(self, index_name: str, content_path: str, **kwargs) -> None:
|
||||
"""
|
||||
Updates an existing vector search index with new content.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to update.
|
||||
content_path: The local file system path to the data that will be used to
|
||||
populate the index.
|
||||
**kwargs: Additional provider-specific arguments for index update.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def run_query(
|
||||
self, index: str, query: List[float], limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Runs a similarity search query against the index.
|
||||
|
||||
Args:
|
||||
query: The embedding vector to use for the search query.
|
||||
limit: The maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, where each dictionary represents a matched item
|
||||
and contains at least the item's 'id' and the search 'distance'.
|
||||
"""
|
||||
...
|
||||
@@ -1,10 +0,0 @@
|
||||
from typer import Typer
|
||||
|
||||
from .create import app as create_callback
|
||||
from .delete import app as delete_callback
|
||||
from .query import app as query_callback
|
||||
|
||||
app = Typer()
|
||||
app.add_typer(create_callback, name="create")
|
||||
app.add_typer(delete_callback, name="delete")
|
||||
app.add_typer(query_callback, name="query")
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Create and deploy a Vertex AI Vector Search index."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def create(
|
||||
path: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--path",
|
||||
"-p",
|
||||
help="The GCS URI (gs://...) to the directory containing your embedding JSON file(s).",
|
||||
),
|
||||
],
|
||||
agent_name: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
"--agent",
|
||||
"-a",
|
||||
help="The name of the agent to create the index for.",
|
||||
),
|
||||
],
|
||||
):
|
||||
"""Create and deploy a Vertex AI Vector Search index for a specific agent."""
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
console.print(
|
||||
f"[bold green]Looking up configuration for agent '{agent_name}'...[/bold green]"
|
||||
)
|
||||
agent_config = config.agents.get(agent_name)
|
||||
if not agent_config:
|
||||
console.print(
|
||||
f"[bold red]Agent '{agent_name}' not found in settings.[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if not agent_config.index:
|
||||
console.print(
|
||||
f"[bold red]Index configuration not found for agent '{agent_name}'.[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
index_config = agent_config.index
|
||||
|
||||
console.print(
|
||||
f"[bold green]Initializing Vertex AI client for project '{config.project_id}' in '{config.location}'...[/bold green]"
|
||||
)
|
||||
vector_search = GoogleCloudVectorSearch(
|
||||
project_id=config.project_id,
|
||||
location=config.location,
|
||||
bucket=config.bucket,
|
||||
index_name=index_config.name,
|
||||
)
|
||||
|
||||
console.print(
|
||||
f"[bold green]Starting creation of index '{index_config.name}'...[/bold green]"
|
||||
)
|
||||
console.print("This may take a while.")
|
||||
vector_search.create_index(
|
||||
name=index_config.name,
|
||||
content_path=f"gs://{config.bucket}/{path}",
|
||||
dimensions=index_config.dimensions,
|
||||
)
|
||||
console.print(
|
||||
f"[bold green]Index '{index_config.name}' created successfully.[/bold green]"
|
||||
)
|
||||
|
||||
console.print("[bold green]Deploying index to a new endpoint...[/bold green]")
|
||||
console.print("This will also take some time.")
|
||||
vector_search.deploy_index(
|
||||
index_name=index_config.name, machine_type=index_config.machine_type
|
||||
)
|
||||
console.print("[bold green]Index deployed successfully![/bold green]")
|
||||
console.print(f"Endpoint name: {vector_search.index_endpoint.display_name}")
|
||||
console.print(
|
||||
f"Endpoint resource name: {vector_search.index_endpoint.resource_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
@@ -1,38 +0,0 @@
|
||||
"""Delete a vector index or endpoint."""
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def delete(
|
||||
id: str = typer.Argument(..., help="The ID of the index or endpoint to delete."),
|
||||
endpoint: bool = typer.Option(
|
||||
False, "--endpoint", help="Delete an endpoint instead of an index."
|
||||
),
|
||||
):
|
||||
"""Delete a vector index or endpoint."""
|
||||
console = Console()
|
||||
vector_search = GoogleCloudVectorSearch(
|
||||
project_id=config.project_id, location=config.location, bucket=config.bucket
|
||||
)
|
||||
|
||||
try:
|
||||
if endpoint:
|
||||
console.print(f"[bold red]Deleting endpoint {id}...[/bold red]")
|
||||
vector_search.delete_index_endpoint(id)
|
||||
console.print(
|
||||
f"[bold green]Endpoint {id} deleted successfully.[/bold green]"
|
||||
)
|
||||
else:
|
||||
console.print(f"[bold red]Deleting index {id}...[/bold red]")
|
||||
vector_search.delete_index(id)
|
||||
console.print(f"[bold green]Index {id} deleted successfully.[/bold green]")
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Generate embeddings for documents and save them to a JSON file."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress
|
||||
|
||||
from rag_eval.config import Settings
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def generate(
|
||||
path: str = typer.Argument(..., help="The path to the markdown files."),
|
||||
output_file: str = typer.Option(
|
||||
...,
|
||||
"--output-file",
|
||||
"-o",
|
||||
help="The local path to save the output JSON file.",
|
||||
),
|
||||
batch_size: int = typer.Option(
|
||||
10,
|
||||
"--batch-size",
|
||||
"-b",
|
||||
help="The batch size for processing files.",
|
||||
),
|
||||
jsonl: bool = typer.Option(
|
||||
False,
|
||||
"--jsonl",
|
||||
help="Output in JSONL format instead of JSON.",
|
||||
),
|
||||
):
|
||||
"""Generate embeddings for documents and save them to a JSON file."""
|
||||
config = Settings()
|
||||
console = Console()
|
||||
|
||||
console.print("[bold green]Starting vector generation...[/bold green]")
|
||||
|
||||
try:
|
||||
storage = GoogleCloudFileStorage(bucket=config.bucket)
|
||||
embedder = VertexAIEmbedder(model_name=config.embedding_model)
|
||||
|
||||
remote_files = storage.list_files(path=path)
|
||||
results = []
|
||||
|
||||
with Progress(console=console) as progress:
|
||||
task = progress.add_task(
|
||||
"[cyan]Generating embeddings...", total=len(remote_files)
|
||||
)
|
||||
|
||||
for i in range(0, len(remote_files), batch_size):
|
||||
batch_files = remote_files[i : i + batch_size]
|
||||
batch_contents = []
|
||||
|
||||
for remote_file in batch_files:
|
||||
file_stream = storage.get_file_stream(remote_file)
|
||||
batch_contents.append(
|
||||
file_stream.read().decode("utf-8-sig", errors="replace")
|
||||
)
|
||||
|
||||
batch_embeddings = embedder.generate_embeddings_batch(batch_contents)
|
||||
|
||||
for j, remote_file in enumerate(batch_files):
|
||||
results.append(
|
||||
{"id": remote_file, "embedding": batch_embeddings[j]}
|
||||
)
|
||||
progress.update(task, advance=1)
|
||||
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]An error occurred during vector generation: {e}[/bold red]"
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
if jsonl:
|
||||
for record in results:
|
||||
f.write(json.dumps(record) + "\n")
|
||||
else:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
console.print(
|
||||
f"[bold green]Embedding generation complete. {len(results)} vectors saved to '{output_path.resolve()}'[/bold green]"
|
||||
)
|
||||
@@ -1,55 +0,0 @@
|
||||
"""Query the vector search index."""
|
||||
|
||||
import typer
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from typer import Argument, Option
|
||||
|
||||
from rag_eval.config import settings as config
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.callback(invoke_without_command=True)
|
||||
def query(
|
||||
query: str = Argument(..., help="The text query to search for."),
|
||||
limit: int = Option(5, "--limit", "-l", help="The number of results to return."),
|
||||
):
|
||||
"""Queries the vector search index."""
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
console.print("[bold green]Initializing clients...[/bold green]")
|
||||
embedder = VertexAIEmbedder(model_name=config.embedding_model)
|
||||
vector_search = GoogleCloudVectorSearch(
|
||||
project_id=config.project_id, location=config.location, bucket=config.bucket
|
||||
)
|
||||
|
||||
console.print("[bold green]Loading index endpoint...[/bold green]")
|
||||
vector_search.load_index_endpoint(config.index.endpoint)
|
||||
|
||||
console.print("[bold green]Generating embedding for query...[/bold green]")
|
||||
query_embedding = embedder.generate_embedding(query)
|
||||
|
||||
console.print("[bold green]Running search query...[/bold green]")
|
||||
search_results = vector_search.run_query(
|
||||
deployed_index_id=config.index.deployment,
|
||||
query=query_embedding,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
table = Table(title="Search Results")
|
||||
table.add_column("ID", justify="left", style="cyan")
|
||||
table.add_column("Distance", justify="left", style="magenta")
|
||||
table.add_column("Content", justify="left", style="green")
|
||||
|
||||
for result in search_results:
|
||||
table.add_row(result["id"], str(result["distance"]), result["content"])
|
||||
|
||||
console.print(table)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[bold red]An error occurred: {e}[/bold red]")
|
||||
raise typer.Exit(code=1)
|
||||
@@ -1,255 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import google.auth
|
||||
import google.auth.transport.requests
|
||||
from file_storage.google_cloud import GoogleCloudFileStorage
|
||||
from gcloud.aio.auth import Token
|
||||
from google.cloud import aiplatform
|
||||
|
||||
from .base import BaseVectorSearch, SearchResult
|
||||
|
||||
|
||||
class GoogleCloudVectorSearch(BaseVectorSearch):
|
||||
"""
|
||||
A vector search provider that uses Google Cloud's Vertex AI Vector Search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, project_id: str, location: str, bucket: str, index_name: str = None
|
||||
):
|
||||
"""
|
||||
Initializes the GoogleCloudVectorSearch client.
|
||||
|
||||
Args:
|
||||
project_id: The Google Cloud project ID.
|
||||
location: The Google Cloud location (e.g., 'us-central1').
|
||||
bucket: The GCS bucket to use for file storage.
|
||||
index_name: The name of the index. If None, it will be taken from settings.
|
||||
"""
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||
self.index_name = index_name
|
||||
self._credentials = None
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._async_token: Token | None = None
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
if self._credentials is None:
|
||||
self._credentials, _ = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"]
|
||||
)
|
||||
if not self._credentials.token or self._credentials.expired:
|
||||
self._credentials.refresh(google.auth.transport.requests.Request())
|
||||
return {
|
||||
"Authorization": f"Bearer {self._credentials.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def _async_get_auth_headers(self) -> dict:
|
||||
if self._async_token is None:
|
||||
self._async_token = Token(
|
||||
session=self._get_aio_session(),
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
access_token = await self._async_token.get()
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(limit=300, limit_per_host=50)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
name: str,
|
||||
content_path: str,
|
||||
dimensions: int,
|
||||
approximate_neighbors_count: int = 150,
|
||||
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Creates a new Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
name: The display name for the new index.
|
||||
content_path: The GCS URI to the JSON file containing the embeddings.
|
||||
dimensions: The number of dimensions in the embedding vectors.
|
||||
approximate_neighbors_count: The number of neighbors to find for each vector.
|
||||
distance_measure_type: The distance measure to use (e.g., 'DOT_PRODUCT_DISTANCE').
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
||||
display_name=name,
|
||||
contents_delta_uri=content_path,
|
||||
dimensions=dimensions,
|
||||
approximate_neighbors_count=approximate_neighbors_count,
|
||||
distance_measure_type=distance_measure_type,
|
||||
leaf_node_embedding_count=1000,
|
||||
leaf_nodes_to_search_percent=10,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def update_index(self, index_name: str, content_path: str, **kwargs) -> None:
|
||||
"""
|
||||
Updates an existing Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index to update.
|
||||
content_path: The GCS URI to the JSON file containing the new embeddings.
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
||||
index.update_embeddings(
|
||||
contents_delta_uri=content_path,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def deploy_index(
|
||||
self, index_name: str, machine_type: str = "e2-standard-2"
|
||||
) -> None:
|
||||
"""
|
||||
Deploys a Vertex AI Vector Search index to an endpoint.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to deploy.
|
||||
machine_type: The type of machine to use for the endpoint.
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
||||
display_name=f"{index_name}-endpoint",
|
||||
public_endpoint_enabled=True,
|
||||
)
|
||||
index_endpoint.deploy_index(
|
||||
index=self.index,
|
||||
deployed_index_id=f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}",
|
||||
machine_type=machine_type,
|
||||
)
|
||||
self.index_endpoint = index_endpoint
|
||||
|
||||
def load_index_endpoint(self, endpoint_name: str) -> None:
|
||||
"""
|
||||
Loads an existing Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: The resource name of the index endpoint.
|
||||
"""
|
||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(endpoint_name)
|
||||
if not self.index_endpoint.public_endpoint_domain_name:
|
||||
raise ValueError(
|
||||
"The index endpoint does not have a public endpoint. "
|
||||
"Please ensure that the endpoint is configured for public access."
|
||||
)
|
||||
|
||||
def run_query(
|
||||
self, deployed_index_id: str, query: List[float], limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Runs a similarity search query against the deployed index.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector to use for the search query.
|
||||
limit: The maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries representing the matched items.
|
||||
"""
|
||||
response = self.index_endpoint.find_neighbors(
|
||||
deployed_index_id=deployed_index_id, queries=[query], num_neighbors=limit
|
||||
)
|
||||
results = []
|
||||
for neighbor in response[0]:
|
||||
file_path = self.index_name + "/contents/" + neighbor.id + ".md"
|
||||
content = self.storage.get_file_stream(file_path).read().decode("utf-8")
|
||||
results.append(
|
||||
{"id": neighbor.id, "distance": neighbor.distance, "content": content}
|
||||
)
|
||||
return results
|
||||
|
||||
async def async_run_query(
|
||||
self, deployed_index_id: str, query: List[float], limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Runs a non-blocking similarity search query against the deployed index
|
||||
using the REST API directly with an async HTTP client.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector to use for the search query.
|
||||
limit: The maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries representing the matched items.
|
||||
"""
|
||||
domain = self.index_endpoint.public_endpoint_domain_name
|
||||
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
||||
url = (
|
||||
f"https://{domain}/v1/projects/{self.project_id}"
|
||||
f"/locations/{self.location}"
|
||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||
)
|
||||
payload = {
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"queries": [
|
||||
{
|
||||
"datapoint": {"feature_vector": query},
|
||||
"neighbor_count": limit,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
headers = await self._async_get_auth_headers()
|
||||
session = self._get_aio_session()
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
neighbors = data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||
content_tasks = []
|
||||
for neighbor in neighbors:
|
||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||
file_path = f"{self.index_name}/contents/{datapoint_id}.md"
|
||||
content_tasks.append(self.storage.async_get_file_stream(file_path))
|
||||
|
||||
file_streams = await asyncio.gather(*content_tasks)
|
||||
results: List[SearchResult] = []
|
||||
for neighbor, stream in zip(neighbors, file_streams):
|
||||
results.append(
|
||||
{
|
||||
"id": neighbor["datapoint"]["datapointId"],
|
||||
"distance": neighbor["distance"],
|
||||
"content": stream.read().decode("utf-8"),
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def delete_index(self, index_name: str) -> None:
|
||||
"""
|
||||
Deletes a Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index.
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name)
|
||||
index.delete()
|
||||
|
||||
def delete_index_endpoint(self, index_endpoint_name: str) -> None:
|
||||
"""
|
||||
Deletes a Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
index_endpoint_name: The resource name of the index endpoint.
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name)
|
||||
index_endpoint.undeploy_all()
|
||||
index_endpoint.delete(force=True)
|
||||
@@ -1,91 +1,57 @@
|
||||
[project]
|
||||
name = "rag-pipeline"
|
||||
name = "knowledge-pipeline"
|
||||
version = "0.1.0"
|
||||
description = "RAG Pipeline for document chunking, embedding, and vector search"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
authors = [
|
||||
{ name = "Pipeline Team" }
|
||||
{ name = "Anibal Angulo", email = "A8065384@banorte.com" }
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
# Core dependencies
|
||||
"google-genai>=1.45.0",
|
||||
"google-cloud-aiplatform>=1.106.0",
|
||||
"google-cloud-storage>=2.19.0",
|
||||
"google-auth>=2.29.0",
|
||||
"pydantic>=2.11.7",
|
||||
"pydantic-settings[yaml]>=2.10.1",
|
||||
"python-dotenv>=1.0.0",
|
||||
|
||||
# Chunking
|
||||
"chonkie>=1.1.2",
|
||||
"tiktoken>=0.7.0",
|
||||
"langchain>=0.3.0",
|
||||
"langchain-core>=0.3.0",
|
||||
|
||||
# Document processing
|
||||
"markitdown[pdf]>=0.1.2",
|
||||
"pypdf>=6.1.2",
|
||||
"pdf2image>=1.17.0",
|
||||
|
||||
# Storage & networking
|
||||
"gcloud-aio-storage>=9.6.1",
|
||||
"gcloud-aio-auth>=5.3.0",
|
||||
"aiohttp>=3.10.11,<4",
|
||||
|
||||
# Utils
|
||||
"tenacity>=9.1.2",
|
||||
"typer>=0.16.1",
|
||||
|
||||
# Pipeline orchestration (optional)
|
||||
"kfp>=2.15.2",
|
||||
"pydantic-ai>=0.0.5",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
# Chunkers
|
||||
llm-chunker = "chunker.llm_chunker:app"
|
||||
recursive-chunker = "chunker.recursive_chunker:app"
|
||||
contextual-chunker = "chunker.contextual_chunker:app"
|
||||
|
||||
# Converters
|
||||
convert-md = "document_converter.markdown:app"
|
||||
|
||||
# Storage
|
||||
file-storage = "file_storage.cli:app"
|
||||
|
||||
# Vector Search
|
||||
vector-search = "vector_search.cli:app"
|
||||
|
||||
# Utils
|
||||
normalize-filenames = "utils.normalize_filenames:app"
|
||||
knowledge-pipeline = "knowledge_pipeline.cli:app"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
"apps/*",
|
||||
"packages/*",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
chunker = { workspace = true }
|
||||
document-converter = { workspace = true }
|
||||
embedder = { workspace = true }
|
||||
file-storage = { workspace = true }
|
||||
llm = { workspace = true }
|
||||
utils = { workspace = true }
|
||||
vector-search = { workspace = true }
|
||||
index-gen = { workspace = true }
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.4.1",
|
||||
"mypy>=1.17.1",
|
||||
"ruff>=0.12.10",
|
||||
"ty>=0.0.18",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
extend-select = ["I", "F"]
|
||||
select = ["I", "F"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = [
|
||||
"--strict-markers",
|
||||
"--tb=short",
|
||||
"--disable-warnings",
|
||||
]
|
||||
markers = [
|
||||
"unit: Unit tests",
|
||||
"integration: Integration tests",
|
||||
"slow: Slow running tests",
|
||||
]
|
||||
|
||||
@@ -13,6 +13,8 @@ class Document(TypedDict):
|
||||
class BaseChunker(ABC):
|
||||
"""Abstract base class for chunker implementations."""
|
||||
|
||||
max_chunk_size: int
|
||||
|
||||
@abstractmethod
|
||||
def process_text(self, text: str) -> List[Document]:
|
||||
"""
|
||||
@@ -1,11 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List
|
||||
|
||||
import typer
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
|
||||
from .base_chunker import BaseChunker, Document
|
||||
|
||||
|
||||
@@ -16,23 +8,13 @@ class ContextualChunker(BaseChunker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: VertexAILLM,
|
||||
model: str = "google-vertex:gemini-2.0-flash",
|
||||
max_chunk_size: int = 800,
|
||||
model: str = "gemini-2.0-flash",
|
||||
):
|
||||
"""
|
||||
Initializes the ContextualChunker.
|
||||
|
||||
Args:
|
||||
max_chunk_size: The maximum length of a chunk in characters.
|
||||
model: The name of the language model to use.
|
||||
llm_client: An optional instance of a language model client.
|
||||
"""
|
||||
self.max_chunk_size = max_chunk_size
|
||||
self.model = model
|
||||
self.llm_client = llm_client
|
||||
|
||||
def _split_text(self, text: str) -> List[str]:
|
||||
def _split_text(self, text: str) -> list[str]:
|
||||
"""Splits text into evenly sized chunks of a maximum size, trying to respect sentence and paragraph boundaries."""
|
||||
import math
|
||||
|
||||
@@ -67,7 +49,7 @@ class ContextualChunker(BaseChunker):
|
||||
|
||||
return chunks
|
||||
|
||||
def process_text(self, text: str) -> List[Document]:
|
||||
def process_text(self, text: str) -> list[Document]:
|
||||
"""
|
||||
Processes a string of text into a list of context-aware Document chunks.
|
||||
"""
|
||||
@@ -75,7 +57,7 @@ class ContextualChunker(BaseChunker):
|
||||
return [{"page_content": text, "metadata": {}}]
|
||||
|
||||
chunks = self._split_text(text)
|
||||
processed_chunks: List[Document] = []
|
||||
processed_chunks: list[Document] = []
|
||||
|
||||
for i, chunk_content in enumerate(chunks):
|
||||
prompt = f"""
|
||||
@@ -93,7 +75,14 @@ class ContextualChunker(BaseChunker):
|
||||
Genera un resumen conciso del "Documento Original" que proporcione el contexto necesario para entender el "Fragmento Actual". El resumen debe ser un solo párrafo en español.
|
||||
"""
|
||||
|
||||
summary = self.llm_client.generate(self.model, prompt).text
|
||||
from pydantic_ai import ModelRequest
|
||||
from pydantic_ai.direct import model_request_sync
|
||||
|
||||
response = model_request_sync(
|
||||
self.model,
|
||||
[ModelRequest.user_text_prompt(prompt)],
|
||||
)
|
||||
summary = next(p.content for p in response.parts if p.part_kind == "text")
|
||||
contextualized_chunk = (
|
||||
f"> **Contexto del documento original:**\n> {summary}\n\n---\n\n"
|
||||
+ chunk_content
|
||||
@@ -107,49 +96,3 @@ class ContextualChunker(BaseChunker):
|
||||
)
|
||||
|
||||
return processed_chunks
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
input_file_path: Annotated[
|
||||
str, typer.Argument(help="Path to the input text file.")
|
||||
],
|
||||
output_dir: Annotated[
|
||||
str, typer.Argument(help="Directory to save the output file.")
|
||||
],
|
||||
max_chunk_size: Annotated[
|
||||
int, typer.Option(help="Maximum chunk size in characters.")
|
||||
] = 800,
|
||||
model: Annotated[
|
||||
str, typer.Option(help="Model to use for the processing")
|
||||
] = "gemini-2.0-flash",
|
||||
):
|
||||
"""
|
||||
Processes a text file using ContextualChunker and saves the output to a JSONL file.
|
||||
"""
|
||||
print(f"Starting to process {input_file_path}...")
|
||||
|
||||
chunker = ContextualChunker(max_chunk_size=max_chunk_size, model=model)
|
||||
documents = chunker.process_path(Path(input_file_path))
|
||||
|
||||
print(f"Successfully created {len(documents)} chunks.")
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
print(f"Created output directory: {output_dir}")
|
||||
|
||||
output_file_path = os.path.join(output_dir, "chunked_documents.jsonl")
|
||||
|
||||
with open(output_file_path, "w", encoding="utf-8") as f:
|
||||
for doc in documents:
|
||||
doc["metadata"]["source_file"] = os.path.basename(input_file_path)
|
||||
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Successfully saved {len(documents)} chunks to {output_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -13,7 +13,6 @@ from langchain_core.documents import Document as LangchainDocument
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from pdf2image import convert_from_path
|
||||
from pypdf import PdfReader
|
||||
|
||||
from rag_eval.config import Settings
|
||||
|
||||
from .base_chunker import BaseChunker, Document
|
||||
20
src/knowledge_pipeline/cli.py
Normal file
20
src/knowledge_pipeline/cli.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import logging
|
||||
|
||||
import typer
|
||||
|
||||
from .config import Settings
|
||||
from .pipeline import run_pipeline
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command()
|
||||
def run_ingestion():
|
||||
"""Main function for the CLI script."""
|
||||
settings = Settings.model_validate({})
|
||||
logging.getLogger("knowledge_pipeline").setLevel(getattr(logging, settings.log_level))
|
||||
run_pipeline(settings)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
101
src/knowledge_pipeline/config.py
Normal file
101
src/knowledge_pipeline/config.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
|
||||
DistanceMeasureType,
|
||||
)
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
project_id: str
|
||||
location: str
|
||||
log_level: str = "INFO"
|
||||
|
||||
agent_embedding_model: str
|
||||
|
||||
index_name: str
|
||||
index_dimensions: int
|
||||
index_machine_type: str = "e2-standard-16"
|
||||
index_origin: str
|
||||
index_destination: str
|
||||
index_chunk_limit: int
|
||||
index_distance_measure_type: DistanceMeasureType = (
|
||||
DistanceMeasureType.DOT_PRODUCT_DISTANCE
|
||||
)
|
||||
index_approximate_neighbors_count: int = 150
|
||||
index_leaf_node_embedding_count: int = 1000
|
||||
index_leaf_nodes_to_search_percent: int = 10
|
||||
index_public_endpoint_enabled: bool = True
|
||||
|
||||
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
|
||||
|
||||
def model_post_init(self, _):
|
||||
from google.cloud import aiplatform
|
||||
|
||||
aiplatform.init(project=self.project_id, location=self.location)
|
||||
|
||||
@property
|
||||
def index_deployment(self) -> str:
|
||||
return self.index_name.replace("-", "_") + "_deployed"
|
||||
|
||||
@property
|
||||
def index_data(self) -> str:
|
||||
return f"{self.index_destination}/{self.index_name}"
|
||||
|
||||
@property
|
||||
def index_contents_dir(self) -> str:
|
||||
return f"{self.index_data}/contents"
|
||||
|
||||
@property
|
||||
def index_vectors_dir(self) -> str:
|
||||
return f"{self.index_data}/vectors"
|
||||
|
||||
@property
|
||||
def index_vectors_jsonl_path(self) -> str:
|
||||
return f"{self.index_vectors_dir}/vectors.json"
|
||||
|
||||
@cached_property
|
||||
def gcs_client(self):
|
||||
from google.cloud import storage
|
||||
|
||||
return storage.Client()
|
||||
|
||||
@cached_property
|
||||
def converter(self):
|
||||
from markitdown import MarkItDown
|
||||
|
||||
return MarkItDown(enable_plugins=False)
|
||||
|
||||
@cached_property
|
||||
def embedder(self):
|
||||
from pydantic_ai import Embedder
|
||||
|
||||
return Embedder(f"google-vertex:{self.agent_embedding_model}")
|
||||
|
||||
@cached_property
|
||||
def chunker(self):
|
||||
from .chunker.contextual_chunker import ContextualChunker
|
||||
|
||||
return ContextualChunker(max_chunk_size=self.index_chunk_limit)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
209
src/knowledge_pipeline/pipeline.py
Normal file
209
src/knowledge_pipeline/pipeline.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import tempfile
|
||||
import unicodedata
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from google.cloud import aiplatform
|
||||
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
|
||||
DistanceMeasureType,
|
||||
)
|
||||
from google.cloud.storage import Client as StorageClient
|
||||
from markitdown import MarkItDown
|
||||
from pydantic_ai import Embedder
|
||||
|
||||
from .chunker.base_chunker import BaseChunker, Document
|
||||
from .config import Settings
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_gcs_uri(uri: str) -> tuple[str, str]:
|
||||
"""Parse a 'gs://bucket/path' URI into (bucket_name, object_path)."""
|
||||
bucket, _, path = uri.removeprefix("gs://").partition("/")
|
||||
return bucket, path
|
||||
|
||||
|
||||
def normalize_string(s: str) -> str:
|
||||
"""Normalizes a string to be a valid filename."""
|
||||
nfkd_form = unicodedata.normalize("NFKD", s)
|
||||
only_ascii = "".join([c for c in nfkd_form if not unicodedata.combining(c)])
|
||||
only_ascii = only_ascii.lower()
|
||||
only_ascii = re.sub(r"\s+", "_", only_ascii)
|
||||
only_ascii = re.sub(r"[^a-z0-9_.-]", "", only_ascii)
|
||||
return only_ascii
|
||||
|
||||
|
||||
def gather_pdfs(index_origin: str, gcs_client: StorageClient) -> list[str]:
|
||||
"""Lists all PDF file URIs in a GCS directory."""
|
||||
bucket, prefix = _parse_gcs_uri(index_origin)
|
||||
blobs = gcs_client.bucket(bucket).list_blobs(prefix=prefix)
|
||||
pdf_files = [
|
||||
f"gs://{bucket}/{blob.name}" for blob in blobs if blob.name.endswith(".pdf")
|
||||
]
|
||||
log.info("Found %d PDF files in %s", len(pdf_files), index_origin)
|
||||
return pdf_files
|
||||
|
||||
|
||||
def split_into_chunks(text: str, file_id: str, chunker: BaseChunker) -> list[Document]:
|
||||
"""Splits text into chunks, or returns a single chunk if small enough."""
|
||||
if len(text) <= chunker.max_chunk_size:
|
||||
return [{"page_content": text, "metadata": {"id": file_id}}]
|
||||
|
||||
chunks = chunker.process_text(text)
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk["metadata"]["id"] = f"{file_id}_{i}"
|
||||
return chunks
|
||||
|
||||
|
||||
def upload_to_gcs(
|
||||
chunks: list[Document],
|
||||
vectors: list[dict],
|
||||
index_contents_dir: str,
|
||||
index_vectors_jsonl_path: str,
|
||||
gcs_client: StorageClient,
|
||||
) -> None:
|
||||
"""Uploads chunk contents and vectors to GCS."""
|
||||
bucket, prefix = _parse_gcs_uri(index_contents_dir)
|
||||
gcs_bucket = gcs_client.bucket(bucket)
|
||||
for chunk in chunks:
|
||||
chunk_id = chunk["metadata"]["id"]
|
||||
gcs_bucket.blob(f"{prefix}/{chunk_id}.md").upload_from_string(
|
||||
chunk["page_content"], content_type="text/markdown; charset=utf-8"
|
||||
)
|
||||
|
||||
vectors_jsonl = "\n".join(json.dumps(v) for v in vectors) + "\n"
|
||||
bucket, obj_path = _parse_gcs_uri(index_vectors_jsonl_path)
|
||||
gcs_client.bucket(bucket).blob(obj_path).upload_from_string(
|
||||
vectors_jsonl, content_type="application/x-ndjson; charset=utf-8"
|
||||
)
|
||||
log.info("Uploaded %d chunks and %d vectors to GCS", len(chunks), len(vectors))
|
||||
|
||||
|
||||
def build_vectors(
|
||||
chunks: list[Document],
|
||||
embeddings: Sequence[Sequence[float]],
|
||||
source_folder: str,
|
||||
) -> list[dict]:
|
||||
"""Builds vector records from chunks and their embeddings."""
|
||||
source = Path(source_folder).parts[0] if source_folder else ""
|
||||
return [
|
||||
{
|
||||
"id": chunk["metadata"]["id"],
|
||||
"embedding": list(embedding),
|
||||
"restricts": [{"namespace": "source", "allow": [source]}],
|
||||
}
|
||||
for chunk, embedding in zip(chunks, embeddings)
|
||||
]
|
||||
|
||||
|
||||
def create_vector_index(
|
||||
index_name: str,
|
||||
index_vectors_dir: str,
|
||||
index_dimensions: int,
|
||||
index_distance_measure_type: DistanceMeasureType,
|
||||
index_deployment: str,
|
||||
index_machine_type: str,
|
||||
approximate_neighbors_count: int,
|
||||
leaf_node_embedding_count: int,
|
||||
leaf_nodes_to_search_percent: int,
|
||||
public_endpoint_enabled: bool,
|
||||
):
|
||||
"""Creates and deploys a Vertex AI Vector Search Index."""
|
||||
log.info("Creating index '%s'...", index_name)
|
||||
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
||||
display_name=index_name,
|
||||
contents_delta_uri=index_vectors_dir,
|
||||
dimensions=index_dimensions,
|
||||
approximate_neighbors_count=approximate_neighbors_count,
|
||||
distance_measure_type=index_distance_measure_type,
|
||||
leaf_node_embedding_count=leaf_node_embedding_count,
|
||||
leaf_nodes_to_search_percent=leaf_nodes_to_search_percent,
|
||||
)
|
||||
log.info("Index '%s' created successfully.", index_name)
|
||||
|
||||
log.info("Deploying index to a new endpoint...")
|
||||
endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
||||
display_name=f"{index_name}-endpoint",
|
||||
public_endpoint_enabled=public_endpoint_enabled,
|
||||
)
|
||||
endpoint.deploy_index(
|
||||
index=index,
|
||||
deployed_index_id=index_deployment,
|
||||
machine_type=index_machine_type,
|
||||
)
|
||||
log.info("Index deployed: %s", endpoint.display_name)
|
||||
|
||||
|
||||
def process_file(
|
||||
file_uri: str,
|
||||
temp_dir: Path,
|
||||
gcs_client: StorageClient,
|
||||
converter: MarkItDown,
|
||||
embedder: Embedder,
|
||||
chunker: BaseChunker,
|
||||
) -> tuple[list[Document], list[dict]]:
|
||||
"""Downloads a PDF from GCS, converts to markdown, chunks, and embeds."""
|
||||
bucket, obj_path = _parse_gcs_uri(file_uri)
|
||||
local_path = temp_dir / Path(file_uri).name
|
||||
gcs_client.bucket(bucket).blob(obj_path).download_to_filename(local_path)
|
||||
|
||||
try:
|
||||
markdown = converter.convert(local_path).text_content
|
||||
file_id = normalize_string(Path(file_uri).stem)
|
||||
source_folder = Path(obj_path).parent.as_posix()
|
||||
|
||||
chunks = split_into_chunks(markdown, file_id, chunker)
|
||||
texts = [c["page_content"] for c in chunks]
|
||||
embeddings = embedder.embed_documents_sync(texts).embeddings
|
||||
|
||||
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||
return chunks, vectors
|
||||
finally:
|
||||
if local_path.exists():
|
||||
local_path.unlink()
|
||||
|
||||
|
||||
def run_pipeline(settings: Settings):
|
||||
"""Runs the full ingestion pipeline: gather → process → aggregate → index."""
|
||||
files = gather_pdfs(settings.index_origin, settings.gcs_client)
|
||||
|
||||
all_chunks: list[Document] = []
|
||||
all_vectors: list[dict] = []
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
for file_uri in files:
|
||||
log.info("Processing file: %s", file_uri)
|
||||
chunks, vectors = process_file(
|
||||
file_uri,
|
||||
Path(temp_dir),
|
||||
settings.gcs_client,
|
||||
settings.converter,
|
||||
settings.embedder,
|
||||
settings.chunker,
|
||||
)
|
||||
all_chunks.extend(chunks)
|
||||
all_vectors.extend(vectors)
|
||||
|
||||
upload_to_gcs(
|
||||
all_chunks,
|
||||
all_vectors,
|
||||
settings.index_contents_dir,
|
||||
settings.index_vectors_jsonl_path,
|
||||
settings.gcs_client,
|
||||
)
|
||||
|
||||
create_vector_index(
|
||||
settings.index_name,
|
||||
settings.index_vectors_dir,
|
||||
settings.index_dimensions,
|
||||
settings.index_distance_measure_type,
|
||||
settings.index_deployment,
|
||||
settings.index_machine_type,
|
||||
settings.index_approximate_neighbors_count,
|
||||
settings.index_leaf_node_embedding_count,
|
||||
settings.index_leaf_nodes_to_search_percent,
|
||||
settings.index_public_endpoint_enabled,
|
||||
)
|
||||
@@ -1,121 +0,0 @@
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||
|
||||
|
||||
class IndexConfig(BaseModel):
|
||||
name: str
|
||||
endpoint: str
|
||||
dimensions: int
|
||||
machine_type: str = "e2-standard-16"
|
||||
origin: str
|
||||
destination: str
|
||||
chunk_limit: int
|
||||
|
||||
@property
|
||||
def deployment(self) -> str:
|
||||
return self.name.replace("-", "_") + "_deployed"
|
||||
|
||||
@property
|
||||
def data(self) -> str:
|
||||
return self.destination + self.name
|
||||
|
||||
class AgentConfig(BaseModel):
|
||||
name: str
|
||||
instructions: str
|
||||
language_model: str
|
||||
embedding_model: str
|
||||
thinking: int
|
||||
|
||||
|
||||
class BigQueryConfig(BaseModel):
|
||||
dataset_id: str
|
||||
project_id: str | None = None
|
||||
table_ids: dict[str, str]
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
project_id: str
|
||||
location: str
|
||||
service_account: str
|
||||
|
||||
# Flattened fields from nested models
|
||||
agent_name: str
|
||||
agent_instructions: str
|
||||
agent_language_model: str
|
||||
agent_embedding_model: str
|
||||
agent_thinking: int
|
||||
|
||||
index_name: str
|
||||
index_endpoint: str
|
||||
index_dimensions: int
|
||||
index_machine_type: str = "e2-standard-16"
|
||||
index_origin: str
|
||||
index_destination: str
|
||||
index_chunk_limit: int
|
||||
|
||||
bigquery_dataset_id: str
|
||||
bigquery_project_id: str | None = None
|
||||
bigquery_table_ids: dict[str, str]
|
||||
|
||||
bucket: str
|
||||
base_image: str
|
||||
dialogflow_agent_id: str
|
||||
processing_image: str
|
||||
|
||||
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
|
||||
|
||||
@property
|
||||
def agent(self) -> AgentConfig:
|
||||
return AgentConfig(
|
||||
name=self.agent_name,
|
||||
instructions=self.agent_instructions,
|
||||
language_model=self.agent_language_model,
|
||||
embedding_model=self.agent_embedding_model,
|
||||
thinking=self.agent_thinking,
|
||||
)
|
||||
|
||||
@property
|
||||
def index(self) -> IndexConfig:
|
||||
return IndexConfig(
|
||||
name=self.index_name,
|
||||
endpoint=self.index_endpoint,
|
||||
dimensions=self.index_dimensions,
|
||||
machine_type=self.index_machine_type,
|
||||
origin=self.index_origin,
|
||||
destination=self.index_destination,
|
||||
chunk_limit=self.index_chunk_limit,
|
||||
)
|
||||
|
||||
@property
|
||||
def bigquery(self) -> BigQueryConfig:
|
||||
return BigQueryConfig(
|
||||
dataset_id=self.bigquery_dataset_id,
|
||||
project_id=self.bigquery_project_id,
|
||||
table_ids=self.bigquery_table_ids,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
89
tests/conftest.py
Normal file
89
tests/conftest.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Shared pytest fixtures for knowledge_pipeline tests."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from knowledge_pipeline.chunker.base_chunker import BaseChunker, Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gcs_client():
|
||||
"""Mock Google Cloud Storage client."""
|
||||
client = Mock()
|
||||
bucket = Mock()
|
||||
blob = Mock()
|
||||
|
||||
client.bucket.return_value = bucket
|
||||
bucket.blob.return_value = blob
|
||||
bucket.list_blobs.return_value = []
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunker():
|
||||
"""Mock BaseChunker implementation."""
|
||||
chunker = Mock(spec=BaseChunker)
|
||||
chunker.max_chunk_size = 1000
|
||||
chunker.process_text.return_value = [
|
||||
{"page_content": "Test chunk content", "metadata": {"id": "test_chunk"}}
|
||||
]
|
||||
return chunker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedder():
|
||||
"""Mock pydantic_ai Embedder."""
|
||||
embedder = Mock()
|
||||
embeddings_result = Mock()
|
||||
embeddings_result.embeddings = [[0.1, 0.2, 0.3]]
|
||||
embedder.embed_documents_sync.return_value = embeddings_result
|
||||
return embedder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_converter():
|
||||
"""Mock MarkItDown converter."""
|
||||
converter = Mock()
|
||||
result = Mock()
|
||||
result.text_content = "# Markdown Content\n\nTest content here."
|
||||
converter.convert.return_value = result
|
||||
return converter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_chunks() -> list[Document]:
|
||||
"""Sample document chunks for testing."""
|
||||
return [
|
||||
{"page_content": "First chunk content", "metadata": {"id": "doc_1_0"}},
|
||||
{"page_content": "Second chunk content", "metadata": {"id": "doc_1_1"}},
|
||||
{"page_content": "Third chunk content", "metadata": {"id": "doc_1_2"}},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings():
|
||||
"""Sample embeddings for testing."""
|
||||
return [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
[0.6, 0.7, 0.8, 0.9, 1.0],
|
||||
[0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vectors():
|
||||
"""Sample vector records for testing."""
|
||||
return [
|
||||
{
|
||||
"id": "doc_1_0",
|
||||
"embedding": [0.1, 0.2, 0.3],
|
||||
"restricts": [{"namespace": "source", "allow": ["documents"]}],
|
||||
},
|
||||
{
|
||||
"id": "doc_1_1",
|
||||
"embedding": [0.4, 0.5, 0.6],
|
||||
"restricts": [{"namespace": "source", "allow": ["documents"]}],
|
||||
},
|
||||
]
|
||||
553
tests/test_pipeline.py
Normal file
553
tests/test_pipeline.py
Normal file
@@ -0,0 +1,553 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from google.cloud.aiplatform.matching_engine.matching_engine_index_config import (
|
||||
DistanceMeasureType,
|
||||
)
|
||||
|
||||
from knowledge_pipeline.chunker.base_chunker import BaseChunker
|
||||
from knowledge_pipeline.pipeline import (
|
||||
_parse_gcs_uri,
|
||||
build_vectors,
|
||||
create_vector_index,
|
||||
gather_pdfs,
|
||||
normalize_string,
|
||||
process_file,
|
||||
run_pipeline,
|
||||
split_into_chunks,
|
||||
upload_to_gcs,
|
||||
)
|
||||
|
||||
|
||||
class TestParseGcsUri:
|
||||
"""Tests for _parse_gcs_uri function."""
|
||||
|
||||
def test_basic_gcs_uri(self):
|
||||
bucket, path = _parse_gcs_uri("gs://my-bucket/path/to/file.pdf")
|
||||
assert bucket == "my-bucket"
|
||||
assert path == "path/to/file.pdf"
|
||||
|
||||
def test_gcs_uri_with_nested_path(self):
|
||||
bucket, path = _parse_gcs_uri("gs://test-bucket/deep/nested/path/file.txt")
|
||||
assert bucket == "test-bucket"
|
||||
assert path == "deep/nested/path/file.txt"
|
||||
|
||||
def test_gcs_uri_bucket_only(self):
|
||||
bucket, path = _parse_gcs_uri("gs://my-bucket/")
|
||||
assert bucket == "my-bucket"
|
||||
assert path == ""
|
||||
|
||||
def test_gcs_uri_no_trailing_slash(self):
|
||||
bucket, path = _parse_gcs_uri("gs://bucket-name")
|
||||
assert bucket == "bucket-name"
|
||||
assert path == ""
|
||||
|
||||
|
||||
class TestNormalizeString:
|
||||
"""Tests for normalize_string function."""
|
||||
|
||||
def test_normalize_basic_string(self):
|
||||
result = normalize_string("Hello World")
|
||||
assert result == "hello_world"
|
||||
|
||||
def test_normalize_special_characters(self):
|
||||
result = normalize_string("File#Name@2024!.pdf")
|
||||
assert result == "filename2024.pdf"
|
||||
|
||||
def test_normalize_unicode(self):
|
||||
result = normalize_string("Café Münchën")
|
||||
assert result == "cafe_munchen"
|
||||
|
||||
def test_normalize_multiple_spaces(self):
|
||||
result = normalize_string("Multiple Spaces Here")
|
||||
assert result == "multiple_spaces_here"
|
||||
|
||||
def test_normalize_with_hyphens_and_periods(self):
|
||||
result = normalize_string("valid-filename.2024")
|
||||
assert result == "valid-filename.2024"
|
||||
|
||||
def test_normalize_empty_string(self):
|
||||
result = normalize_string("")
|
||||
assert result == ""
|
||||
|
||||
def test_normalize_only_special_chars(self):
|
||||
result = normalize_string("@#$%^&*()")
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestGatherFiles:
|
||||
"""Tests for gather_files function."""
|
||||
|
||||
def test_gather_files_finds_pdfs(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
|
||||
# Create mock blobs
|
||||
mock_blob1 = Mock()
|
||||
mock_blob1.name = "docs/file1.pdf"
|
||||
mock_blob2 = Mock()
|
||||
mock_blob2.name = "docs/file2.pdf"
|
||||
mock_blob3 = Mock()
|
||||
mock_blob3.name = "docs/readme.txt"
|
||||
|
||||
mock_bucket.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3]
|
||||
|
||||
files = gather_pdfs("gs://my-bucket/docs", mock_client)
|
||||
|
||||
assert len(files) == 2
|
||||
assert "gs://my-bucket/docs/file1.pdf" in files
|
||||
assert "gs://my-bucket/docs/file2.pdf" in files
|
||||
assert "gs://my-bucket/docs/readme.txt" not in files
|
||||
|
||||
def test_gather_files_no_pdfs(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
|
||||
mock_blob = Mock()
|
||||
mock_blob.name = "docs/readme.txt"
|
||||
mock_bucket.list_blobs.return_value = [mock_blob]
|
||||
|
||||
files = gather_pdfs("gs://my-bucket/docs", mock_client)
|
||||
|
||||
assert len(files) == 0
|
||||
|
||||
def test_gather_files_empty_bucket(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.list_blobs.return_value = []
|
||||
|
||||
files = gather_pdfs("gs://my-bucket/docs", mock_client)
|
||||
|
||||
assert len(files) == 0
|
||||
|
||||
def test_gather_files_correct_prefix(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.list_blobs.return_value = []
|
||||
|
||||
gather_pdfs("gs://my-bucket/docs/subfolder", mock_client)
|
||||
|
||||
mock_client.bucket.assert_called_once_with("my-bucket")
|
||||
mock_bucket.list_blobs.assert_called_once_with(prefix="docs/subfolder")
|
||||
|
||||
|
||||
class TestSplitIntoChunks:
|
||||
"""Tests for split_into_chunks function."""
|
||||
|
||||
def test_split_small_text_single_chunk(self):
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
mock_chunker.max_chunk_size = 1000
|
||||
|
||||
text = "Small text"
|
||||
file_id = "test_file"
|
||||
|
||||
chunks = split_into_chunks(text, file_id, mock_chunker)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["page_content"] == "Small text"
|
||||
assert chunks[0]["metadata"]["id"] == "test_file"
|
||||
mock_chunker.process_text.assert_not_called()
|
||||
|
||||
def test_split_large_text_multiple_chunks(self):
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
mock_chunker.max_chunk_size = 10
|
||||
|
||||
# Create text larger than max_chunk_size
|
||||
text = "This is a very long text that needs to be split into chunks"
|
||||
file_id = "test_file"
|
||||
|
||||
# Mock the chunker to return multiple chunks
|
||||
mock_chunker.process_text.return_value = [
|
||||
{"page_content": "This is a very", "metadata": {}},
|
||||
{"page_content": "long text that", "metadata": {}},
|
||||
{"page_content": "needs to be split", "metadata": {}},
|
||||
]
|
||||
|
||||
chunks = split_into_chunks(text, file_id, mock_chunker)
|
||||
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0]["metadata"]["id"] == "test_file_0"
|
||||
assert chunks[1]["metadata"]["id"] == "test_file_1"
|
||||
assert chunks[2]["metadata"]["id"] == "test_file_2"
|
||||
mock_chunker.process_text.assert_called_once_with(text)
|
||||
|
||||
def test_split_exactly_max_size(self):
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
mock_chunker.max_chunk_size = 10
|
||||
|
||||
text = "0123456789" # Exactly 10 characters
|
||||
file_id = "test_file"
|
||||
|
||||
chunks = split_into_chunks(text, file_id, mock_chunker)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["page_content"] == text
|
||||
mock_chunker.process_text.assert_not_called()
|
||||
|
||||
|
||||
class TestUploadToGcs:
|
||||
"""Tests for upload_to_gcs function."""
|
||||
|
||||
def test_upload_single_chunk_and_vectors(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
|
||||
chunks = [
|
||||
{
|
||||
"page_content": "Test content",
|
||||
"metadata": {"id": "chunk_1"},
|
||||
}
|
||||
]
|
||||
vectors = [{"id": "chunk_1", "embedding": [0.1, 0.2]}]
|
||||
|
||||
upload_to_gcs(
|
||||
chunks,
|
||||
vectors,
|
||||
"gs://my-bucket/contents",
|
||||
"gs://my-bucket/vectors/vectors.jsonl",
|
||||
mock_client,
|
||||
)
|
||||
|
||||
blob_calls = [call[0][0] for call in mock_bucket.blob.call_args_list]
|
||||
assert "contents/chunk_1.md" in blob_calls
|
||||
assert "vectors/vectors.jsonl" in blob_calls
|
||||
|
||||
def test_upload_multiple_chunks(self):
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
|
||||
chunks = [
|
||||
{"page_content": "Content 1", "metadata": {"id": "chunk_1"}},
|
||||
{"page_content": "Content 2", "metadata": {"id": "chunk_2"}},
|
||||
{"page_content": "Content 3", "metadata": {"id": "chunk_3"}},
|
||||
]
|
||||
vectors = [{"id": "chunk_1", "embedding": [0.1]}]
|
||||
|
||||
upload_to_gcs(
|
||||
chunks,
|
||||
vectors,
|
||||
"gs://my-bucket/contents",
|
||||
"gs://my-bucket/vectors/vectors.jsonl",
|
||||
mock_client,
|
||||
)
|
||||
|
||||
# 3 chunk blobs + 1 vectors blob
|
||||
assert mock_bucket.blob.call_count == 4
|
||||
|
||||
blob_calls = [call[0][0] for call in mock_bucket.blob.call_args_list]
|
||||
assert blob_calls == [
|
||||
"contents/chunk_1.md",
|
||||
"contents/chunk_2.md",
|
||||
"contents/chunk_3.md",
|
||||
"vectors/vectors.jsonl",
|
||||
]
|
||||
|
||||
|
||||
class TestBuildVectors:
|
||||
"""Tests for build_vectors function."""
|
||||
|
||||
def test_build_vectors_basic(self):
|
||||
chunks = [
|
||||
{"metadata": {"id": "doc_1"}, "page_content": "content 1"},
|
||||
{"metadata": {"id": "doc_2"}, "page_content": "content 2"},
|
||||
]
|
||||
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
source_folder = "documents/reports"
|
||||
|
||||
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||
|
||||
assert len(vectors) == 2
|
||||
assert vectors[0]["id"] == "doc_1"
|
||||
assert vectors[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||
assert vectors[0]["restricts"] == [
|
||||
{"namespace": "source", "allow": ["documents"]}
|
||||
]
|
||||
assert vectors[1]["id"] == "doc_2"
|
||||
assert vectors[1]["embedding"] == [0.4, 0.5, 0.6]
|
||||
|
||||
def test_build_vectors_empty_source(self):
|
||||
chunks = [{"metadata": {"id": "doc_1"}, "page_content": "content"}]
|
||||
embeddings = [[0.1, 0.2]]
|
||||
source_folder = ""
|
||||
|
||||
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||
|
||||
assert len(vectors) == 1
|
||||
assert vectors[0]["restricts"] == [{"namespace": "source", "allow": [""]}]
|
||||
|
||||
def test_build_vectors_nested_path(self):
|
||||
chunks = [{"metadata": {"id": "doc_1"}, "page_content": "content"}]
|
||||
embeddings = [[0.1]]
|
||||
source_folder = "a/b/c/d"
|
||||
|
||||
vectors = build_vectors(chunks, embeddings, source_folder)
|
||||
|
||||
assert vectors[0]["restricts"] == [{"namespace": "source", "allow": ["a"]}]
|
||||
|
||||
|
||||
class TestCreateVectorIndex:
|
||||
"""Tests for create_vector_index function."""
|
||||
|
||||
@patch("knowledge_pipeline.pipeline.aiplatform.MatchingEngineIndexEndpoint")
|
||||
@patch("knowledge_pipeline.pipeline.aiplatform.MatchingEngineIndex")
|
||||
def test_create_vector_index(self, mock_index_class, mock_endpoint_class):
|
||||
mock_index = Mock()
|
||||
mock_endpoint = Mock()
|
||||
|
||||
mock_index_class.create_tree_ah_index.return_value = mock_index
|
||||
mock_endpoint_class.create.return_value = mock_endpoint
|
||||
|
||||
create_vector_index(
|
||||
index_name="test-index",
|
||||
index_vectors_dir="gs://bucket/vectors",
|
||||
index_dimensions=768,
|
||||
index_distance_measure_type=DistanceMeasureType.DOT_PRODUCT_DISTANCE,
|
||||
index_deployment="test_index_deployed",
|
||||
index_machine_type="e2-standard-16",
|
||||
)
|
||||
|
||||
mock_index_class.create_tree_ah_index.assert_called_once_with(
|
||||
display_name="test-index",
|
||||
contents_delta_uri="gs://bucket/vectors",
|
||||
dimensions=768,
|
||||
approximate_neighbors_count=150,
|
||||
distance_measure_type=DistanceMeasureType.DOT_PRODUCT_DISTANCE,
|
||||
leaf_node_embedding_count=1000,
|
||||
leaf_nodes_to_search_percent=10,
|
||||
)
|
||||
|
||||
mock_endpoint_class.create.assert_called_once_with(
|
||||
display_name="test-index-endpoint",
|
||||
public_endpoint_enabled=True,
|
||||
)
|
||||
|
||||
mock_endpoint.deploy_index.assert_called_once_with(
|
||||
index=mock_index,
|
||||
deployed_index_id="test_index_deployed",
|
||||
machine_type="e2-standard-16",
|
||||
sync=False,
|
||||
)
|
||||
|
||||
|
||||
class TestProcessFile:
|
||||
"""Tests for process_file function."""
|
||||
|
||||
def test_process_file_success(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
# Mock dependencies
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
|
||||
mock_converter = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.text_content = "Converted markdown content"
|
||||
mock_converter.convert.return_value = mock_result
|
||||
|
||||
mock_embedder = Mock()
|
||||
mock_embeddings_result = Mock()
|
||||
mock_embeddings_result.embeddings = [[0.1, 0.2, 0.3]]
|
||||
mock_embedder.embed_documents_sync.return_value = mock_embeddings_result
|
||||
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
mock_chunker.max_chunk_size = 1000
|
||||
|
||||
file_uri = "gs://my-bucket/docs/test-file.pdf"
|
||||
|
||||
chunks, vectors = process_file(
|
||||
file_uri,
|
||||
temp_path,
|
||||
mock_client,
|
||||
mock_converter,
|
||||
mock_embedder,
|
||||
mock_chunker,
|
||||
)
|
||||
|
||||
# Verify download was called
|
||||
mock_client.bucket.assert_called_with("my-bucket")
|
||||
mock_bucket.blob.assert_called_with("docs/test-file.pdf")
|
||||
assert mock_blob.download_to_filename.called
|
||||
|
||||
# Verify converter was called
|
||||
assert mock_converter.convert.called
|
||||
|
||||
# Verify embedder was called
|
||||
mock_embedder.embed_documents_sync.assert_called_once()
|
||||
|
||||
# Verify results
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["page_content"] == "Converted markdown content"
|
||||
assert len(vectors) == 1
|
||||
assert vectors[0]["embedding"] == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_process_file_cleans_up_temp_file(self):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_path = Path(temp_dir)
|
||||
|
||||
mock_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
|
||||
mock_converter = Mock()
|
||||
mock_converter.convert.side_effect = Exception("Conversion failed")
|
||||
|
||||
mock_embedder = Mock()
|
||||
mock_chunker = Mock(spec=BaseChunker)
|
||||
|
||||
file_uri = "gs://my-bucket/docs/test.pdf"
|
||||
|
||||
# This should raise an exception but still clean up
|
||||
with pytest.raises(Exception, match="Conversion failed"):
|
||||
process_file(
|
||||
file_uri,
|
||||
temp_path,
|
||||
mock_client,
|
||||
mock_converter,
|
||||
mock_embedder,
|
||||
mock_chunker,
|
||||
)
|
||||
|
||||
# File should be cleaned up even after exception
|
||||
temp_file = temp_path / "test.pdf"
|
||||
assert not temp_file.exists()
|
||||
|
||||
|
||||
class TestRunPipeline:
|
||||
"""Tests for run_pipeline function."""
|
||||
|
||||
@patch("knowledge_pipeline.pipeline.create_vector_index")
|
||||
@patch("knowledge_pipeline.pipeline.upload_to_gcs")
|
||||
@patch("knowledge_pipeline.pipeline.process_file")
|
||||
@patch("knowledge_pipeline.pipeline.gather_pdfs")
|
||||
def test_run_pipeline_integration(
|
||||
self,
|
||||
mock_gather,
|
||||
mock_process,
|
||||
mock_upload,
|
||||
mock_create_index,
|
||||
):
|
||||
# Mock settings
|
||||
mock_settings = Mock()
|
||||
mock_settings.index_origin = "gs://bucket/input"
|
||||
mock_settings.index_contents_dir = "gs://bucket/contents"
|
||||
mock_settings.index_vectors_jsonl_path = "gs://bucket/vectors/vectors.jsonl"
|
||||
mock_settings.index_name = "test-index"
|
||||
mock_settings.index_vectors_dir = "gs://bucket/vectors"
|
||||
mock_settings.index_dimensions = 768
|
||||
mock_settings.index_distance_measure_type = (
|
||||
DistanceMeasureType.DOT_PRODUCT_DISTANCE
|
||||
)
|
||||
mock_settings.index_deployment = "test_index_deployed"
|
||||
mock_settings.index_machine_type = "e2-standard-16"
|
||||
|
||||
mock_gcs_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_gcs_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_settings.gcs_client = mock_gcs_client
|
||||
|
||||
mock_settings.converter = Mock()
|
||||
mock_settings.embedder = Mock()
|
||||
mock_settings.chunker = Mock()
|
||||
|
||||
# Mock gather_files to return test files
|
||||
mock_gather.return_value = ["gs://bucket/input/file1.pdf"]
|
||||
|
||||
# Mock process_file to return chunks and vectors
|
||||
mock_chunks = [{"page_content": "content", "metadata": {"id": "chunk_1"}}]
|
||||
mock_vectors = [
|
||||
{
|
||||
"id": "chunk_1",
|
||||
"embedding": [0.1, 0.2],
|
||||
"restricts": [{"namespace": "source", "allow": ["input"]}],
|
||||
}
|
||||
]
|
||||
mock_process.return_value = (mock_chunks, mock_vectors)
|
||||
|
||||
run_pipeline(mock_settings)
|
||||
|
||||
# Verify all steps were called
|
||||
mock_gather.assert_called_once_with("gs://bucket/input", mock_gcs_client)
|
||||
mock_process.assert_called_once()
|
||||
mock_upload.assert_called_once_with(
|
||||
mock_chunks,
|
||||
mock_vectors,
|
||||
"gs://bucket/contents",
|
||||
"gs://bucket/vectors/vectors.jsonl",
|
||||
mock_gcs_client,
|
||||
)
|
||||
mock_create_index.assert_called_once()
|
||||
|
||||
@patch("knowledge_pipeline.pipeline.create_vector_index")
|
||||
@patch("knowledge_pipeline.pipeline.upload_to_gcs")
|
||||
@patch("knowledge_pipeline.pipeline.process_file")
|
||||
@patch("knowledge_pipeline.pipeline.gather_pdfs")
|
||||
def test_run_pipeline_multiple_files(
|
||||
self,
|
||||
mock_gather,
|
||||
mock_process,
|
||||
mock_upload,
|
||||
mock_create_index,
|
||||
):
|
||||
mock_settings = Mock()
|
||||
mock_settings.index_origin = "gs://bucket/input"
|
||||
mock_settings.index_contents_dir = "gs://bucket/contents"
|
||||
mock_settings.index_vectors_jsonl_path = "gs://bucket/vectors/vectors.jsonl"
|
||||
mock_settings.index_name = "test-index"
|
||||
mock_settings.index_vectors_dir = "gs://bucket/vectors"
|
||||
mock_settings.index_dimensions = 768
|
||||
mock_settings.index_distance_measure_type = (
|
||||
DistanceMeasureType.DOT_PRODUCT_DISTANCE
|
||||
)
|
||||
mock_settings.index_deployment = "test_index_deployed"
|
||||
mock_settings.index_machine_type = "e2-standard-16"
|
||||
|
||||
mock_gcs_client = Mock()
|
||||
mock_bucket = Mock()
|
||||
mock_blob = Mock()
|
||||
mock_gcs_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_settings.gcs_client = mock_gcs_client
|
||||
|
||||
mock_settings.converter = Mock()
|
||||
mock_settings.embedder = Mock()
|
||||
mock_settings.chunker = Mock()
|
||||
|
||||
# Return multiple files
|
||||
mock_gather.return_value = [
|
||||
"gs://bucket/input/file1.pdf",
|
||||
"gs://bucket/input/file2.pdf",
|
||||
]
|
||||
|
||||
mock_process.return_value = (
|
||||
[{"page_content": "content", "metadata": {"id": "chunk_1"}}],
|
||||
[{"id": "chunk_1", "embedding": [0.1], "restricts": []}],
|
||||
)
|
||||
|
||||
run_pipeline(mock_settings)
|
||||
|
||||
# Verify process_file was called for each file
|
||||
assert mock_process.call_count == 2
|
||||
# Upload is called once with all accumulated chunks and vectors
|
||||
mock_upload.assert_called_once()
|
||||
40
utils/delete_endpoint.py
Normal file
40
utils/delete_endpoint.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Delete a GCP Vector Search endpoint by ID.
|
||||
|
||||
Undeploys any deployed indexes before deleting the endpoint.
|
||||
|
||||
Usage:
|
||||
uv run python utils/delete_endpoint.py <endpoint_id> [--project PROJECT] [--location LOCATION]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from google.cloud import aiplatform
|
||||
|
||||
|
||||
def delete_endpoint(endpoint_id: str, project: str, location: str) -> None:
|
||||
aiplatform.init(project=project, location=location)
|
||||
endpoint = aiplatform.MatchingEngineIndexEndpoint(endpoint_id)
|
||||
|
||||
print(f"Endpoint: {endpoint.display_name}")
|
||||
|
||||
for deployed in endpoint.deployed_indexes:
|
||||
print(f"Undeploying index: {deployed.id}")
|
||||
endpoint.undeploy_index(deployed_index_id=deployed.id)
|
||||
print(f"Undeployed: {deployed.id}")
|
||||
|
||||
endpoint.delete()
|
||||
print(f"Endpoint {endpoint_id} deleted successfully.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Delete a GCP Vector Search endpoint.")
|
||||
parser.add_argument("endpoint_id", help="The endpoint ID to delete.")
|
||||
parser.add_argument("--project", default="bnt-orquestador-cognitivo-dev")
|
||||
parser.add_argument("--location", default="us-central1")
|
||||
args = parser.parse_args()
|
||||
|
||||
delete_endpoint(args.endpoint_id, args.project, args.location)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
132
utils/search_index.py
Normal file
132
utils/search_index.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Search a deployed Vertex AI Vector Search index.
|
||||
|
||||
Embeds a query, finds nearest neighbors, and retrieves chunk contents from GCS.
|
||||
|
||||
Usage:
|
||||
uv run python utils/search_index.py "your search query" <endpoint_id> <index_deployment_id> \
|
||||
[--source SOURCE] [--top-k 5] [--project PROJECT] [--location LOCATION]
|
||||
|
||||
Examples:
|
||||
# Basic search
|
||||
uv run python utils/search_index.py "¿Cómo funciona el proceso?" 123456 blue_ivy_deployed
|
||||
|
||||
# Filter by source folder
|
||||
uv run python utils/search_index.py "requisitos" 123456 blue_ivy_deployed --source "manuales"
|
||||
|
||||
# Return more results
|
||||
uv run python utils/search_index.py "políticas" 123456 blue_ivy_deployed --top-k 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from google.cloud import aiplatform, storage
|
||||
from pydantic_ai import Embedder
|
||||
|
||||
|
||||
def search_index(
|
||||
query: str,
|
||||
endpoint_id: str,
|
||||
deployed_index_id: str,
|
||||
project: str,
|
||||
location: str,
|
||||
embedding_model: str,
|
||||
contents_dir: str,
|
||||
top_k: int,
|
||||
source: str | None,
|
||||
) -> None:
|
||||
aiplatform.init(project=project, location=location)
|
||||
|
||||
embedder = Embedder(f"google-vertex:{embedding_model}")
|
||||
query_embedding = embedder.embed_documents_sync([query]).embeddings[0]
|
||||
|
||||
endpoint = aiplatform.MatchingEngineIndexEndpoint(endpoint_id)
|
||||
|
||||
restricts = None
|
||||
if source:
|
||||
restricts = [
|
||||
aiplatform.matching_engine.matching_engine_index_endpoint.Namespace(
|
||||
name="source",
|
||||
allow_tokens=[source],
|
||||
)
|
||||
]
|
||||
|
||||
response = endpoint.find_neighbors(
|
||||
deployed_index_id=deployed_index_id,
|
||||
queries=[list(query_embedding)],
|
||||
num_neighbors=top_k,
|
||||
filter=restricts,
|
||||
)
|
||||
|
||||
if not response or not response[0]:
|
||||
print("No results found.")
|
||||
return
|
||||
|
||||
gcs_client = storage.Client()
|
||||
neighbors = response[0]
|
||||
|
||||
print(f"Found {len(neighbors)} results for: {query!r}\n")
|
||||
for i, neighbor in enumerate(neighbors, 1):
|
||||
chunk_id = neighbor.id
|
||||
distance = neighbor.distance
|
||||
|
||||
content = _fetch_chunk_content(gcs_client, contents_dir, chunk_id)
|
||||
|
||||
print(f"--- Result {i} (id={chunk_id}, distance={distance:.4f}) ---")
|
||||
print(content)
|
||||
print()
|
||||
|
||||
|
||||
def _fetch_chunk_content(
|
||||
gcs_client: storage.Client, contents_dir: str, chunk_id: str
|
||||
) -> str:
|
||||
"""Fetches a chunk's markdown content from GCS."""
|
||||
uri = f"{contents_dir}/{chunk_id}.md"
|
||||
bucket_name, _, obj_path = uri.removeprefix("gs://").partition("/")
|
||||
blob = gcs_client.bucket(bucket_name).blob(obj_path)
|
||||
if not blob.exists():
|
||||
return f"[content not found: {uri}]"
|
||||
return blob.download_as_text()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Search a deployed Vertex AI Vector Search index."
|
||||
)
|
||||
parser.add_argument("query", help="The search query text.")
|
||||
parser.add_argument("endpoint_id", help="The deployed endpoint ID.")
|
||||
parser.add_argument("deployed_index_id", help="The deployed index ID.")
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
default=None,
|
||||
help="Filter results by source folder (metadata namespace).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k", type=int, default=5, help="Number of results to return (default: 5)."
|
||||
)
|
||||
parser.add_argument("--project", default="bnt-orquestador-cognitivo-dev")
|
||||
parser.add_argument("--location", default="us-central1")
|
||||
parser.add_argument(
|
||||
"--embedding-model", default="gemini-embedding-001", help="Embedding model name."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--contents-dir",
|
||||
default="gs://bnt_orquestador_cognitivo_gcs_configs_dev/blue-ivy/contents",
|
||||
help="GCS URI of the contents directory.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
search_index(
|
||||
query=args.query,
|
||||
endpoint_id=args.endpoint_id,
|
||||
deployed_index_id=args.deployed_index_id,
|
||||
project=args.project,
|
||||
location=args.location,
|
||||
embedding_model=args.embedding_model,
|
||||
contents_dir=args.contents_dir,
|
||||
top_k=args.top_k,
|
||||
source=args.source,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user