forked from innovacion/Mayacontigo
134 lines
6.7 KiB
Python
134 lines
6.7 KiB
Python
import json
|
|
import logging
|
|
from typing import Any
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
|
from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel
|
|
from langchain_core.messages.ai import AIMessageChunk
|
|
from langchain_qdrant import QdrantVectorStore
|
|
|
|
import api.context as ctx
|
|
from api.config import config
|
|
from api.prompts import ORCHESTRATOR_PROMPT, TOOL_SCHEMAS
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com"
|
|
SQLITE_DB_PATH = Path(__file__).parent / "db.sqlite"
|
|
|
|
|
|
class MayaInversionistas:
|
|
system_prompt = ORCHESTRATOR_PROMPT
|
|
generation_config = {
|
|
"temperature": config.model_temperature,
|
|
}
|
|
message_limit = config.message_limit
|
|
index = config.vector_index
|
|
limit = config.search_limit
|
|
bucket = config.storage_bucket
|
|
|
|
llm = AzureAIChatCompletionsModel(
|
|
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}",
|
|
credential=config.openai_api_key,
|
|
).bind_tools(TOOL_SCHEMAS)
|
|
embedder = AzureAIEmbeddingsModel(
|
|
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}",
|
|
credential=config.openai_api_key,
|
|
)
|
|
search = QdrantVectorStore.from_existing_collection(
|
|
embedding=embedder,
|
|
collection_name=index,
|
|
url=config.qdrant_url,
|
|
api_key=config.qdrant_api_key,
|
|
)
|
|
|
|
def __init__(self) -> None:
|
|
self.tool_map = {
|
|
"getGFNORTEData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "gf_norte"),
|
|
"getBanorteConsolidadoData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_consolidado"),
|
|
"getAlmacenadoraConsolidadoData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "almacenadora_consolidado"),
|
|
"getArrendadoraFactorConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "arrendadora_factor_consolidado"),
|
|
"getCasadeBolsaConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "casa_bolsa_conosolidado"),
|
|
"getOperadoradeFondos": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "op_fondos"),
|
|
"getSectorBursatil": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_bursatil"),
|
|
"getSectorBAPConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_bap_consolidado"),
|
|
"getSeguros": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "seguros"),
|
|
"getPensiones": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "pensiones"),
|
|
"getBineo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "bineo"),
|
|
"getSectorBanca": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_banca"),
|
|
"getHolding": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "holding"),
|
|
"getBanorteFinancialServices": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_financial_services"),
|
|
"getFideicomisoBursaGEM": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "fideicomiso_bursa_gem"),
|
|
"getTarjetasdelFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "tarjetas_del_futuro"),
|
|
"getAfore": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "afore"),
|
|
"getBanorteFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_futuro"),
|
|
"getSegurosSinBanorteFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "seguros_sin_banorte_futuro"),
|
|
"getInformationalData": self.run_qdrant_tool,
|
|
}
|
|
|
|
@staticmethod
|
|
def build_response(results: list[dict]) -> str:
|
|
return (
|
|
"I have retrieved the following results from the database:\n"
|
|
+ json.dumps(results)
|
|
+ "\nPara mayor información consultar el Reporte de Resultados Trimestral (URL: https://investors.banorte.com/es/financial-information/quarterly-reports)"
|
|
)
|
|
|
|
async def run_sqlite_tool(self, year: int, quarter: int, concept: str, table: str):
|
|
results = await self.get_data_from_sqlite(year, quarter, concept, table)
|
|
data = [dict(row) for row in results]
|
|
return self.build_response(data)
|
|
|
|
async def run_qdrant_tool(self, question: str):
|
|
logger.info(
|
|
f"Embedding question: {question} with model {self.embedder.model_name}"
|
|
)
|
|
results = self.search.similarity_search(question)
|
|
data = [dict(row.metadata) for row in results]
|
|
tool_response = self.build_response(data)
|
|
return tool_response
|
|
|
|
@staticmethod
|
|
async def get_data_from_sqlite(year: int, quarter: int, concept: str, table: str):
|
|
async with aiosqlite.connect(SQLITE_DB_PATH) as db:
|
|
query = """
|
|
SELECT * FROM {}
|
|
WHERE year = ? AND trim = ? AND concept = ?
|
|
""".format(table)
|
|
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute(query, (year, quarter, concept))
|
|
rows = await cursor.fetchall()
|
|
return rows
|
|
|
|
def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]:
|
|
generation_config_copy = self.generation_config.copy()
|
|
if overwrites:
|
|
for k, v in overwrites.items():
|
|
generation_config_copy[k] = v
|
|
return generation_config_copy
|
|
|
|
async def stream(self, history, overwrites: dict | None = None):
|
|
generation_config = self._generation_config_overwrite(overwrites)
|
|
|
|
async for chunk in self.llm.astream(input=history, **generation_config):
|
|
assert isinstance(chunk, AIMessageChunk)
|
|
if call := chunk.tool_call_chunks:
|
|
if tool_id := call[0].get("id"):
|
|
ctx.tool_id.set(tool_id)
|
|
if name := call[0].get("name"):
|
|
ctx.tool_name.set(name)
|
|
if args := call[0].get("args"):
|
|
ctx.tool_buffer.set(ctx.tool_buffer.get() + args)
|
|
else:
|
|
if buffer := chunk.content:
|
|
assert isinstance(buffer, str)
|
|
ctx.buffer.set(ctx.buffer.get() + buffer)
|
|
yield buffer
|
|
|
|
async def generate(self, history, overwrites: dict | None = None):
|
|
generation_config = self._generation_config_overwrite(overwrites)
|
|
return await self.llm.ainvoke(input=history, **generation_config)
|