forked from innovacion/Mayacontigo
373 lines
14 KiB
Python
373 lines
14 KiB
Python
import logging
|
|
from pathlib import Path
|
|
from typing import Annotated, List, Sequence, Literal, Any, AsyncGenerator
|
|
from typing_extensions import TypedDict
|
|
|
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, AIMessageChunk
|
|
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
|
from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel
|
|
from langgraph.graph.message import add_messages
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
from pydantic import BaseModel, Field
|
|
|
|
from banortegpt.storage.azure_storage import AzureStorage
|
|
from banortegpt.vector.qdrant import AsyncQdrant
|
|
|
|
import api.context as ctx
|
|
from api.config import config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
parent = Path(__file__).parent
|
|
SYSTEM_PROMPT = (parent / "system_prompt.md").read_text()
|
|
AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com"
|
|
|
|
|
|
class get_information(BaseModel):
|
|
"""Search a private repository for information."""
|
|
question: str = Field(..., description="The user question")
|
|
|
|
class MayaPymeState(TypedDict):
|
|
messages: Annotated[Sequence[BaseMessage], add_messages]
|
|
query: str
|
|
search_results: List[dict]
|
|
iteration_count: int
|
|
max_iterations: int
|
|
final_response: str
|
|
|
|
|
|
class MayaPyme:
|
|
system_prompt = SYSTEM_PROMPT
|
|
generation_config = {
|
|
"temperature": config.model_temperature,
|
|
}
|
|
message_limit = config.message_limit
|
|
index = config.vector_index
|
|
limit = config.search_limit
|
|
bucket = config.storage_bucket
|
|
|
|
search = AsyncQdrant.from_config(config)
|
|
llm = AzureAIChatCompletionsModel(
|
|
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}",
|
|
credential=config.openai_api_key,
|
|
).bind_tools([get_information])
|
|
embedder = AzureAIEmbeddingsModel(
|
|
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}",
|
|
credential=config.openai_api_key,
|
|
)
|
|
storage = AzureStorage.from_config(config)
|
|
|
|
def __init__(self) -> None:
|
|
self.tool_map = {"get_information": self.get_information}
|
|
self.memory = MemorySaver()
|
|
self.graph = self._build_complete_langgraph()
|
|
|
|
def build_response(self, payloads):
|
|
"""Mejorado con más info que OCP original"""
|
|
preface = ["Recuerda citar las referencias en el formato: texto[1]."]
|
|
template = "------ REFERENCIA {index} ----- \n\n{content}\n\n**Fuente:** {source_info}"
|
|
|
|
filled_templates = []
|
|
for idx, payload in enumerate(payloads):
|
|
content = payload.get("content", "") or payload.get("page_content", "")
|
|
metadata = payload.get("metadata", {})
|
|
|
|
source_info = ""
|
|
if metadata:
|
|
file_name = metadata.get("file_name", "") or metadata.get("file", "")
|
|
page = metadata.get("page", "")
|
|
|
|
if file_name and page:
|
|
source_info = f"{file_name} - Página {page}"
|
|
elif file_name:
|
|
source_info = file_name
|
|
else:
|
|
source_info = "Documento interno"
|
|
|
|
if not source_info:
|
|
source_info = "No disponible"
|
|
|
|
filled_template = template.format(
|
|
index=idx + 1,
|
|
content=content,
|
|
source_info=source_info
|
|
)
|
|
filled_templates.append(filled_template)
|
|
|
|
return "\n".join(preface + filled_templates)
|
|
|
|
async def get_information(self, question: str):
|
|
logger.info(f"Embedding question: {question} with model {self.embedder.model_name}")
|
|
embedding = await self.embedder.aembed_query(question)
|
|
|
|
results = await self.search.semantic_search(
|
|
embedding=embedding, collection=self.index, limit=self.limit
|
|
)
|
|
|
|
tool_response = self.build_response(results)
|
|
return tool_response, results
|
|
|
|
async def get_shareable_urls(self, metadatas: list):
|
|
reference_urls = []
|
|
image_urls = []
|
|
|
|
for metadata in metadatas:
|
|
if file := metadata.get("file"):
|
|
reference_url = await self.storage.get_file_url(
|
|
filename=file,
|
|
bucket=self.bucket,
|
|
minute_duration=20,
|
|
image=False,
|
|
)
|
|
reference_urls.append(reference_url)
|
|
if image_file := metadata.get("image"):
|
|
image_url = await self.storage.get_file_url(
|
|
filename=image_file,
|
|
bucket=self.bucket,
|
|
minute_duration=20,
|
|
image=True,
|
|
)
|
|
image_urls.append(image_url)
|
|
|
|
return reference_urls, image_urls
|
|
|
|
def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]:
|
|
generation_config_copy = self.generation_config.copy()
|
|
if overwrites:
|
|
for k, v in overwrites.items():
|
|
generation_config_copy[k] = v
|
|
return generation_config_copy
|
|
|
|
async def retrieve_node(self, state: MayaPymeState) -> dict:
|
|
query = state["query"]
|
|
logger.info(f"Retrieving information for: {query}")
|
|
|
|
try:
|
|
_, results = await self.get_information(query)
|
|
logger.info(f"Retrieved {len(results)} results")
|
|
return {
|
|
"search_results": results,
|
|
"iteration_count": state["iteration_count"] + 1
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error in retrieve_node: {e}")
|
|
return {
|
|
"search_results": [],
|
|
"iteration_count": state["iteration_count"] + 1
|
|
}
|
|
|
|
async def evaluate_node(self, state: MayaPymeState) -> dict:
|
|
results = state["search_results"]
|
|
iteration = state["iteration_count"]
|
|
max_iter = state["max_iterations"]
|
|
|
|
has_sufficient_results = len(results) >= 2
|
|
reached_max_iterations = iteration >= max_iter
|
|
|
|
if has_sufficient_results or reached_max_iterations:
|
|
logger.info(f"Stopping search: {len(results)} results, iteration {iteration}")
|
|
return {"continue_search": False}
|
|
else:
|
|
original_query = state["query"]
|
|
new_query = f"circular artículo {original_query}"
|
|
logger.info(f"Continuing search with modified query: {new_query}")
|
|
return {
|
|
"continue_search": True,
|
|
"query": new_query
|
|
}
|
|
|
|
async def generate_node(self, state: MayaPymeState) -> dict:
|
|
results = state["search_results"]
|
|
query = state["query"]
|
|
messages = state.get("messages", [])
|
|
|
|
logger.info(f"Generating response for query: {query}")
|
|
logger.info(f"Using {len(results)} search results")
|
|
logger.info(f"Message history length: {len(messages)}")
|
|
|
|
if not results:
|
|
final_response = "No encontré información sobre este tema en la documentación actual."
|
|
else:
|
|
context_text = self.build_response(results)
|
|
|
|
try:
|
|
history = [
|
|
{"role": "system", "content": self.system_prompt}
|
|
]
|
|
|
|
for msg in messages[:-1]:
|
|
if isinstance(msg, HumanMessage):
|
|
history.append({"role": "user", "content": msg.content})
|
|
elif isinstance(msg, AIMessage):
|
|
history.append({"role": "assistant", "content": msg.content})
|
|
|
|
current_prompt = f"""
|
|
Consulta del usuario: {query}
|
|
|
|
Información encontrada:
|
|
{context_text}
|
|
|
|
INSTRUCCIONES:
|
|
- Reproduce la información EXACTAMENTE como aparece en la documentación
|
|
- NO parafrasees ni interpretes
|
|
- Usa las palabras exactas del documento original
|
|
- Mantén los tiempos verbales originales
|
|
- Mejora el formato con emojis
|
|
- Respuestas extensas y completas
|
|
- Siempre has referencia al articulo, ley o seccion de la pagina donde encontraste la informacion
|
|
- Pregunta por informacion relacionada con la respuesta que requiera al final
|
|
- Considera el contexto de la infomacion anterior si existe
|
|
"""
|
|
|
|
history.append({"role": "user", "content": current_prompt})
|
|
|
|
generation_config = self._generation_config_overwrite(None)
|
|
|
|
response_chunks = []
|
|
async for delta in self.llm.astream(input=history, **generation_config):
|
|
assert isinstance(delta, AIMessageChunk)
|
|
if delta.content:
|
|
response_chunks.append(delta.content)
|
|
|
|
final_response = "".join(response_chunks)
|
|
logger.info(f"Generated response length: {len(final_response)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"ERROR generando respuesta: {e}")
|
|
final_response = f"Error generando respuesta: {str(e)}"
|
|
|
|
return {
|
|
"final_response": final_response,
|
|
"messages": [AIMessage(content=final_response)]
|
|
}
|
|
|
|
def _build_complete_langgraph(self) -> StateGraph:
|
|
workflow = StateGraph(MayaPymeState)
|
|
|
|
workflow.add_node("retrieve", self.retrieve_node)
|
|
workflow.add_node("evaluate", self.evaluate_node)
|
|
workflow.add_node("generate", self.generate_node)
|
|
|
|
workflow.add_edge(START, "retrieve")
|
|
workflow.add_edge("retrieve", "evaluate")
|
|
|
|
workflow.add_conditional_edges(
|
|
"evaluate",
|
|
self._decide_next_step,
|
|
{
|
|
"continue": "retrieve",
|
|
"finish": "generate"
|
|
}
|
|
)
|
|
|
|
workflow.add_edge("generate", END)
|
|
|
|
return workflow.compile(checkpointer=self.memory)
|
|
|
|
def _decide_next_step(self, state: MayaPymeState) -> Literal["continue", "finish"]:
|
|
if state.get("continue_search", False):
|
|
return "continue"
|
|
else:
|
|
return "finish"
|
|
|
|
async def stream(self, history, overwrites: dict | None = None, thread_id: str = "default"):
|
|
"""Stream simplificado que mantiene memoria"""
|
|
|
|
last_message = history[-1] if history else {"content": ""}
|
|
query = last_message.get("content", "")
|
|
|
|
if not query:
|
|
yield "Error: No se encontró pregunta en el historial"
|
|
return
|
|
|
|
logger.info(f"Processing query: {query}")
|
|
logger.info(f"Thread ID: {thread_id}")
|
|
|
|
try:
|
|
config_with_thread = {
|
|
"configurable": {"thread_id": thread_id}
|
|
}
|
|
|
|
initial_state = {
|
|
"messages": [HumanMessage(content=query)],
|
|
"query": query,
|
|
"search_results": [],
|
|
"iteration_count": 0,
|
|
"max_iterations": 2,
|
|
"final_response": ""
|
|
}
|
|
|
|
logger.info("Invoking LangGraph...")
|
|
|
|
final_state = await self.graph.ainvoke(initial_state, config=config_with_thread)
|
|
|
|
logger.info("LangGraph execution completed")
|
|
|
|
self.last_search_results = final_state.get("search_results", [])
|
|
|
|
# Extraer metadatos
|
|
if self.last_search_results:
|
|
try:
|
|
metadatas = []
|
|
for result in self.last_search_results:
|
|
metadata = result.get("metadata", {})
|
|
if metadata:
|
|
metadatas.append(metadata)
|
|
|
|
self.last_metadatas = metadatas
|
|
logger.info(f"Extracted {len(metadatas)} metadata objects")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error extrayendo metadatos: {e}")
|
|
self.last_metadatas = []
|
|
else:
|
|
self.last_metadatas = []
|
|
|
|
final_response = final_state.get("final_response", "Error: No se pudo generar respuesta")
|
|
|
|
chunk_size = 50
|
|
for i in range(0, len(final_response), chunk_size):
|
|
chunk = final_response[i:i + chunk_size]
|
|
ctx.buffer.set(ctx.buffer.get() + chunk)
|
|
yield chunk
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error en stream: {str(e)}"
|
|
logger.error(error_msg)
|
|
yield error_msg
|
|
|
|
async def get_conversation_history(self, thread_id: str = "default") -> List[BaseMessage]:
|
|
try:
|
|
config_with_thread = {
|
|
"configurable": {"thread_id": thread_id}
|
|
}
|
|
|
|
checkpoint = await self.graph.aget_state(config=config_with_thread)
|
|
|
|
if checkpoint and checkpoint.values:
|
|
return checkpoint.values.get("messages", [])
|
|
else:
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error obteniendo historial: {e}")
|
|
return []
|
|
|
|
async def debug_memory(self, thread_id: str = "default"):
|
|
|
|
try:
|
|
history = await self.get_conversation_history(thread_id)
|
|
logger.info(f"MEMORY DEBUG (thread: {thread_id}) ===")
|
|
logger.info(f"Total messages: {len(history)}")
|
|
|
|
for i, msg in enumerate(history):
|
|
msg_type = "USER" if isinstance(msg, HumanMessage) else "ASSISTANT"
|
|
content_preview = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
|
logger.info(f"{i+1}. {msg_type}: {content_preview}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in debug_memory: {e}")
|
|
|
|
|
|
Agent = MayaPyme |