import json import logging from typing import Any from pathlib import Path import aiosqlite from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel from langchain_core.messages.ai import AIMessageChunk from langchain_qdrant import QdrantVectorStore import api.context as ctx from api.config import config from api.prompts import ORCHESTRATOR_PROMPT, TOOL_SCHEMAS logger = logging.getLogger(__name__) AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com" SQLITE_DB_PATH = Path(__file__).parent / "db.sqlite" class MayaInversionistas: system_prompt = ORCHESTRATOR_PROMPT generation_config = { "temperature": config.model_temperature, } message_limit = config.message_limit index = config.vector_index limit = config.search_limit bucket = config.storage_bucket llm = AzureAIChatCompletionsModel( endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}", credential=config.openai_api_key, ).bind_tools(TOOL_SCHEMAS) embedder = AzureAIEmbeddingsModel( endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}", credential=config.openai_api_key, ) search = QdrantVectorStore.from_existing_collection( embedding=embedder, collection_name=index, url=config.qdrant_url, api_key=config.qdrant_api_key, ) def __init__(self) -> None: self.tool_map = { "getGFNORTEData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "gf_norte"), "getBanorteConsolidadoData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_consolidado"), "getAlmacenadoraConsolidadoData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "almacenadora_consolidado"), "getArrendadoraFactorConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "arrendadora_factor_consolidado"), "getCasadeBolsaConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "casa_bolsa_conosolidado"), "getOperadoradeFondos": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "op_fondos"), "getSectorBursatil": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_bursatil"), "getSectorBAPConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_bap_consolidado"), "getSeguros": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "seguros"), "getPensiones": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "pensiones"), "getBineo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "bineo"), "getSectorBanca": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_banca"), "getHolding": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "holding"), "getBanorteFinancialServices": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_financial_services"), "getFideicomisoBursaGEM": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "fideicomiso_bursa_gem"), "getTarjetasdelFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "tarjetas_del_futuro"), "getAfore": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "afore"), "getBanorteFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_futuro"), "getSegurosSinBanorteFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "seguros_sin_banorte_futuro"), "getInformationalData": self.run_qdrant_tool, } @staticmethod def build_response(results: list[dict]) -> str: return ( "I have retrieved the following results from the database:\n" + json.dumps(results) + "\nPara mayor información consultar el Reporte de Resultados Trimestral (URL: https://investors.banorte.com/es/financial-information/quarterly-reports)" ) async def run_sqlite_tool(self, year: int, quarter: int, concept: str, table: str): results = await self.get_data_from_sqlite(year, quarter, concept, table) data = [dict(row) for row in results] return self.build_response(data) async def run_qdrant_tool(self, question: str): logger.info( f"Embedding question: {question} with model {self.embedder.model_name}" ) results = self.search.similarity_search(question) data = [dict(row.metadata) for row in results] tool_response = self.build_response(data) return tool_response @staticmethod async def get_data_from_sqlite(year: int, quarter: int, concept: str, table: str): async with aiosqlite.connect(SQLITE_DB_PATH) as db: query = """ SELECT * FROM {} WHERE year = ? AND trim = ? AND concept = ? """.format(table) db.row_factory = aiosqlite.Row cursor = await db.execute(query, (year, quarter, concept)) rows = await cursor.fetchall() return rows def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]: generation_config_copy = self.generation_config.copy() if overwrites: for k, v in overwrites.items(): generation_config_copy[k] = v return generation_config_copy async def stream(self, history, overwrites: dict | None = None): generation_config = self._generation_config_overwrite(overwrites) async for chunk in self.llm.astream(input=history, **generation_config): assert isinstance(chunk, AIMessageChunk) if call := chunk.tool_call_chunks: if tool_id := call[0].get("id"): ctx.tool_id.set(tool_id) if name := call[0].get("name"): ctx.tool_name.set(name) if args := call[0].get("args"): ctx.tool_buffer.set(ctx.tool_buffer.get() + args) else: if buffer := chunk.content: assert isinstance(buffer, str) ctx.buffer.set(ctx.buffer.get() + buffer) yield buffer async def generate(self, history, overwrites: dict | None = None): generation_config = self._generation_config_overwrite(overwrites) return await self.llm.ainvoke(input=history, **generation_config)