forked from innovacion/Mayacontigo
ic
This commit is contained in:
133
apps/inversionistas/api/agent.py
Normal file
133
apps/inversionistas/api/agent.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
||||
from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_qdrant import QdrantVectorStore
|
||||
|
||||
import api.context as ctx
|
||||
from api.config import config
|
||||
from api.prompts import ORCHESTRATOR_PROMPT, TOOL_SCHEMAS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com"
|
||||
SQLITE_DB_PATH = Path(__file__).parent / "db.sqlite"
|
||||
|
||||
|
||||
class MayaInversionistas:
|
||||
system_prompt = ORCHESTRATOR_PROMPT
|
||||
generation_config = {
|
||||
"temperature": config.model_temperature,
|
||||
}
|
||||
message_limit = config.message_limit
|
||||
index = config.vector_index
|
||||
limit = config.search_limit
|
||||
bucket = config.storage_bucket
|
||||
|
||||
llm = AzureAIChatCompletionsModel(
|
||||
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}",
|
||||
credential=config.openai_api_key,
|
||||
).bind_tools(TOOL_SCHEMAS)
|
||||
embedder = AzureAIEmbeddingsModel(
|
||||
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}",
|
||||
credential=config.openai_api_key,
|
||||
)
|
||||
search = QdrantVectorStore.from_existing_collection(
|
||||
embedding=embedder,
|
||||
collection_name=index,
|
||||
url=config.qdrant_url,
|
||||
api_key=config.qdrant_api_key,
|
||||
)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.tool_map = {
|
||||
"getGFNORTEData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "gf_norte"),
|
||||
"getBanorteConsolidadoData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_consolidado"),
|
||||
"getAlmacenadoraConsolidadoData": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "almacenadora_consolidado"),
|
||||
"getArrendadoraFactorConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "arrendadora_factor_consolidado"),
|
||||
"getCasadeBolsaConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "casa_bolsa_conosolidado"),
|
||||
"getOperadoradeFondos": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "op_fondos"),
|
||||
"getSectorBursatil": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_bursatil"),
|
||||
"getSectorBAPConsolidado": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_bap_consolidado"),
|
||||
"getSeguros": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "seguros"),
|
||||
"getPensiones": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "pensiones"),
|
||||
"getBineo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "bineo"),
|
||||
"getSectorBanca": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "sector_banca"),
|
||||
"getHolding": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "holding"),
|
||||
"getBanorteFinancialServices": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_financial_services"),
|
||||
"getFideicomisoBursaGEM": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "fideicomiso_bursa_gem"),
|
||||
"getTarjetasdelFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "tarjetas_del_futuro"),
|
||||
"getAfore": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "afore"),
|
||||
"getBanorteFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "banorte_futuro"),
|
||||
"getSegurosSinBanorteFuturo": lambda year, quarter, concept: self.run_sqlite_tool(year, quarter, concept, "seguros_sin_banorte_futuro"),
|
||||
"getInformationalData": self.run_qdrant_tool,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_response(results: list[dict]) -> str:
|
||||
return (
|
||||
"I have retrieved the following results from the database:\n"
|
||||
+ json.dumps(results)
|
||||
+ "\nPara mayor información consultar el Reporte de Resultados Trimestral (URL: https://investors.banorte.com/es/financial-information/quarterly-reports)"
|
||||
)
|
||||
|
||||
async def run_sqlite_tool(self, year: int, quarter: int, concept: str, table: str):
|
||||
results = await self.get_data_from_sqlite(year, quarter, concept, table)
|
||||
data = [dict(row) for row in results]
|
||||
return self.build_response(data)
|
||||
|
||||
async def run_qdrant_tool(self, question: str):
|
||||
logger.info(
|
||||
f"Embedding question: {question} with model {self.embedder.model_name}"
|
||||
)
|
||||
results = self.search.similarity_search(question)
|
||||
data = [dict(row.metadata) for row in results]
|
||||
tool_response = self.build_response(data)
|
||||
return tool_response
|
||||
|
||||
@staticmethod
|
||||
async def get_data_from_sqlite(year: int, quarter: int, concept: str, table: str):
|
||||
async with aiosqlite.connect(SQLITE_DB_PATH) as db:
|
||||
query = """
|
||||
SELECT * FROM {}
|
||||
WHERE year = ? AND trim = ? AND concept = ?
|
||||
""".format(table)
|
||||
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(query, (year, quarter, concept))
|
||||
rows = await cursor.fetchall()
|
||||
return rows
|
||||
|
||||
def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]:
|
||||
generation_config_copy = self.generation_config.copy()
|
||||
if overwrites:
|
||||
for k, v in overwrites.items():
|
||||
generation_config_copy[k] = v
|
||||
return generation_config_copy
|
||||
|
||||
async def stream(self, history, overwrites: dict | None = None):
|
||||
generation_config = self._generation_config_overwrite(overwrites)
|
||||
|
||||
async for chunk in self.llm.astream(input=history, **generation_config):
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if call := chunk.tool_call_chunks:
|
||||
if tool_id := call[0].get("id"):
|
||||
ctx.tool_id.set(tool_id)
|
||||
if name := call[0].get("name"):
|
||||
ctx.tool_name.set(name)
|
||||
if args := call[0].get("args"):
|
||||
ctx.tool_buffer.set(ctx.tool_buffer.get() + args)
|
||||
else:
|
||||
if buffer := chunk.content:
|
||||
assert isinstance(buffer, str)
|
||||
ctx.buffer.set(ctx.buffer.get() + buffer)
|
||||
yield buffer
|
||||
|
||||
async def generate(self, history, overwrites: dict | None = None):
|
||||
generation_config = self._generation_config_overwrite(overwrites)
|
||||
return await self.llm.ainvoke(input=history, **generation_config)
|
||||
Reference in New Issue
Block a user