forked from innovacion/Mayacontigo
103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
import json
|
|
from enum import StrEnum
|
|
from typing import TypeAlias
|
|
from uuid import UUID
|
|
|
|
from banortegpt.database.mongo_memory import crud
|
|
from langfuse.decorators import langfuse_context, observe
|
|
from pydantic import BaseModel
|
|
|
|
from api import context as ctx
|
|
from api.agent import MayaRiesgos
|
|
|
|
|
|
class ChunkType(StrEnum):
|
|
START = "start"
|
|
TEXT = "text"
|
|
REFERENCE = "reference"
|
|
IMAGE = "image"
|
|
TOOL = "tool"
|
|
END = "end"
|
|
ERROR = "error"
|
|
|
|
|
|
ContentType: TypeAlias = str | int
|
|
|
|
|
|
class ResponseChunk(BaseModel):
|
|
type: ChunkType
|
|
content: ContentType | list[ContentType] | None
|
|
|
|
|
|
@observe(capture_input=False, capture_output=False)
|
|
async def stream(agent: MayaRiesgos, prompt: str, conversation_id: UUID):
|
|
yield ResponseChunk(type=ChunkType.START, content="")
|
|
|
|
conversation = await crud.get_conversation(conversation_id)
|
|
|
|
if conversation is None:
|
|
raise ValueError(f"Conversation with id {conversation_id} not found")
|
|
|
|
conversation.add(role="user", content=prompt)
|
|
|
|
history = conversation.to_openai_format(agent.message_limit, langchain_compat=True)
|
|
async for content in agent.stream(history):
|
|
yield ResponseChunk(type=ChunkType.TEXT, content=content)
|
|
|
|
if (tool_id := ctx.tool_id.get()) is not None:
|
|
tool_buffer = ctx.tool_buffer.get()
|
|
assert tool_buffer is not None
|
|
|
|
tool_name = ctx.tool_name.get()
|
|
assert tool_name is not None
|
|
|
|
yield ResponseChunk(type=ChunkType.TOOL, content=None)
|
|
|
|
buffer_dict = json.loads(tool_buffer)
|
|
|
|
response, payloads = await agent.tool_map[tool_name](**buffer_dict)
|
|
|
|
conversation.add(
|
|
role="assistant",
|
|
tool_calls=[
|
|
{
|
|
"id": tool_id,
|
|
"function": {
|
|
"name": tool_name,
|
|
"arguments": tool_buffer,
|
|
},
|
|
"type": "function",
|
|
}
|
|
],
|
|
)
|
|
conversation.add(role="tool", content=response, tool_call_id=tool_id)
|
|
|
|
history = conversation.to_openai_format(agent.message_limit, langchain_compat=True)
|
|
async for content in agent.stream(history, {"tools": None}):
|
|
yield ResponseChunk(type=ChunkType.TEXT, content=content)
|
|
|
|
ref_urls, image_urls = await agent.get_shareable_urls(payloads) # type: ignore
|
|
|
|
if len(ref_urls) > 0:
|
|
yield ResponseChunk(type=ChunkType.REFERENCE, content=ref_urls)
|
|
|
|
if len(image_urls) > 0:
|
|
yield ResponseChunk(type=ChunkType.IMAGE, content=image_urls)
|
|
|
|
buffer = ctx.buffer.get()
|
|
|
|
conversation.add(role="assistant", content=buffer)
|
|
|
|
await conversation.save()
|
|
|
|
langfuse_context.update_current_trace(
|
|
name=agent.__class__.__name__,
|
|
session_id=str(conversation_id),
|
|
input=prompt,
|
|
output=buffer,
|
|
)
|
|
|
|
yield ResponseChunk(
|
|
type=ChunkType.END, content=langfuse_context.get_current_trace_id()
|
|
)
|