Files
Mayacontigo/apps/ChatEgresos/api/agent/main.py
Rogelio 325f1ef439 ic
2025-10-13 18:16:25 +00:00

108 lines
4.0 KiB
Python

from pathlib import Path
from typing import Any
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel, Field
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel
from banortegpt.vector.qdrant import AsyncQdrant
from api import context
from api.config import config
parent = Path(__file__).parent
SYSTEM_PROMPT = (parent / "system_prompt.md").read_text()
AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com"
class get_information(BaseModel):
"""Search a private repository for information."""
question: str = Field(..., description="The user question")
class Agent:
system_prompt = SYSTEM_PROMPT
generation_config = {
"temperature": config.model_temperature,
}
embedding_model = config.embedding_model
message_limit = config.message_limit
index = config.vector_index
limit = config.search_limit
search = AsyncQdrant.from_config(config)
llm = AzureAIChatCompletionsModel(
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}",
credential=config.openai_api_key,
).bind_tools([get_information])
embedder = AzureAIEmbeddingsModel(
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}",
credential=config.openai_api_key,
)
def __init__(self) -> None:
self.tool_map = {
"get_information": self.get_information
}
def build_response(self, payloads, fallback):
template = "<FAQ {index}>\n\n{content}\n\n</FAQ {index}>"
filled_templates = [
template.format(index=idx, content=payload["content"])
for idx, payload in enumerate(payloads)
]
filled_templates.append(f"<FALLBACK>\n{fallback}\n</FALLBACK>")
return "\n".join(filled_templates)
async def get_information(self, question: str):
embedding = await self.embedder.aembed_query(question)
payloads = await self.search.semantic_search(
embedding=embedding,
collection=self.index,
limit=self.limit,
)
fallback_messages = {}
images = []
for idx, payload in enumerate(payloads):
fallback_message = payload.get("fallback_message", "None")
fallback_messages[fallback_message] = fallback_messages.get(fallback_message, 0) + 1
# Solo extraer imágenes del primer payload
if idx == 0 and "images" in payload:
images.extend(payload["images"])
fallback = max(fallback_messages, key=fallback_messages.get) # type: ignore
response = self.build_response(payloads, fallback)
return str(response), images[:3] # Limitar a 3 imágenes máximo
def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]:
if not overwrites:
return self.generation_config.copy()
return {**self.generation_config, **overwrites}
async def stream(self, history, overwrites: dict | None = None):
generation_config = self._generation_config_overwrite(overwrites)
async for delta in self.llm.astream(input=history, **generation_config):
assert isinstance(delta, AIMessageChunk)
if call := delta.tool_call_chunks:
if tool_id := call[0].get("id"):
context.tool_id.set(tool_id)
if name := call[0].get("name"):
context.tool_name.set(name)
if args := call[0].get("args"):
context.tool_buffer.set(context.tool_buffer.get() + args)
elif delta.content:
assert isinstance(delta.content, str)
context.buffer.set(context.buffer.get() + delta.content)
yield delta.content
async def generate(self, history, overwrites: dict | None = None):
generation_config = self._generation_config_overwrite(overwrites)
return await self.llm.ainvoke(input=history, **generation_config)