import logging from pathlib import Path from typing import Annotated, List, Sequence, Literal, Any, AsyncGenerator from typing_extensions import TypedDict from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, AIMessageChunk from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel from langgraph.graph.message import add_messages from langgraph.graph import StateGraph, START, END from langgraph.checkpoint.memory import MemorySaver from pydantic import BaseModel, Field from banortegpt.storage.azure_storage import AzureStorage from banortegpt.vector.qdrant import AsyncQdrant import api.context as ctx from api.config import config logger = logging.getLogger(__name__) parent = Path(__file__).parent SYSTEM_PROMPT = (parent / "system_prompt.md").read_text() AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com" class get_information(BaseModel): """Search a private repository for information.""" question: str = Field(..., description="The user question") class MayaPymeState(TypedDict): messages: Annotated[Sequence[BaseMessage], add_messages] query: str search_results: List[dict] iteration_count: int max_iterations: int final_response: str class MayaPyme: system_prompt = SYSTEM_PROMPT generation_config = { "temperature": config.model_temperature, } message_limit = config.message_limit index = config.vector_index limit = config.search_limit bucket = config.storage_bucket search = AsyncQdrant.from_config(config) llm = AzureAIChatCompletionsModel( endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}", credential=config.openai_api_key, ).bind_tools([get_information]) embedder = AzureAIEmbeddingsModel( endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}", credential=config.openai_api_key, ) storage = AzureStorage.from_config(config) def __init__(self) -> None: self.tool_map = {"get_information": self.get_information} self.memory = MemorySaver() self.graph = self._build_complete_langgraph() def build_response(self, payloads): """Mejorado con más info que OCP original""" preface = ["Recuerda citar las referencias en el formato: texto[1]."] template = "------ REFERENCIA {index} ----- \n\n{content}\n\n**Fuente:** {source_info}" filled_templates = [] for idx, payload in enumerate(payloads): content = payload.get("content", "") or payload.get("page_content", "") metadata = payload.get("metadata", {}) source_info = "" if metadata: file_name = metadata.get("file_name", "") or metadata.get("file", "") page = metadata.get("page", "") if file_name and page: source_info = f"{file_name} - Página {page}" elif file_name: source_info = file_name else: source_info = "Documento interno" if not source_info: source_info = "No disponible" filled_template = template.format( index=idx + 1, content=content, source_info=source_info ) filled_templates.append(filled_template) return "\n".join(preface + filled_templates) async def get_information(self, question: str): logger.info(f"Embedding question: {question} with model {self.embedder.model_name}") embedding = await self.embedder.aembed_query(question) results = await self.search.semantic_search( embedding=embedding, collection=self.index, limit=self.limit ) tool_response = self.build_response(results) return tool_response, results async def get_shareable_urls(self, metadatas: list): reference_urls = [] image_urls = [] for metadata in metadatas: if file := metadata.get("file"): reference_url = await self.storage.get_file_url( filename=file, bucket=self.bucket, minute_duration=20, image=False, ) reference_urls.append(reference_url) if image_file := metadata.get("image"): image_url = await self.storage.get_file_url( filename=image_file, bucket=self.bucket, minute_duration=20, image=True, ) image_urls.append(image_url) return reference_urls, image_urls def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]: generation_config_copy = self.generation_config.copy() if overwrites: for k, v in overwrites.items(): generation_config_copy[k] = v return generation_config_copy async def retrieve_node(self, state: MayaPymeState) -> dict: query = state["query"] logger.info(f"Retrieving information for: {query}") try: _, results = await self.get_information(query) logger.info(f"Retrieved {len(results)} results") return { "search_results": results, "iteration_count": state["iteration_count"] + 1 } except Exception as e: logger.error(f"Error in retrieve_node: {e}") return { "search_results": [], "iteration_count": state["iteration_count"] + 1 } async def evaluate_node(self, state: MayaPymeState) -> dict: results = state["search_results"] iteration = state["iteration_count"] max_iter = state["max_iterations"] has_sufficient_results = len(results) >= 2 reached_max_iterations = iteration >= max_iter if has_sufficient_results or reached_max_iterations: logger.info(f"Stopping search: {len(results)} results, iteration {iteration}") return {"continue_search": False} else: original_query = state["query"] new_query = f"circular artículo {original_query}" logger.info(f"Continuing search with modified query: {new_query}") return { "continue_search": True, "query": new_query } async def generate_node(self, state: MayaPymeState) -> dict: results = state["search_results"] query = state["query"] messages = state.get("messages", []) logger.info(f"Generating response for query: {query}") logger.info(f"Using {len(results)} search results") logger.info(f"Message history length: {len(messages)}") if not results: final_response = "No encontré información sobre este tema en la documentación actual." else: context_text = self.build_response(results) try: history = [ {"role": "system", "content": self.system_prompt} ] for msg in messages[:-1]: if isinstance(msg, HumanMessage): history.append({"role": "user", "content": msg.content}) elif isinstance(msg, AIMessage): history.append({"role": "assistant", "content": msg.content}) current_prompt = f""" Consulta del usuario: {query} Información encontrada: {context_text} INSTRUCCIONES: - Reproduce la información EXACTAMENTE como aparece en la documentación - NO parafrasees ni interpretes - Usa las palabras exactas del documento original - Mantén los tiempos verbales originales - Mejora el formato con emojis - Respuestas extensas y completas - Siempre has referencia al articulo, ley o seccion de la pagina donde encontraste la informacion - Pregunta por informacion relacionada con la respuesta que requiera al final - Considera el contexto de la infomacion anterior si existe """ history.append({"role": "user", "content": current_prompt}) generation_config = self._generation_config_overwrite(None) response_chunks = [] async for delta in self.llm.astream(input=history, **generation_config): assert isinstance(delta, AIMessageChunk) if delta.content: response_chunks.append(delta.content) final_response = "".join(response_chunks) logger.info(f"Generated response length: {len(final_response)}") except Exception as e: logger.error(f"ERROR generando respuesta: {e}") final_response = f"Error generando respuesta: {str(e)}" return { "final_response": final_response, "messages": [AIMessage(content=final_response)] } def _build_complete_langgraph(self) -> StateGraph: workflow = StateGraph(MayaPymeState) workflow.add_node("retrieve", self.retrieve_node) workflow.add_node("evaluate", self.evaluate_node) workflow.add_node("generate", self.generate_node) workflow.add_edge(START, "retrieve") workflow.add_edge("retrieve", "evaluate") workflow.add_conditional_edges( "evaluate", self._decide_next_step, { "continue": "retrieve", "finish": "generate" } ) workflow.add_edge("generate", END) return workflow.compile(checkpointer=self.memory) def _decide_next_step(self, state: MayaPymeState) -> Literal["continue", "finish"]: if state.get("continue_search", False): return "continue" else: return "finish" async def stream(self, history, overwrites: dict | None = None, thread_id: str = "default"): """Stream simplificado que mantiene memoria""" last_message = history[-1] if history else {"content": ""} query = last_message.get("content", "") if not query: yield "Error: No se encontró pregunta en el historial" return logger.info(f"Processing query: {query}") logger.info(f"Thread ID: {thread_id}") try: config_with_thread = { "configurable": {"thread_id": thread_id} } initial_state = { "messages": [HumanMessage(content=query)], "query": query, "search_results": [], "iteration_count": 0, "max_iterations": 2, "final_response": "" } logger.info("Invoking LangGraph...") final_state = await self.graph.ainvoke(initial_state, config=config_with_thread) logger.info("LangGraph execution completed") self.last_search_results = final_state.get("search_results", []) # Extraer metadatos if self.last_search_results: try: metadatas = [] for result in self.last_search_results: metadata = result.get("metadata", {}) if metadata: metadatas.append(metadata) self.last_metadatas = metadatas logger.info(f"Extracted {len(metadatas)} metadata objects") except Exception as e: logger.error(f"Error extrayendo metadatos: {e}") self.last_metadatas = [] else: self.last_metadatas = [] final_response = final_state.get("final_response", "Error: No se pudo generar respuesta") chunk_size = 50 for i in range(0, len(final_response), chunk_size): chunk = final_response[i:i + chunk_size] ctx.buffer.set(ctx.buffer.get() + chunk) yield chunk except Exception as e: error_msg = f"Error en stream: {str(e)}" logger.error(error_msg) yield error_msg async def get_conversation_history(self, thread_id: str = "default") -> List[BaseMessage]: try: config_with_thread = { "configurable": {"thread_id": thread_id} } checkpoint = await self.graph.aget_state(config=config_with_thread) if checkpoint and checkpoint.values: return checkpoint.values.get("messages", []) else: return [] except Exception as e: logger.error(f"Error obteniendo historial: {e}") return [] async def debug_memory(self, thread_id: str = "default"): try: history = await self.get_conversation_history(thread_id) logger.info(f"MEMORY DEBUG (thread: {thread_id}) ===") logger.info(f"Total messages: {len(history)}") for i, msg in enumerate(history): msg_type = "USER" if isinstance(msg, HumanMessage) else "ASSISTANT" content_preview = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content logger.info(f"{i+1}. {msg_type}: {content_preview}") except Exception as e: logger.error(f"Error in debug_memory: {e}") Agent = MayaPyme