import marimo __generated_with = "0.13.15" app = marimo.App(width="medium") with app.setup: import hashlib import json import logging import textwrap import time from pathlib import Path from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document from pdf2image import convert_from_path from pypdf import PdfReader from qdrant_client.models import Distance, PointStruct, VectorParams import matplotlib.pyplot as plt import seaborn as sns import numpy as np from banortegpt.embedding.azure_ada import Ada from banortegpt.generation.vertex_ai_gemini import Gemini from banortegpt.vector.qdrant import Qdrant logger = logging.getLogger(__name__) def load_prompt(prompt_file: str) -> str: prompt_dir = Path("prompts/") return (prompt_dir / prompt_file).read_text() class TempFile: temp_dir = Path("temp_dir/") def __init__(self, name: str, contents: bytes): self.name = name self.contents = contents def __enter__(self): self.file = self.temp_dir / self.name self.file.write_bytes(self.contents) return self.file def __exit__(self, exc_type, exc_val, exc_tb): self.file.unlink() def id_from_json(json_data: dict) -> int: json_str = json.dumps(json_data, sort_keys=True) hash_obj = hashlib.sha256(json_str.encode("utf-8")) return abs(int.from_bytes(hash_obj.digest(), byteorder="big")) @app.class_definition(hide_code=True) class PDFPageExtractor: detect_special_format_prompt = load_prompt("detect_special_format_prompt.md") def __init__(self, gemini_client: Gemini): self.client = gemini_client self._cache = {} # Cache para resultados de detección def detect_special_format(self, chunk: Document) -> bool: """ Detecta si un chunk contiene tablas o formatos especiales. Usa caché para evitar llamadas API repetidas. """ # Usar un hash simple del contenido como clave de caché cache_key = hash(chunk.page_content) if cache_key in self._cache: return self._cache[cache_key] start_time = time.time() try: prompt = self.detect_special_format_prompt.format(chunk.page_content) response = self.client.generate(prompt).text result = response.strip().upper() == "SI" self._cache[cache_key] = result logger.info(f"Tiempo de análisis de chunk: {time.time() - start_time:.2f}s") return result except Exception as e: logger.error(f"Error detectando formato especial: {e}") return False def _create_chunks_from_pdf( self, pdf_path: Path, chunk_size: int = 1000, chunk_overlap: int = 200 ) -> list[Document]: """ Crea chunks a partir de un PDF manteniendo la información de la página original. """ start_time = time.time() logger.info(f"Iniciando lectura del PDF: {pdf_path}") pdf = PdfReader(pdf_path) total_pages = len(pdf.pages) logger.info(f"Total de páginas en el PDF: {total_pages}") chunks = [] text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, separators=["\n\n", "\n", " ", ""], ) for page_num in range(total_pages): page_start = time.time() logger.info(f"Procesando página {page_num + 1}/{total_pages}...") page = pdf.pages[page_num] text = page.extract_text() if text.strip(): page_chunks = text_splitter.create_documents( [text], metadatas=[{"page": page_num + 1, "file_name": pdf_path.name}], ) chunks.extend(page_chunks) logger.info( f" - Chunks creados para página {page_num + 1}: {len(page_chunks)}" ) else: logger.info(f" - Página {page_num + 1} está vacía o no contiene texto") logger.info( f" - Tiempo de procesamiento página {page_num + 1}: {time.time() - page_start:.2f}s" ) logger.info( f"Tiempo total de procesamiento PDF: {time.time() - start_time:.2f}s" ) logger.info(f"Total de chunks creados: {len(chunks)}") return chunks def process_pdf( self, pdf_path: Path, output_dir: Path, chunk_size: int = 1000, chunk_overlap: int = 200, ) -> list[Document]: """ Procesa un PDF completo, detectando formatos especiales y extrayendo páginas. """ overall_start = time.time() logger.info("\n=== Iniciando procesamiento de PDF ===") if not output_dir.exists(): output_dir.mkdir() logger.info(f"Directorio de salida creado: {output_dir}") # Crear chunks del PDF logger.info("\n1. Creando chunks del PDF...") chunks_start = time.time() chunks = self._create_chunks_from_pdf(pdf_path, chunk_size, chunk_overlap) logger.info(f"Chunks creados en {time.time() - chunks_start:.2f}s") processed_chunks = [] pages_to_extract = set() # Identificar páginas con formatos especiales logger.info("\n2. Analizando chunks para detectar formatos especiales...") analysis_start = time.time() for i, chunk in enumerate(chunks, 1): logger.info(f"\nAnalizando chunk {i}/{len(chunks)}") if self.detect_special_format(chunk): page_number = chunk.metadata.get("page") if page_number not in pages_to_extract: pages_to_extract.add(page_number) logger.info( f" - Formato especial detectado en página {page_number}" ) logger.info(f"Análisis completado en {time.time() - analysis_start:.2f}s") logger.info(f"Páginas a extraer: {sorted(pages_to_extract)}") # Extraer páginas con formatos especiales if pages_to_extract: logger.info("\n3. Extrayendo páginas como imágenes...") extraction_start = time.time() for page_number in sorted(pages_to_extract): page_start = time.time() logger.info(f"\nProcesando página {page_number}...") pdf_filename = pdf_path.name image_path = output_dir / f"{page_number}_{pdf_filename}.png" try: images = convert_from_path( pdf_path, first_page=page_number, last_page=page_number, dpi=150, thread_count=4, grayscale=False, ) if images: images[0].save(image_path, "PNG", optimize=True) logger.info(f" - Imagen guardada: {image_path}") logger.info( f" - Tiempo de extracción: {time.time() - page_start:.2f}s" ) except Exception as e: logger.error(f" - Error extrayendo página {page_number}: {e}") logger.info( f"Extracción de imágenes completada en {time.time() - extraction_start:.2f}s" ) # Procesar chunks y agregar referencias a imágenes logger.info("\n4. Procesando chunks finales...") for chunk in chunks: page_number = chunk.metadata.get("page") if page_number in pages_to_extract: pdf_filename = pdf_path.name image_path = output_dir / f"{page_number}_{pdf_filename}.png" if image_path.exists(): image_reference = f"\n[Ver página {page_number} completa en imagen: {image_path}]\n" chunk.page_content = image_reference + chunk.page_content processed_chunks.append(chunk) total_time = time.time() - overall_start logger.info(f"\n=== Procesamiento completado en {total_time:.2f}s ===") logger.info(f"Total de chunks procesados: {len(processed_chunks)}") logger.info(f"Total de páginas extraídas como imagen: {len(pages_to_extract)}") return processed_chunks @app.class_definition(hide_code=True) class ChunkProcessor: should_merge_prompt = load_prompt("should_merge_prompt.md") enhance_chunk_prompt = load_prompt("enhance_chunk_prompt.md") MAX_TOKENS = 750 # límite máximo de tokens def __init__(self, gemini_client: Gemini, chunks_per_page: int = 5): self.client = gemini_client self.chunks_per_page = chunks_per_page def should_merge_chunks(self, chunk1: str, chunk2: str) -> bool: """ Determina si dos chunks deberían unirse basado en su contenido y longitud. """ try: combined_length = len(chunk1) + len(chunk2) if combined_length > 3375: return False prompt = self.should_merge_prompt.format(chunk1, chunk2) response = self.client.generate(prompt).text return response.strip().upper() == "SI" except Exception as e: logger.error(f"Error analizando chunks: {e}") return False def enhance_chunk(self, chunk_text: str) -> str: """Mejora un chunk individual manteniendo el límite de tokens.""" try: prompt = self.enhance_chunk_prompt.format(chunk_text) response = self.client.generate(prompt).text enhanced_text = response.strip() if len(enhanced_text) > 3375: logger.warning( "Advertencia: Texto optimizado excede el límite de tokens" ) truncated = enhanced_text[:3375].rsplit(".", 1)[0] + "." return truncated return enhanced_text except Exception as e: logger.error(f"Error procesando chunk: {e}") return chunk_text def process_chunks( self, chunks: list[Document], merge_related: bool = False ) -> list[Document]: """ Procesa y opcionalmente une chunks relacionados. Args: chunks: Lista de chunks a procesar merge_related: Si es True, intenta unir chunks relacionados Returns: List[Document]: Lista de chunks procesados """ processed_chunks = [] i = 0 while i < len(chunks): current_chunk = chunks[i] merged_content = current_chunk.page_content if merge_related and i < len(chunks) - 1: while i < len(chunks) - 1 and self.should_merge_chunks( merged_content, chunks[i + 1].page_content ): logger.info(f"\nUniendo chunks {i + 1} y {i + 2}...") merged_content += "\n\n" + chunks[i + 1].page_content i += 1 logger.info(f"\nProcesando chunk {i + 1}:") logger.info(textwrap.fill(merged_content, width=80)) logger.info("\nMejorando contenido") enhanced_content = self.enhance_chunk(merged_content) processed_chunks.append( Document(page_content=enhanced_content, metadata=current_chunk.metadata) ) logger.info("\nContenido mejorado") logger.info(textwrap.fill(enhanced_content, width=80)) logger.info("-" * 80) i += 1 if i % self.chunks_per_page == 0 and i < len(chunks): continue_processing = "s" # input("\n¿Continuar con la siguiente página? (s/n): ").lower() if continue_processing != "s": break return processed_chunks @app.class_definition(hide_code=True) class Pipeline: def __init__(self, *, ada: Ada, qdrant: Qdrant, gemini: Gemini): self.ada = ada self.qdrant = qdrant self.gemini = gemini self.extractor = PDFPageExtractor(gemini_client=gemini) self.processor = ChunkProcessor(gemini_client=gemini) def run(self, name: str, contents: bytes): with TempFile(name=name, contents=contents) as pdf: chunks = self.extractor.process_pdf(pdf, Path("output_images")) merged_enhanced_chunks = self.processor.process_chunks( chunks, merge_related=True ) points = self._build_points_from_chunks(merged_enhanced_chunks) return points def _build_points_from_chunks(self, chunks): points = [ PointStruct( id=id_from_json(document.metadata), payload={ "page_content": document.page_content, "metadata": document.metadata, }, vector={self.ada.model: self.ada.embed(input=document.page_content)}, ) for document in chunks ] return points def upload_points(self, points: list[PointStruct]): self.qdrant.create_collection_if_not_exists( vector_config={ self.ada.model: VectorParams(size=3072, distance=Distance.COSINE) } ) self.qdrant.upload_to_collection(points=points) @classmethod def from_vault( cls, vault: str, *, collection: str, embedding_model: str, gemini_model: str ): return cls( ada=Ada.from_vault(vault, model=embedding_model), qdrant=Qdrant.from_vault(vault, collection=collection), gemini=Gemini.from_vault(vault, model=gemini_model), ) @app.class_definition(hide_code=True) class ChunkDistGraph: def __init__( self, points: list[dict], campo_texto: str = "page_content", titulo: str = "Distribución de Chunks por Longitud", ) -> None: self.points = points self.campo_texto = campo_texto self.title = titulo def show(self): longitudes = self._obtener_longitudes() plot = self._visualizar_distribucion_chunks(longitudes) return plot.gcf() def _obtener_longitudes(self) -> list[int]: """ Obtiene la longitud de todos los chunks de texto en una lista de puntos. """ longitudes = [] for point in self.points: texto = point.payload[self.campo_texto] longitudes.append(len(str(texto))) return longitudes def _visualizar_distribucion_chunks(self, longitudes: list[int]): """ Crea una visualización de la distribución de chunks según su longitud. """ plt.figure(figsize=(15, 6)) n_bins = int(np.log2(len(longitudes)) + 1) n, bins, patches = plt.hist( longitudes, bins=n_bins, color="skyblue", edgecolor="black", alpha=0.7 ) from scipy.stats import gaussian_kde density = gaussian_kde(longitudes) xs = np.linspace(min(longitudes), max(longitudes), 200) plt.plot( xs, density(xs) * len(longitudes) * (bins[1] - bins[0]), color="red", linewidth=2, label="Tendencia", ) # Personalizar el gráfico plt.title(self.title, fontsize=14, pad=20) plt.xlabel("Cantidad de Caracteres", fontsize=12) plt.ylabel("Cantidad de Chunks", fontsize=12) media = np.mean(longitudes) mediana = np.median(longitudes) desv_std = np.std(longitudes) stats_text = ( f"Estadísticas:\n" f"• Media: {media:.1f} caracteres\n" f"• Mediana: {mediana:.1f} caracteres\n" f"• Desv. Estándar: {desv_std:.1f}\n" f"• Total de chunks: {len(longitudes)}" ) plt.text( 1.02, 0.95, stats_text, transform=plt.gca().transAxes, bbox=dict(facecolor="white", alpha=0.8), verticalalignment="top", ) plt.tight_layout() return plt @app.class_definition(hide_code=True) class ChunkDistGraph2: def __init__( self, points: list[dict], campo_texto: str = "page_content", titulo: str = "Distribución de longitud de chunks", ) -> None: self.points = points self.campo_texto = campo_texto self.titulo = titulo def show(self): chunks_info = self._obtener_longitudes_chunks() longitudes = [length for length, _, _, _ in chunks_info] chunks_extremos = self._encontrar_chunks_extremos(chunks_info) print("\nInformación de la colección:") print(f"Número total de chunks: {len(longitudes)}") print(f"Número de longitudes únicas: {len(set(longitudes))}") if longitudes: print(f"Rango de longitudes: {min(longitudes)} a {max(longitudes)}") fig = self._visualizar_distribucion(longitudes, chunks_extremos) return fig.gcf() def _obtener_longitudes_chunks(self) -> list[int]: """ Obtiene la longitud de todos los chunks de texto en una colección de Qdrant. """ chunks_info = [] for point in self.points: # Fixed: was using 'points' instead of 'self.points' texto = point.payload[self.campo_texto] chunks_info.append( ( len(str(texto)), str(texto)[:100], str(point.id), point.payload.get("metadata", {}).get("page", "N/A"), ) ) return chunks_info def _encontrar_chunks_extremos( self, chunks_info: list[tuple[int, str, str, str]] ) -> dict: """ Encuentra los chunks más largo y más corto. """ if not chunks_info: return {} chunk_mas_corto = min(chunks_info, key=lambda x: x[0]) chunk_mas_largo = max(chunks_info, key=lambda x: x[0]) return { "mas_corto": { "longitud": chunk_mas_corto[0], "preview": chunk_mas_corto[1] + "..." if len(chunk_mas_corto[1]) == 100 else chunk_mas_corto[1], "id": chunk_mas_corto[2], "page": chunk_mas_corto[3], }, "mas_largo": { "longitud": chunk_mas_largo[0], "preview": chunk_mas_largo[1] + "..." if len(chunk_mas_largo[1]) == 100 else chunk_mas_largo[1], "id": chunk_mas_largo[2], "page": chunk_mas_largo[3], }, } def _visualizar_distribucion(self, longitudes: list[int], chunks_extremos: dict): """ Crea una visualización suavizada de la distribución de longitudes. """ if not longitudes: raise ValueError("No hay datos para visualizar") longitudes = [float(x) for x in longitudes] plt.figure(figsize=(15, 6)) n_bins = max(10, min(50, len(set(longitudes)) // 2)) if len(longitudes) < 2: plt.text( 0.5, 0.5, "Datos insuficientes para visualización", ha="center", va="center", ) return plt.gcf() counts, bins, _ = plt.hist( longitudes, bins=n_bins, density=True, alpha=0.6, color="skyblue", edgecolor="black", ) bin_centers = (bins[:-1] + bins[1:]) / 2 window_size = 5 if len(counts) > window_size: smoothed = np.convolve( counts, np.ones(window_size) / window_size, mode="valid" ) smoothed_x = bin_centers[window_size - 1 :] plt.plot(smoothed_x, smoothed, color="blue", linewidth=2, alpha=0.8) plt.title(self.titulo, fontsize=14, pad=5) # Reduced pad from 20 to 5 plt.xlabel("Longitud del chunk (caracteres)", fontsize=12) plt.ylabel("Densidad", fontsize=12) media = np.mean(longitudes) mediana = np.median(longitudes) desv_std = np.std(longitudes) info_text = ( f"Estadísticas:\n" f"• Media: {media:.1f} caracteres\n" f"• Mediana: {mediana:.1f} caracteres\n" f"• Desv. Estándar: {desv_std:.1f}\n\n" f"Chunks Extremos:\n\n" f"• Más corto: {chunks_extremos['mas_corto']['longitud']} caracteres\n" f" ID para buscar en dashboard: \n" f" {chunks_extremos['mas_corto']['id']}\n" f" Página: {chunks_extremos['mas_corto'].get('page', 'N/A')}\n" f" Preview: {chunks_extremos['mas_corto']['preview']}\n\n" f"• Más largo: {chunks_extremos['mas_largo']['longitud']} caracteres\n" f" ID para buscar en dashboard: \n" f" {chunks_extremos['mas_largo']['id']}\n" f" Página: {chunks_extremos['mas_largo'].get('page', 'N/A')}\n" f" Preview: {chunks_extremos['mas_largo']['preview']}" ) plt.figtext( 1.02, 0.5, info_text, fontsize=10, bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"), wrap=True, ) # Remove whitespace at the top by adjusting subplots plt.subplots_adjust(top=0.92, bottom=0.1, left=0.08, right=0.75) return plt @app.cell def _(): import marimo as mo logger.setLevel(logging.INFO) return (mo,) @app.cell def _(): pipeline = Pipeline.from_vault( "banortegpt", collection="MayaNormativa", embedding_model="text-embedding-3-large", gemini_model="gemini-1.5-flash", ) return (pipeline,) @app.cell def _(mo): uploads = mo.ui.file(filetypes=[".pdf"], kind="area", multiple=True).form() uploads return (uploads,) @app.cell def _(mo, pipeline, uploads): mo.stop(uploads.value is None) points = [ point for upload in mo.status.progress_bar(uploads.value, remove_on_exit=True) for point in pipeline.run(upload.name, upload.contents) ] return (points,) @app.cell def _(points): ChunkDistGraph(points).show() return @app.cell def _(): # ChunkDistGraph2(points).show() return @app.cell def _(points): import polars as pl pl.from_records([p.payload for p in points]) return @app.cell def _(mo): upload_button = mo.ui.run_button(label="Upload to Qdrant", kind="success") upload_button return (upload_button,) @app.cell def _(mo, pipeline, points, upload_button): mo.stop(upload_button.value is False) pipeline.upload_points(points) return @app.cell def _(): return if __name__ == "__main__": app.run()