forked from innovacion/Mayacontigo
ic
This commit is contained in:
131
apps/riesgos/api/agent.py
Normal file
131
apps/riesgos/api/agent.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from banortegpt.storage.azure_storage import AzureStorage
|
||||
from banortegpt.vector.qdrant import AsyncQdrant
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from qdrant_client import models
|
||||
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
||||
from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from . import config, context
|
||||
|
||||
parent = Path(__file__).parent
|
||||
SYSTEM_PROMPT = (parent / "system_prompt.md").read_text()
|
||||
|
||||
AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com"
|
||||
|
||||
class get_information(BaseModel):
|
||||
"""Buscar informacion relevante de documentos de Banorte sobre un sistema en particular."""
|
||||
|
||||
question: str = Field(..., description="La pregunta del usuario, reescrita para ser comprensible fuera de contexto.")
|
||||
system: str = Field(..., description="El sistema del cual se buscara informacion. Puede ser uno de los siguientes: ['ML','SACS','ED','CARATULA', 'SICRED']. 'ML' es 'Master de lineas', 'ED' es 'Expediente Digital'")
|
||||
|
||||
class MayaRiesgos:
|
||||
system_prompt = SYSTEM_PROMPT
|
||||
generation_config = {
|
||||
"temperature": config.model_temperature,
|
||||
}
|
||||
message_limit = config.message_limit
|
||||
index = config.vector_index
|
||||
limit = config.search_limit
|
||||
bucket = config.storage_bucket
|
||||
|
||||
search = AsyncQdrant.from_config(config)
|
||||
llm = AzureAIChatCompletionsModel(
|
||||
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}",
|
||||
credential=config.openai_api_key,
|
||||
).bind_tools([get_information])
|
||||
embedder = AzureAIEmbeddingsModel(
|
||||
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}",
|
||||
credential=config.openai_api_key,
|
||||
)
|
||||
storage = AzureStorage.from_config(config)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tool_map = {"get_information": self.get_information}
|
||||
|
||||
def build_response(self, payloads):
|
||||
template = "------ REFERENCIA {index} ----- \n\n{content}"
|
||||
|
||||
filled_templates = [
|
||||
template.format(index=idx, content=payload["content"])
|
||||
for idx, payload in enumerate(payloads)
|
||||
]
|
||||
|
||||
return "\n".join(filled_templates)
|
||||
|
||||
async def get_information(
|
||||
self, question: str, system: Literal["ML", "SACS", "ED", "7.1", "SICRED"]
|
||||
):
|
||||
embedding = await self.embedder.aembed_query(question)
|
||||
|
||||
conditions = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="system",
|
||||
match=models.MatchAny(any=["ALL", system]),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
payloads = await self.search.semantic_search(
|
||||
collection=self.index,
|
||||
embedding=embedding,
|
||||
limit=self.limit,
|
||||
conditions=conditions,
|
||||
)
|
||||
|
||||
tool_response = self.build_response(payloads)
|
||||
|
||||
return tool_response, payloads
|
||||
|
||||
async def get_shareable_urls(self, metadatas: list):
|
||||
reference_urls = []
|
||||
image_urls = []
|
||||
|
||||
for metadata in metadatas:
|
||||
if (pagina := metadata.get("pagina")) and (
|
||||
archivo := metadata.get("archivo")
|
||||
):
|
||||
image_file = f"{pagina}_{archivo}.png"
|
||||
|
||||
image_url = await self.storage.get_file_url(
|
||||
filename=image_file,
|
||||
bucket=self.bucket,
|
||||
minute_duration=20,
|
||||
image=True,
|
||||
)
|
||||
image_urls.append(image_url)
|
||||
|
||||
return reference_urls, image_urls
|
||||
|
||||
def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]:
|
||||
generation_config_copy = self.generation_config.copy()
|
||||
if overwrites:
|
||||
for k, v in overwrites.items():
|
||||
generation_config_copy[k] = v
|
||||
return generation_config_copy
|
||||
|
||||
async def stream(self, history, overwrites: dict | None = None):
|
||||
generation_config = self._generation_config_overwrite(overwrites)
|
||||
|
||||
async for chunk in self.llm.astream(input=history, **generation_config):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if call := chunk.tool_call_chunks:
|
||||
if tool_id := call[0].get("id"):
|
||||
context.tool_id.set(tool_id)
|
||||
if name := call[0].get("name"):
|
||||
context.tool_name.set(name)
|
||||
if args := call[0].get("args"):
|
||||
context.tool_buffer.set(context.tool_buffer.get() + args)
|
||||
else:
|
||||
if buffer := chunk.content:
|
||||
assert isinstance(buffer, str)
|
||||
context.buffer.set(context.buffer.get() + buffer)
|
||||
yield buffer
|
||||
|
||||
async def generate(self, history, overwrites: dict | None = None):
|
||||
generation_config = self._generation_config_overwrite(overwrites)
|
||||
return await self.llm.ainvoke(input=history, **generation_config)
|
||||
Reference in New Issue
Block a user