from pathlib import Path from typing import Any from langchain_core.messages import AIMessageChunk from pydantic import BaseModel, Field from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel from banortegpt.vector.qdrant import AsyncQdrant from api import context from api.config import config parent = Path(__file__).parent SYSTEM_PROMPT = (parent / "system_prompt.md").read_text() AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com" class get_information(BaseModel): """Search a private repository for information.""" question: str = Field(..., description="The user question") class Agent: system_prompt = SYSTEM_PROMPT generation_config = { "temperature": config.model_temperature, } embedding_model = config.embedding_model message_limit = config.message_limit index = config.vector_index limit = config.search_limit search = AsyncQdrant.from_config(config) llm = AzureAIChatCompletionsModel( endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}", credential=config.openai_api_key, ).bind_tools([get_information]) embedder = AzureAIEmbeddingsModel( endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}", credential=config.openai_api_key, ) def __init__(self) -> None: self.tool_map = { "get_information": self.get_information } def build_response(self, payloads, fallback): template = "\n\n{content}\n\n" filled_templates = [ template.format(index=idx, content=payload["content"]) for idx, payload in enumerate(payloads) ] filled_templates.append(f"\n{fallback}\n") return "\n".join(filled_templates) async def get_information(self, question: str): embedding = await self.embedder.aembed_query(question) payloads = await self.search.semantic_search( embedding=embedding, collection=self.index, limit=self.limit, ) fallback_messages = {} images = [] for idx, payload in enumerate(payloads): fallback_message = payload.get("fallback_message", "None") fallback_messages[fallback_message] = fallback_messages.get(fallback_message, 0) + 1 # Solo extraer imágenes del primer payload if idx == 0 and "images" in payload: images.extend(payload["images"]) fallback = max(fallback_messages, key=fallback_messages.get) # type: ignore response = self.build_response(payloads, fallback) return str(response), images[:3] # Limitar a 3 imágenes máximo def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]: if not overwrites: return self.generation_config.copy() return {**self.generation_config, **overwrites} async def stream(self, history, overwrites: dict | None = None): generation_config = self._generation_config_overwrite(overwrites) async for delta in self.llm.astream(input=history, **generation_config): assert isinstance(delta, AIMessageChunk) if call := delta.tool_call_chunks: if tool_id := call[0].get("id"): context.tool_id.set(tool_id) if name := call[0].get("name"): context.tool_name.set(name) if args := call[0].get("args"): context.tool_buffer.set(context.tool_buffer.get() + args) elif delta.content: assert isinstance(delta.content, str) context.buffer.set(context.buffer.get() + delta.content) yield delta.content async def generate(self, history, overwrites: dict | None = None): generation_config = self._generation_config_overwrite(overwrites) return await self.llm.ainvoke(input=history, **generation_config)