forked from innovacion/Mayacontigo
131 lines
4.7 KiB
Python
131 lines
4.7 KiB
Python
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from langchain_core.messages import AIMessageChunk
|
|
from pydantic import BaseModel, Field
|
|
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
|
|
from langchain_azure_ai.embeddings import AzureAIEmbeddingsModel
|
|
from banortegpt.storage.azure_storage import AzureStorage
|
|
from banortegpt.vector.qdrant import AsyncQdrant
|
|
|
|
from api import context
|
|
from api.config import config
|
|
|
|
parent = Path(__file__).parent
|
|
SYSTEM_PROMPT = (parent / "system_prompt.md").read_text()
|
|
|
|
AZURE_AI_URI = "https://eastus2.api.cognitive.microsoft.com"
|
|
|
|
class get_information(BaseModel):
|
|
"""Search a private repository for information."""
|
|
|
|
question: str = Field(..., description="The user question")
|
|
|
|
class MayaBursatil:
|
|
system_prompt = SYSTEM_PROMPT
|
|
generation_config = {
|
|
"temperature": config.model_temperature,
|
|
}
|
|
embedding_model = config.embedding_model
|
|
message_limit = config.message_limit
|
|
index = config.vector_index
|
|
limit = config.search_limit
|
|
bucket = config.storage_bucket
|
|
|
|
search = AsyncQdrant.from_config(config)
|
|
llm = AzureAIChatCompletionsModel(
|
|
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.model}",
|
|
credential=config.openai_api_key,
|
|
).bind_tools([get_information])
|
|
embedder = AzureAIEmbeddingsModel(
|
|
endpoint=f"{AZURE_AI_URI}/openai/deployments/{config.embedding_model}",
|
|
credential=config.openai_api_key,
|
|
)
|
|
storage = AzureStorage.from_config(config)
|
|
|
|
def __init__(self) -> None:
|
|
self.tool_map = {
|
|
"get_information": self.get_information
|
|
}
|
|
|
|
def build_response(self, payloads, fallback):
|
|
template = "<FAQ {index}>\n\n{content}\n\n</FAQ {index}>"
|
|
|
|
filled_templates = [
|
|
template.format(index=idx, content=payload["content"])
|
|
for idx, payload in enumerate(payloads)
|
|
]
|
|
|
|
filled_templates.append(f"<FALLBACK>\n{fallback}\n</FALLBACK>")
|
|
|
|
return "\n".join(filled_templates)
|
|
|
|
async def get_information(self, question: str):
|
|
embedding = await self.embedder.aembed_query(question)
|
|
|
|
payloads = await self.search.semantic_search(embedding=embedding, collection=self.index, limit=self.limit)
|
|
|
|
fallback_messages: dict[str, int] = {}
|
|
|
|
for payload in payloads:
|
|
fallback_message = payload.get("fallback_message", "None")
|
|
if fallback_message not in fallback_messages:
|
|
fallback_messages[fallback_message] = 1
|
|
else:
|
|
fallback_messages[fallback_message] += 1
|
|
|
|
fallback = max(fallback_messages, key=fallback_messages.get) # type: ignore
|
|
|
|
tool_response = self.build_response(payloads, fallback)
|
|
|
|
return tool_response, payloads
|
|
|
|
async def get_shareable_urls(self, payloads: list):
|
|
reference_urls = []
|
|
image_urls = []
|
|
|
|
for payload in payloads:
|
|
if imagen := payload.get("imagen"):
|
|
image_url = await self.storage.get_file_url(
|
|
filename=imagen,
|
|
bucket=self.bucket,
|
|
minute_duration=20,
|
|
image=True,
|
|
)
|
|
|
|
if image_url:
|
|
image_urls.append(image_url)
|
|
else:
|
|
print("Image not found")
|
|
|
|
return reference_urls, image_urls
|
|
|
|
def _generation_config_overwrite(self, overwrites: dict | None) -> dict[str, Any]:
|
|
generation_config_copy = self.generation_config.copy()
|
|
if overwrites:
|
|
for k, v in overwrites.items():
|
|
generation_config_copy[k] = v
|
|
return generation_config_copy
|
|
|
|
async def stream(self, history, overwrites: dict | None = None):
|
|
generation_config = self._generation_config_overwrite(overwrites)
|
|
|
|
async for delta in self.llm.astream(input=history, **generation_config):
|
|
assert isinstance(delta, AIMessageChunk)
|
|
if call := delta.tool_call_chunks:
|
|
if tool_id := call[0].get("id"):
|
|
context.tool_id.set(tool_id)
|
|
if name := call[0].get("name"):
|
|
context.tool_name.set(name)
|
|
if args := call[0].get("args"):
|
|
context.tool_buffer.set(context.tool_buffer.get() + args)
|
|
else:
|
|
if buffer := delta.content:
|
|
assert isinstance(buffer, str)
|
|
context.buffer.set(context.buffer.get() + buffer)
|
|
yield buffer
|
|
|
|
async def generate(self, history, overwrites: dict | None = None):
|
|
generation_config = self._generation_config_overwrite(overwrites)
|
|
return await self.llm.ainvoke(input=history, **generation_config)
|