Files
Mayacontigo/apps/inversionistas/api/agent.py
Rogelio 325f1ef439 ic
2025-10-13 18:16:25 +00:00

134 lines
6.7 KiB
Python

import json
import logging
from typing import Any
from pathlib import Path
import aiosqlite
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel
from langchain_core.messages.ai import AIMessageChunk
from langchain_qdrant import QdrantVectorStore
import api.context as ctx
from api.config import config
from api.prompts import ORCHESTRATOR_PROMPT, TOOL_SCHEMAS
logger = logging.getLogger(__name__)
AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com"
SQLITE_DB_PATH = Path(__file__).parent / "db.sqlite"
class MayaInversionistas:
system_prompt = ORCHESTRATOR_PROMPT
generation_config = {
"temperature": config.model_temperature,
}
message_limit = config.message_limit
index = config.vector_index
limit = config.search_limit
bucket = config.storage_bucket
llm = AzureAIChatCompletionsModel(
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}",
credential=config.openai_api_key,
).bind_tools(TOOL_SCHEMAS)
embedder = AzureAIEmbeddingsModel(
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}",
credential=config.openai_api_key,
)
search = QdrantVectorStore.from_existing_collection(
embedding=embedder,
collection_name=index,
url=config.qdrant_url,
api_key=config.qdrant_api_key,
)
def __init__(self) -> None:
self.tool_map = {
"getGFNORTEData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "gf_norte"),
"getBanorteConsolidadoData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_consolidado"),
"getAlmacenadoraConsolidadoData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "almacenadora_consolidado"),
"getArrendadoraFactorConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "arrendadora_factor_consolidado"),
"getCasadeBolsaConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "casa_bolsa_conosolidado"),
"getOperadoradeFondos": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "op_fondos"),
"getSectorBursatil": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_bursatil"),
"getSectorBAPConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_bap_consolidado"),
"getSeguros": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "seguros"),
"getPensiones": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "pensiones"),
"getBineo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "bineo"),
"getSectorBanca": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_banca"),
"getHolding": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "holding"),
"getBanorteFinancialServices": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_financial_services"),
"getFideicomisoBursaGEM": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "fideicomiso_bursa_gem"),
"getTarjetasdelFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "tarjetas_del_futuro"),
"getAfore": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "afore"),
"getBanorteFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_futuro"),
"getSegurosSinBanorteFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "seguros_sin_banorte_futuro"),
"getInformationalData": self.run_qdrant_tool,
}
@staticmethod
def build_response(results: list[dict]) -> str:
return (
"I have retrieved the following results from the database:\n"
+ json.dumps(results)
+ "\nPara mayor información consultar el Reporte de Resultados Trimestral (URL: https://investors.banorte.com/es/financial-information/quarterly-reports)"
)
async def run_sqlite_tool(self, year: int, quarter: int, concept: str, table: str):
results = await self.get_data_from_sqlite(year, quarter, concept, table)
data = [dict(row) for row in results]
return self.build_response(data)
async def run_qdrant_tool(self, question: str):
logger.info(
f"Embedding question: {question} with model {self.embedder.model_name}"
)
results = self.search.similarity_search(question)
data = [dict(row.metadata) for row in results]
tool_response = self.build_response(data)
return tool_response
@staticmethod
async def get_data_from_sqlite(year: int, quarter: int, concept: str, table: str):
async with aiosqlite.connect(SQLITE_DB_PATH) as db:
query = """
SELECT * FROM {}
WHERE year = ? AND trim = ? AND concept = ?
""".format(table)
db.row_factory = aiosqlite.Row
cursor = await db.execute(query, (year, quarter, concept))
rows = await cursor.fetchall()
return rows
def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]:
generation_config_copy = self.generation_config.copy()
if overwrites:
for k, v in overwrites.items():
generation_config_copy[k] = v
return generation_config_copy
async def stream(self, history, overwrites: dict | None = None):
generation_config = self._generation_config_overwrite(overwrites)
async for chunk in self.llm.astream(input=history, **generation_config):
assert isinstance(chunk, AIMessageChunk)
if call := chunk.tool_call_chunks:
if tool_id := call[0].get("id"):
ctx.tool_id.set(tool_id)
if name := call[0].get("name"):
ctx.tool_name.set(name)
if args := call[0].get("args"):
ctx.tool_buffer.set(ctx.tool_buffer.get() + args)
else:
if buffer := chunk.content:
assert isinstance(buffer, str)
ctx.buffer.set(ctx.buffer.get() + buffer)
yield buffer
async def generate(self, history, overwrites: dict | None = None):
generation_config = self._generation_config_overwrite(overwrites)
return await self.llm.ainvoke(input=history, **generation_config)