577 lines
22 KiB
Python
577 lines
22 KiB
Python
import hashlib
|
|
import json
|
|
import os
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
from typing import Annotated, List
|
|
|
|
import tiktoken
|
|
import typer
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain_core.documents import Document as LangchainDocument
|
|
from llm.vertex_ai import VertexAILLM
|
|
from pdf2image import convert_from_path
|
|
from pypdf import PdfReader
|
|
from rag_eval.config import Settings
|
|
|
|
from .base_chunker import BaseChunker, Document
|
|
|
|
|
|
class TokenManager:
|
|
"""Manages token counting and truncation."""
|
|
|
|
def __init__(self, model_name: str = "gpt-3.5-turbo"):
|
|
try:
|
|
self.encoding = tiktoken.encoding_for_model(model_name)
|
|
except KeyError:
|
|
self.encoding = tiktoken.get_encoding("cl100k_base")
|
|
|
|
def count_tokens(self, text: str) -> int:
|
|
return len(self.encoding.encode(text))
|
|
|
|
def truncate_to_tokens(
|
|
self, text: str, max_tokens: int, preserve_sentences: bool = True
|
|
) -> str:
|
|
tokens = self.encoding.encode(text)
|
|
|
|
if len(tokens) <= max_tokens:
|
|
return text
|
|
|
|
truncated_tokens = tokens[:max_tokens]
|
|
truncated_text = self.encoding.decode(truncated_tokens)
|
|
|
|
if preserve_sentences:
|
|
last_period = truncated_text.rfind(".")
|
|
if last_period > len(truncated_text) * 0.7:
|
|
return truncated_text[: last_period + 1]
|
|
|
|
return truncated_text
|
|
|
|
|
|
class OptimizedChunkProcessor:
|
|
"""Uses an LLM to merge and enhance text chunks."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
target_tokens: int = 800,
|
|
chunks_per_batch: int = 5,
|
|
gemini_client: VertexAILLM | None = None,
|
|
model_name: str = "gpt-3.5-turbo",
|
|
custom_instructions: str = "",
|
|
):
|
|
self.model = model
|
|
self.client = gemini_client
|
|
self.chunks_per_batch = chunks_per_batch
|
|
self.max_tokens = max_tokens
|
|
self.target_tokens = target_tokens
|
|
self.token_manager = TokenManager(model_name)
|
|
self.custom_instructions = custom_instructions
|
|
self._merge_cache = {}
|
|
self._enhance_cache = {}
|
|
|
|
def _get_cache_key(self, text: str) -> str:
|
|
combined = text + self.custom_instructions
|
|
return hashlib.md5(combined.encode()).hexdigest()[:16]
|
|
|
|
def should_merge_chunks(self, chunk1: str, chunk2: str) -> bool:
|
|
cache_key = f"{self._get_cache_key(chunk1)}_{self._get_cache_key(chunk2)}"
|
|
if cache_key in self._merge_cache:
|
|
return self._merge_cache[cache_key]
|
|
|
|
try:
|
|
combined_text = f"{chunk1}\n\n{chunk2}"
|
|
combined_tokens = self.token_manager.count_tokens(combined_text)
|
|
|
|
if combined_tokens > self.max_tokens:
|
|
self._merge_cache[cache_key] = False
|
|
return False
|
|
|
|
if self.client:
|
|
base_prompt = f"""Analiza estos dos fragmentos de texto y determina si deben unirse.
|
|
|
|
LÍMITES ESTRICTOS:
|
|
- Tokens combinados: {combined_tokens}/{self.max_tokens}
|
|
- Solo unir si hay continuidad semántica clara
|
|
|
|
Criterios de unión:
|
|
1. El primer fragmento termina abruptamente
|
|
2. El segundo fragmento continúa la misma idea/concepto
|
|
3. La unión mejora la coherencia del contenido
|
|
4. Exceder {self.max_tokens} tokens, SOLAMENTE si es necesario para mantener el contexto"""
|
|
|
|
base_prompt += f"""
|
|
|
|
Responde SOLO 'SI' o 'NO'.
|
|
|
|
Fragmento 1 ({self.token_manager.count_tokens(chunk1)} tokens):
|
|
{chunk1[:500]}...
|
|
|
|
Fragmento 2 ({self.token_manager.count_tokens(chunk2)} tokens):
|
|
{chunk2[:500]}..."""
|
|
|
|
response = self.client.generate(self.model, base_prompt).text
|
|
result = response.strip().upper() == "SI"
|
|
self._merge_cache[cache_key] = result
|
|
return result
|
|
|
|
result = (
|
|
not chunk1.rstrip().endswith((".", "!", "?"))
|
|
and combined_tokens <= self.target_tokens
|
|
)
|
|
self._merge_cache[cache_key] = result
|
|
return result
|
|
|
|
except Exception as e:
|
|
print(f"Error analizando chunks para merge: {e}")
|
|
self._merge_cache[cache_key] = False
|
|
return False
|
|
|
|
def enhance_chunk(self, chunk_text: str) -> str:
|
|
cache_key = self._get_cache_key(chunk_text)
|
|
if cache_key in self._enhance_cache:
|
|
return self._enhance_cache[cache_key]
|
|
|
|
current_tokens = self.token_manager.count_tokens(chunk_text)
|
|
|
|
try:
|
|
if self.client and current_tokens < self.max_tokens:
|
|
base_prompt = f"""Optimiza este texto siguiendo estas reglas ESTRICTAS:
|
|
|
|
LÍMITES DE TOKENS:
|
|
- Actual: {current_tokens} tokens
|
|
- Máximo permitido: {self.max_tokens} tokens
|
|
- Objetivo: {self.target_tokens} tokens
|
|
|
|
REGLAS FUNDAMENTALES:
|
|
NO exceder {self.max_tokens} tokens bajo ninguna circunstancia
|
|
Mantener TODA la información esencial y metadatos
|
|
NO cambiar términos técnicos o palabras clave
|
|
Asegurar oraciones completas y coherentes
|
|
Optimizar claridad y estructura sin añadir contenido
|
|
SOLO devuelve el texto no agregues conclusiones NUNCA
|
|
|
|
Si el texto está cerca del límite, NO expandir. Solo mejorar estructura."""
|
|
|
|
if self.custom_instructions.strip():
|
|
base_prompt += (
|
|
f"\n\nINSTRUCCIONES ADICIONALES:\n{self.custom_instructions}"
|
|
)
|
|
|
|
base_prompt += f"\n\nTexto a optimizar:\n{chunk_text}"
|
|
|
|
response = self.client.generate(self.model, base_prompt).text
|
|
enhanced_text = response.strip()
|
|
|
|
enhanced_tokens = self.token_manager.count_tokens(enhanced_text)
|
|
if enhanced_tokens > self.max_tokens:
|
|
print(
|
|
f"Advertencia: Texto optimizado excede límite ({enhanced_tokens} > {self.max_tokens})"
|
|
)
|
|
enhanced_text = self.token_manager.truncate_to_tokens(
|
|
enhanced_text, self.max_tokens
|
|
)
|
|
|
|
self._enhance_cache[cache_key] = enhanced_text
|
|
return enhanced_text
|
|
else:
|
|
if current_tokens > self.max_tokens:
|
|
truncated = self.token_manager.truncate_to_tokens(
|
|
chunk_text, self.max_tokens
|
|
)
|
|
self._enhance_cache[cache_key] = truncated
|
|
return truncated
|
|
|
|
self._enhance_cache[cache_key] = chunk_text
|
|
return chunk_text
|
|
|
|
except Exception as e:
|
|
print(f"Error procesando chunk: {e}")
|
|
if current_tokens > self.max_tokens:
|
|
truncated = self.token_manager.truncate_to_tokens(
|
|
chunk_text, self.max_tokens
|
|
)
|
|
self._enhance_cache[cache_key] = truncated
|
|
return truncated
|
|
|
|
self._enhance_cache[cache_key] = chunk_text
|
|
return chunk_text
|
|
|
|
def process_chunks_batch(
|
|
self, chunks: List[LangchainDocument], merge_related: bool = False
|
|
) -> List[LangchainDocument]:
|
|
processed_chunks = []
|
|
total_chunks = len(chunks)
|
|
|
|
print(f"Procesando {total_chunks} chunks en lotes de {self.chunks_per_batch}")
|
|
if self.custom_instructions:
|
|
print(
|
|
f"Con instrucciones personalizadas: {self.custom_instructions[:100]}..."
|
|
)
|
|
|
|
i = 0
|
|
while i < len(chunks):
|
|
batch_start = time.time()
|
|
current_chunk = chunks[i]
|
|
merged_content = current_chunk.page_content
|
|
original_tokens = self.token_manager.count_tokens(merged_content)
|
|
|
|
if merge_related and i < len(chunks) - 1:
|
|
merge_count = 0
|
|
while i + merge_count < len(chunks) - 1 and self.should_merge_chunks(
|
|
merged_content, chunks[i + merge_count + 1].page_content
|
|
):
|
|
merge_count += 1
|
|
merged_content += "\n\n" + chunks[i + merge_count].page_content
|
|
print(f" Uniendo chunk {i + 1} con chunk {i + merge_count + 1}")
|
|
|
|
i += merge_count
|
|
|
|
print(f"\nProcesando chunk {i + 1}/{total_chunks}")
|
|
print(f" Tokens originales: {original_tokens}")
|
|
|
|
enhanced_content = self.enhance_chunk(merged_content)
|
|
final_tokens = self.token_manager.count_tokens(enhanced_content)
|
|
|
|
processed_chunks.append(
|
|
LangchainDocument(
|
|
page_content=enhanced_content,
|
|
metadata={
|
|
**current_chunk.metadata,
|
|
"final_tokens": final_tokens,
|
|
},
|
|
)
|
|
)
|
|
|
|
print(f" Tokens finales: {final_tokens}")
|
|
print(f" Tiempo de procesamiento: {time.time() - batch_start:.2f}s")
|
|
|
|
i += 1
|
|
|
|
if i % self.chunks_per_batch == 0 and i < len(chunks):
|
|
print(f"\nCompletados {i}/{total_chunks} chunks")
|
|
time.sleep(0.1)
|
|
|
|
return processed_chunks
|
|
|
|
|
|
class LLMChunker(BaseChunker):
|
|
"""Implements a chunker that uses an LLM to optimize PDF and text content."""
|
|
|
|
def __init__(
|
|
self,
|
|
output_dir: str,
|
|
model: str,
|
|
max_tokens: int = 1000,
|
|
target_tokens: int = 800,
|
|
gemini_client: VertexAILLM | None = None,
|
|
custom_instructions: str = "",
|
|
extract_images: bool = True,
|
|
max_workers: int = 4,
|
|
chunk_size: int = 1000,
|
|
chunk_overlap: int = 200,
|
|
merge_related: bool = True,
|
|
):
|
|
self.output_dir = output_dir
|
|
self.model = model
|
|
self.client = gemini_client
|
|
self.max_workers = max_workers
|
|
self.token_manager = TokenManager()
|
|
self.custom_instructions = custom_instructions
|
|
self.extract_images = extract_images
|
|
self.chunk_size = chunk_size
|
|
self.chunk_overlap = chunk_overlap
|
|
self.merge_related = merge_related
|
|
self._format_cache = {}
|
|
|
|
self.chunk_processor = OptimizedChunkProcessor(
|
|
model=self.model,
|
|
max_tokens=max_tokens,
|
|
target_tokens=target_tokens,
|
|
gemini_client=gemini_client,
|
|
custom_instructions=custom_instructions,
|
|
)
|
|
|
|
def process_text(self, text: str) -> List[Document]:
|
|
"""Processes raw text using the LLM optimizer."""
|
|
print("\n=== Iniciando procesamiento de texto ===")
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=self.chunk_size,
|
|
chunk_overlap=self.chunk_overlap,
|
|
length_function=self.token_manager.count_tokens,
|
|
separators=["\n\n", "\n", ". ", " ", ""],
|
|
)
|
|
# Create dummy LangchainDocuments for compatibility with process_chunks_batch
|
|
langchain_docs = text_splitter.create_documents([text])
|
|
|
|
processed_docs = self.chunk_processor.process_chunks_batch(
|
|
langchain_docs, self.merge_related
|
|
)
|
|
|
|
# Convert from LangchainDocument to our Document TypedDict
|
|
final_documents: List[Document] = [
|
|
{"page_content": doc.page_content, "metadata": doc.metadata}
|
|
for doc in processed_docs
|
|
]
|
|
print(
|
|
f"\n=== Procesamiento de texto completado: {len(final_documents)} chunks creados ==="
|
|
)
|
|
return final_documents
|
|
|
|
def process_path(self, path: Path) -> List[Document]:
|
|
"""Processes a PDF file, extracts text and images, and optimizes chunks."""
|
|
overall_start = time.time()
|
|
print(f"\n=== Iniciando procesamiento optimizado de PDF: {path.name} ===")
|
|
# ... (rest of the logic from process_pdf_optimized)
|
|
if not os.path.exists(self.output_dir):
|
|
os.makedirs(self.output_dir)
|
|
|
|
print("\n1. Creando chunks del PDF...")
|
|
chunks = self._create_optimized_chunks(
|
|
str(path), self.chunk_size, self.chunk_overlap
|
|
)
|
|
print(f" Total chunks creados: {len(chunks)}")
|
|
|
|
pages_to_extract = set()
|
|
if self.extract_images:
|
|
print("\n2. Detectando formatos especiales...")
|
|
format_results = self.detect_special_format_batch(chunks)
|
|
for i, has_special_format in format_results.items():
|
|
if has_special_format:
|
|
page_number = chunks[i].metadata.get("page")
|
|
if page_number:
|
|
pages_to_extract.add(page_number)
|
|
print(f" Páginas con formato especial: {sorted(pages_to_extract)}")
|
|
|
|
if self.extract_images and pages_to_extract:
|
|
print(f"\n3. Extrayendo {len(pages_to_extract)} páginas como imágenes...")
|
|
self._extract_pages_parallel(str(path), self.output_dir, pages_to_extract)
|
|
|
|
print("\n4. Procesando y optimizando chunks...")
|
|
processed_chunks = self.chunk_processor.process_chunks_batch(
|
|
chunks, self.merge_related
|
|
)
|
|
|
|
if self.extract_images:
|
|
final_chunks = self._add_image_references(
|
|
processed_chunks, pages_to_extract, str(path), self.output_dir
|
|
)
|
|
else:
|
|
final_chunks = processed_chunks
|
|
|
|
total_time = time.time() - overall_start
|
|
print(f"\n=== Procesamiento completado en {total_time:.2f}s ===")
|
|
|
|
# Convert from LangchainDocument to our Document TypedDict
|
|
final_documents: List[Document] = [
|
|
{"page_content": doc.page_content, "metadata": doc.metadata}
|
|
for doc in final_chunks
|
|
]
|
|
return final_documents
|
|
|
|
def detect_special_format_batch(
|
|
self, chunks: List[LangchainDocument]
|
|
) -> dict[int, bool]:
|
|
results = {}
|
|
chunks_to_process = []
|
|
for i, chunk in enumerate(chunks):
|
|
cache_key = hashlib.md5(chunk.page_content.encode()).hexdigest()[:16]
|
|
if cache_key in self._format_cache:
|
|
results[i] = self._format_cache[cache_key]
|
|
else:
|
|
chunks_to_process.append((i, chunk, cache_key))
|
|
|
|
if not chunks_to_process:
|
|
return results
|
|
|
|
if self.client and len(chunks_to_process) > 1:
|
|
with ThreadPoolExecutor(
|
|
max_workers=min(self.max_workers, len(chunks_to_process))
|
|
) as executor:
|
|
futures = {
|
|
executor.submit(self._detect_single_format, chunk): (i, cache_key)
|
|
for i, chunk, cache_key in chunks_to_process
|
|
}
|
|
for future in futures:
|
|
i, cache_key = futures[future]
|
|
try:
|
|
result = future.result()
|
|
results[i] = result
|
|
self._format_cache[cache_key] = result
|
|
except Exception as e:
|
|
print(f"Error procesando chunk {i}: {e}")
|
|
results[i] = False
|
|
else:
|
|
for i, chunk, cache_key in chunks_to_process:
|
|
result = self._detect_single_format(chunk)
|
|
results[i] = result
|
|
self._format_cache[cache_key] = result
|
|
return results
|
|
|
|
def _detect_single_format(self, chunk: LangchainDocument) -> bool:
|
|
if not self.client:
|
|
content = chunk.page_content
|
|
table_indicators = ["│", "├", "┼", "┤", "┬", "┴", "|", "+", "-"]
|
|
has_table_chars = any(char in content for char in table_indicators)
|
|
has_multiple_columns = content.count("\t") > 10 or content.count(" ") > 20
|
|
return has_table_chars or has_multiple_columns
|
|
try:
|
|
prompt = f"""¿Contiene este texto tablas estructuradas, diagramas ASCII, o elementos que requieren formato especial?
|
|
|
|
Responde SOLO 'SI' o 'NO'.
|
|
|
|
Texto:
|
|
{chunk.page_content[:1000]}"""
|
|
response = self.client.generate(self.model, prompt).text
|
|
return response.strip().upper() == "SI"
|
|
except Exception as e:
|
|
print(f"Error detectando formato: {e}")
|
|
return False
|
|
|
|
def _create_optimized_chunks(
|
|
self, pdf_path: str, chunk_size: int, chunk_overlap: int
|
|
) -> List[LangchainDocument]:
|
|
pdf = PdfReader(pdf_path)
|
|
chunks = []
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
length_function=self.token_manager.count_tokens,
|
|
separators=["\n\n", "\n", ". ", " ", ""],
|
|
)
|
|
for page_num, page in enumerate(pdf.pages, 1):
|
|
text = page.extract_text()
|
|
if text.strip():
|
|
page_chunks = text_splitter.create_documents(
|
|
[text],
|
|
metadatas=[
|
|
{
|
|
"page": page_num,
|
|
"file_name": os.path.basename(pdf_path),
|
|
}
|
|
],
|
|
)
|
|
chunks.extend(page_chunks)
|
|
return chunks
|
|
|
|
def _extract_pages_parallel(self, pdf_path: str, output_dir: str, pages: set):
|
|
def extract_single_page(page_number):
|
|
try:
|
|
pdf_filename = os.path.basename(pdf_path)
|
|
image_path = os.path.join(
|
|
output_dir, f"{page_number}_{pdf_filename}.png"
|
|
)
|
|
images = convert_from_path(
|
|
pdf_path,
|
|
first_page=page_number,
|
|
last_page=page_number,
|
|
dpi=150,
|
|
thread_count=1,
|
|
grayscale=False,
|
|
)
|
|
if images:
|
|
images[0].save(image_path, "PNG", optimize=True)
|
|
except Exception as e:
|
|
print(f" Error extrayendo página {page_number}: {e}")
|
|
|
|
with ThreadPoolExecutor(
|
|
max_workers=min(self.max_workers, len(pages))
|
|
) as executor:
|
|
futures = [executor.submit(extract_single_page, page) for page in pages]
|
|
for future in futures:
|
|
future.result() # Wait for completion
|
|
|
|
def _add_image_references(
|
|
self,
|
|
chunks: List[LangchainDocument],
|
|
pages_to_extract: set,
|
|
pdf_path: str,
|
|
output_dir: str,
|
|
) -> List[LangchainDocument]:
|
|
pdf_filename = os.path.basename(pdf_path)
|
|
for chunk in chunks:
|
|
page_number = chunk.metadata.get("page")
|
|
if page_number in pages_to_extract:
|
|
image_path = os.path.join(
|
|
output_dir, f"page_{page_number}_{pdf_filename}.png"
|
|
)
|
|
if os.path.exists(image_path):
|
|
image_reference = (
|
|
f"\n[IMAGEN DISPONIBLE - Página {page_number}: {image_path}]\n"
|
|
)
|
|
chunk.page_content = image_reference + chunk.page_content
|
|
chunk.metadata["has_image"] = True
|
|
chunk.metadata["image_path"] = image_path
|
|
return chunks
|
|
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
@app.command()
|
|
def main(
|
|
pdf_path: Annotated[str, typer.Argument(help="Ruta al archivo PDF")],
|
|
output_dir: Annotated[
|
|
str, typer.Argument(help="Directorio de salida para imágenes y chunks")
|
|
],
|
|
model: Annotated[
|
|
str, typer.Option(help="Modelo a usar para el procesamiento")
|
|
] = "gemini-2.0-flash",
|
|
max_tokens: Annotated[
|
|
int, typer.Option(help="Límite máximo de tokens por chunk")
|
|
] = 950,
|
|
target_tokens: Annotated[
|
|
int, typer.Option(help="Tokens objetivo para optimización")
|
|
] = 800,
|
|
chunk_size: Annotated[int, typer.Option(help="Tamaño base de chunks")] = 1000,
|
|
chunk_overlap: Annotated[int, typer.Option(help="Solapamiento entre chunks")] = 200,
|
|
merge_related: Annotated[
|
|
bool, typer.Option(help="Si unir chunks relacionados")
|
|
] = True,
|
|
custom_instructions: Annotated[
|
|
str, typer.Option(help="Instrucciones adicionales para optimización")
|
|
] = "",
|
|
extract_images: Annotated[
|
|
bool,
|
|
typer.Option(help="Si True, extrae páginas con formato especial como imágenes"),
|
|
] = True,
|
|
):
|
|
"""
|
|
Función principal para procesar PDFs con control completo de tokens.
|
|
"""
|
|
settings = Settings()
|
|
llm = VertexAILLM(
|
|
project=settings.project_id,
|
|
location=settings.location,
|
|
)
|
|
|
|
chunker = LLMChunker(
|
|
output_dir=output_dir,
|
|
model=model,
|
|
max_tokens=max_tokens,
|
|
target_tokens=target_tokens,
|
|
gemini_client=llm,
|
|
custom_instructions=custom_instructions,
|
|
extract_images=extract_images,
|
|
max_workers=4,
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
merge_related=merge_related,
|
|
)
|
|
|
|
documents = chunker.process_path(Path(pdf_path))
|
|
print(f"Processed {len(documents)} documents.")
|
|
|
|
output_file_path = os.path.join(output_dir, "chunked_documents.jsonl")
|
|
with open(output_file_path, "w", encoding="utf-8") as f:
|
|
for doc in documents:
|
|
f.write(json.dumps(doc, ensure_ascii=False) + "\n")
|
|
|
|
print(f"Saved {len(documents)} documents to {output_file_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app()
|