Files
Rogelio 325f1ef439 ic
2025-10-13 18:16:25 +00:00

373 lines
14 KiB
Python

import logging
from pathlib import Path
from typing import Annotated, List, Sequence, Literal, Any, AsyncGenerator
from typing_extensions import TypedDict
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, AIMessageChunk
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from pydantic import BaseModel, Field
from banortegpt.storage.azure_storage import AzureStorage
from banortegpt.vector.qdrant import AsyncQdrant
import api.context as ctx
from api.config import config
logger = logging.getLogger(__name__)
parent = Path(__file__).parent
SYSTEM_PROMPT = (parent / "system_prompt.md").read_text()
AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com"
class get_information(BaseModel):
"""Search a private repository for information."""
question: str = Field(..., description="The user question")
class MayaNormativaState(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
query: str
search_results: List[dict]
iteration_count: int
max_iterations: int
final_response: str
class MayaNormativa:
system_prompt = SYSTEM_PROMPT
generation_config = {
"temperature": config.model_temperature,
}
message_limit = config.message_limit
index = config.vector_index
limit = config.search_limit
bucket = config.storage_bucket
search = AsyncQdrant.from_config(config)
llm = AzureAIChatCompletionsModel(
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}",
credential=config.openai_api_key,
).bind_tools([get_information])
embedder = AzureAIEmbeddingsModel(
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}",
credential=config.openai_api_key,
)
storage = AzureStorage.from_config(config)
def __init__(self) -> None:
self.tool_map = {"get_information": self.get_information}
self.memory = MemorySaver()
self.graph = self._build_complete_langgraph()
def build_response(self, payloads):
"""Mejorado con más info que OCP original"""
preface = ["Recuerda citar las referencias en el formato: texto[1]."]
template = "------ REFERENCIA {index} ----- \n\n{content}\n\n**Fuente:** {source_info}"
filled_templates = []
for idx, payload in enumerate(payloads):
content = payload.get("content", "") or payload.get("page_content", "")
metadata = payload.get("metadata", {})
source_info = ""
if metadata:
file_name = metadata.get("file_name", "") or metadata.get("file", "")
page = metadata.get("page", "")
if file_name and page:
source_info = f"{file_name} - Página {page}"
elif file_name:
source_info = file_name
else:
source_info = "Documento interno"
if not source_info:
source_info = "No disponible"
filled_template = template.format(
index=idx + 1,
content=content,
source_info=source_info
)
filled_templates.append(filled_template)
return "\n".join(preface + filled_templates)
async def get_information(self, question: str):
logger.info(f"Embedding question: {question} with model {self.embedder.model_name}")
embedding = await self.embedder.aembed_query(question)
results = await self.search.semantic_search(
embedding=embedding, collection=self.index, limit=self.limit
)
tool_response = self.build_response(results)
return tool_response, results
async def get_shareable_urls(self, metadatas: list):
reference_urls = []
image_urls = []
for metadata in metadatas:
if file := metadata.get("file"):
reference_url = await self.storage.get_file_url(
filename=file,
bucket=self.bucket,
minute_duration=20,
image=False,
)
reference_urls.append(reference_url)
if image_file := metadata.get("image"):
image_url = await self.storage.get_file_url(
filename=image_file,
bucket=self.bucket,
minute_duration=20,
image=True,
)
image_urls.append(image_url)
return reference_urls, image_urls
def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]:
generation_config_copy = self.generation_config.copy()
if overwrites:
for k, v in overwrites.items():
generation_config_copy[k] = v
return generation_config_copy
async def retrieve_node(self, state: MayaNormativaState) -> dict:
query = state["query"]
logger.info(f"Retrieving information for: {query}")
try:
_, results = await self.get_information(query)
logger.info(f"Retrieved {len(results)} results")
return {
"search_results": results,
"iteration_count": state["iteration_count"] + 1
}
except Exception as e:
logger.error(f"Error in retrieve_node: {e}")
return {
"search_results": [],
"iteration_count": state["iteration_count"] + 1
}
async def evaluate_node(self, state: MayaNormativaState) -> dict:
results = state["search_results"]
iteration = state["iteration_count"]
max_iter = state["max_iterations"]
has_sufficient_results = len(results) >= 2
reached_max_iterations = iteration >= max_iter
if has_sufficient_results or reached_max_iterations:
logger.info(f"Stopping search: {len(results)} results, iteration {iteration}")
return {"continue_search": False}
else:
original_query = state["query"]
new_query = f"circular artículo {original_query}"
logger.info(f"Continuing search with modified query: {new_query}")
return {
"continue_search": True,
"query": new_query
}
async def generate_node(self, state: MayaNormativaState) -> dict:
results = state["search_results"]
query = state["query"]
messages = state.get("messages", [])
logger.info(f"Generating response for query: {query}")
logger.info(f"Using {len(results)} search results")
logger.info(f"Message history length: {len(messages)}")
if not results:
final_response = "No encontré información sobre este tema en la documentación actual."
else:
context_text = self.build_response(results)
try:
history = [
{"role": "system", "content": self.system_prompt}
]
for msg in messages[:-1]:
if isinstance(msg, HumanMessage):
history.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
history.append({"role": "assistant", "content": msg.content})
current_prompt = f"""
Consulta del usuario: {query}
Información encontrada:
{context_text}
INSTRUCCIONES:
- Reproduce la información EXACTAMENTE como aparece en la documentación
- NO parafrasees ni interpretes
- Usa las palabras exactas del documento original
- Mantén los tiempos verbales originales
- Mejora el formato con emojis
- Respuestas extensas y completas
- Siempre has referencia al articulo, ley o seccion de la pagina donde encontraste la informacion
- Pregunta por informacion relacionada con la respuesta que requiera al final
- Considera el contexto de la infomacion anterior si existe
"""
history.append({"role": "user", "content": current_prompt})
generation_config = self._generation_config_overwrite(None)
response_chunks = []
async for delta in self.llm.astream(input=history, **generation_config):
assert isinstance(delta, AIMessageChunk)
if delta.content:
response_chunks.append(delta.content)
final_response = "".join(response_chunks)
logger.info(f"Generated response length: {len(final_response)}")
except Exception as e:
logger.error(f"ERROR generando respuesta: {e}")
final_response = f"Error generando respuesta: {str(e)}"
return {
"final_response": final_response,
"messages": [AIMessage(content=final_response)]
}
def _build_complete_langgraph(self) -> StateGraph:
workflow = StateGraph(MayaNormativaState)
workflow.add_node("retrieve", self.retrieve_node)
workflow.add_node("evaluate", self.evaluate_node)
workflow.add_node("generate", self.generate_node)
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "evaluate")
workflow.add_conditional_edges(
"evaluate",
self._decide_next_step,
{
"continue": "retrieve",
"finish": "generate"
}
)
workflow.add_edge("generate", END)
return workflow.compile(checkpointer=self.memory)
def _decide_next_step(self, state: MayaNormativaState) -> Literal["continue", "finish"]:
if state.get("continue_search", False):
return "continue"
else:
return "finish"
async def stream(self, history, overwrites: dict | None = None, thread_id: str = "default"):
"""Stream simplificado que mantiene memoria"""
last_message = history[-1] if history else {"content": ""}
query = last_message.get("content", "")
if not query:
yield "Error: No se encontró pregunta en el historial"
return
logger.info(f"Processing query: {query}")
logger.info(f"Thread ID: {thread_id}")
try:
config_with_thread = {
"configurable": {"thread_id": thread_id}
}
initial_state = {
"messages": [HumanMessage(content=query)],
"query": query,
"search_results": [],
"iteration_count": 0,
"max_iterations": 2,
"final_response": ""
}
logger.info("Invoking LangGraph...")
final_state = await self.graph.ainvoke(initial_state, config=config_with_thread)
logger.info("LangGraph execution completed")
self.last_search_results = final_state.get("search_results", [])
# Extraer metadatos
if self.last_search_results:
try:
metadatas = []
for result in self.last_search_results:
metadata = result.get("metadata", {})
if metadata:
metadatas.append(metadata)
self.last_metadatas = metadatas
logger.info(f"Extracted {len(metadatas)} metadata objects")
except Exception as e:
logger.error(f"Error extrayendo metadatos: {e}")
self.last_metadatas = []
else:
self.last_metadatas = []
final_response = final_state.get("final_response", "Error: No se pudo generar respuesta")
chunk_size = 50
for i in range(0, len(final_response), chunk_size):
chunk = final_response[i:i + chunk_size]
ctx.buffer.set(ctx.buffer.get() + chunk)
yield chunk
except Exception as e:
error_msg = f"Error en stream: {str(e)}"
logger.error(error_msg)
yield error_msg
async def get_conversation_history(self, thread_id: str = "default") -> List[BaseMessage]:
try:
config_with_thread = {
"configurable": {"thread_id": thread_id}
}
checkpoint = await self.graph.aget_state(config=config_with_thread)
if checkpoint and checkpoint.values:
return checkpoint.values.get("messages", [])
else:
return []
except Exception as e:
logger.error(f"Error obteniendo historial: {e}")
return []
async def debug_memory(self, thread_id: str = "default"):
try:
history = await self.get_conversation_history(thread_id)
logger.info(f"MEMORY DEBUG (thread: {thread_id}) ===")
logger.info(f"Total messages: {len(history)}")
for i, msg in enumerate(history):
msg_type = "USER" if isinstance(msg, HumanMessage) else "ASSISTANT"
content_preview = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
logger.info(f"{i+1}. {msg_type}: {content_preview}")
except Exception as e:
logger.error(f"Error in debug_memory: {e}")
Agent = MayaNormativa