forked from innovacion/Mayacontigo
ic
This commit is contained in:
181
packages/postgres/banortegpt/database/postgres/crud.py
Normal file
181
packages/postgres/banortegpt/database/postgres/crud.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError, NoResultFound
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .models import Assistant, Conversation, Message, User
|
||||
|
||||
### Users
|
||||
|
||||
|
||||
async def get_all_users(session: AsyncSession) -> Sequence[User]:
|
||||
statement = select(User)
|
||||
scalars = await session.scalars(statement)
|
||||
return scalars.all()
|
||||
|
||||
|
||||
async def get_user(session: AsyncSession, user_id: int) -> User:
|
||||
try:
|
||||
return await session.get_one(User, user_id)
|
||||
except NoResultFound as e:
|
||||
raise ValueError("No user by that id!") from e
|
||||
|
||||
|
||||
async def get_user_by_name(session: AsyncSession, name: str) -> User:
|
||||
try:
|
||||
statement = select(User).where(User.name == name)
|
||||
scalar = await session.scalars(statement)
|
||||
return scalar.one()
|
||||
except NoResultFound as e:
|
||||
raise ValueError("No user by that name!") from e
|
||||
|
||||
|
||||
async def create_user(session: AsyncSession, name: str) -> User:
|
||||
try:
|
||||
new_user = User(name=name)
|
||||
session.add(new_user)
|
||||
await session.commit()
|
||||
await session.refresh(new_user)
|
||||
return new_user
|
||||
except IntegrityError as e:
|
||||
raise ValueError("User by that name already exists!") from e
|
||||
|
||||
|
||||
async def delete_user(session: AsyncSession, name: str):
|
||||
try:
|
||||
statement = select(User).where(User.name == name)
|
||||
db_user: User = (await session.scalars(statement)).one()
|
||||
await session.delete(db_user)
|
||||
await session.commit()
|
||||
except NoResultFound as e:
|
||||
raise ValueError("No assistant by that id exists.") from e
|
||||
|
||||
|
||||
### Assistants
|
||||
|
||||
|
||||
async def get_all_assistants(session: AsyncSession) -> Sequence[Assistant]:
|
||||
statement = select(Assistant)
|
||||
scalars = await session.scalars(statement)
|
||||
return scalars.all()
|
||||
|
||||
|
||||
async def get_assistant(session: AsyncSession, assistant_id: int) -> Assistant:
|
||||
try:
|
||||
statement = select(Assistant).where(Assistant.id == assistant_id)
|
||||
scalars = await session.scalars(statement)
|
||||
return scalars.one()
|
||||
except NoResultFound as e:
|
||||
raise ValueError("No assistant by that id!") from e
|
||||
|
||||
|
||||
async def get_assistant_by_name(session: AsyncSession, name: str) -> Assistant:
|
||||
try:
|
||||
statement = select(Assistant).where(Assistant.name == name)
|
||||
scalars = await session.scalars(statement)
|
||||
return scalars.one()
|
||||
except NoResultFound as e:
|
||||
raise ValueError("No assistant by that name!") from e
|
||||
|
||||
|
||||
async def create_assistant(
|
||||
session: AsyncSession, name: str, system_prompt: str
|
||||
) -> Assistant:
|
||||
try:
|
||||
new_assistant = Assistant(name=name, system_prompt=system_prompt)
|
||||
session.add(new_assistant)
|
||||
await session.commit()
|
||||
await session.refresh(new_assistant)
|
||||
return new_assistant
|
||||
except IntegrityError as e:
|
||||
raise ValueError("Assistant with that name already exists.") from e
|
||||
|
||||
|
||||
async def delete_assistant(session: AsyncSession, assistant_name: str) -> None:
|
||||
try:
|
||||
statement = select(Assistant).where(Assistant.name == assistant_name)
|
||||
db_assistant: Assistant = (await session.scalars(statement)).one()
|
||||
await session.delete(db_assistant)
|
||||
await session.commit()
|
||||
except NoResultFound as e:
|
||||
raise ValueError("No assistant by that name exists.") from e
|
||||
|
||||
|
||||
### Conversations
|
||||
|
||||
|
||||
async def get_conversation(session: AsyncSession, conversation_id: int) -> Conversation:
|
||||
try:
|
||||
conversation = await session.get_one(Conversation, conversation_id)
|
||||
await session.refresh(conversation, ["messages", "id"])
|
||||
return conversation
|
||||
except NoResultFound as e:
|
||||
raise ValueError("No conversation by that id exists.") from e
|
||||
|
||||
|
||||
async def create_conversation(
|
||||
session: AsyncSession, user: str, assistant: str, system_prompt: str | None = None
|
||||
) -> Conversation:
|
||||
try:
|
||||
db_user = await get_user_by_name(session=session, name=user)
|
||||
db_assistant = await get_assistant_by_name(session=session, name=assistant)
|
||||
|
||||
await session.refresh(db_user, ["id"])
|
||||
await session.refresh(db_assistant, ["id"])
|
||||
|
||||
db_conversation = Conversation(user_id=db_user.id, assistant_id=db_assistant.id)
|
||||
|
||||
if system_prompt is not None:
|
||||
db_conversation.add_message(role="system", content=system_prompt)
|
||||
|
||||
session.add(db_conversation)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return db_conversation
|
||||
except ValueError as e:
|
||||
raise ValueError("User or assistant do not exist!") from e
|
||||
|
||||
|
||||
async def delete_conversation(session: AsyncSession, conversation_id: int) -> None:
|
||||
try:
|
||||
db_conversation = await session.get_one(Conversation, conversation_id)
|
||||
await session.delete(db_conversation)
|
||||
await session.commit()
|
||||
except NoResultFound as e:
|
||||
raise ValueError("No conversation by that id exists.") from e
|
||||
|
||||
|
||||
async def soft_delete_conversation(session: AsyncSession, conversation_id: int) -> None:
|
||||
try:
|
||||
db_conversation = await session.get_one(Conversation, conversation_id)
|
||||
await session.refresh(db_conversation, ["active"])
|
||||
db_conversation.active = False
|
||||
await session.commit()
|
||||
except NoResultFound as e:
|
||||
raise ValueError("No conversation by that id exists.") from e
|
||||
|
||||
|
||||
### Messages
|
||||
|
||||
|
||||
async def get_message_by_id(session: AsyncSession, message_id: int) -> Message:
|
||||
return await session.get_one(Message, message_id)
|
||||
|
||||
|
||||
async def update_message_feedback_by_id(
|
||||
session: AsyncSession, message_id: int, rating: bool | None
|
||||
):
|
||||
try:
|
||||
db_message = await session.get_one(Message, message_id)
|
||||
db_message.feedback = rating
|
||||
|
||||
session.add(db_message)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_message)
|
||||
|
||||
return db_message
|
||||
except ValueError as e:
|
||||
raise ValueError("Message does not exist!") from e
|
||||
123
packages/postgres/banortegpt/database/postgres/models.py
Normal file
123
packages/postgres/banortegpt/database/postgres/models.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import ForeignKey, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB, TIMESTAMP
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
MappedAsDataclass,
|
||||
column_property,
|
||||
mapped_column,
|
||||
relationship,
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase, MappedAsDataclass, AsyncAttrs):
|
||||
type_annotation_map = {
|
||||
datetime: TIMESTAMP(timezone=True),
|
||||
dict: JSONB,
|
||||
}
|
||||
|
||||
|
||||
class CommonMixin(MappedAsDataclass):
|
||||
id: Mapped[int] = mapped_column(init=False, primary_key=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
init=False, default_factory=lambda: datetime.now(UTC)
|
||||
)
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
init=False, default=None, onupdate=lambda: datetime.now(UTC)
|
||||
)
|
||||
active: Mapped[bool] = mapped_column(init=False, default=True)
|
||||
|
||||
|
||||
class User(CommonMixin, Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
name: Mapped[str] = mapped_column(unique=True, index=True)
|
||||
|
||||
conversations: Mapped[list[Conversation]] = relationship(
|
||||
init=False, back_populates="user"
|
||||
)
|
||||
|
||||
|
||||
class Assistant(CommonMixin, Base):
|
||||
__tablename__ = "assistants"
|
||||
|
||||
name: Mapped[str] = mapped_column(unique=True, index=True)
|
||||
system_prompt: Mapped[str]
|
||||
|
||||
conversations: Mapped[list[Conversation]] = relationship(
|
||||
init=False, back_populates="assistant"
|
||||
)
|
||||
|
||||
|
||||
class Conversation(CommonMixin, Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"))
|
||||
assistant_id: Mapped[int] = mapped_column(ForeignKey("assistants.id"))
|
||||
|
||||
user: Mapped[User] = relationship(init=False, back_populates="conversations")
|
||||
assistant: Mapped[Assistant] = relationship(
|
||||
init=False, back_populates="conversations"
|
||||
)
|
||||
messages: Mapped[list[Message]] = relationship(
|
||||
init=False, order_by="Message.created_at"
|
||||
)
|
||||
|
||||
assistant_name = column_property(
|
||||
select(Assistant.name).where(Assistant.id == assistant_id).scalar_subquery()
|
||||
)
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
role: str,
|
||||
content: str | None = None,
|
||||
tools: dict | None = None,
|
||||
query_id: int | None = None,
|
||||
):
|
||||
self.messages.append(
|
||||
Message(
|
||||
role=role,
|
||||
content=content,
|
||||
tools=tools,
|
||||
query_id=query_id,
|
||||
conversation_id=self.id,
|
||||
)
|
||||
)
|
||||
|
||||
async def to_openai_format(self):
|
||||
messages = await self.awaitable_attrs.messages
|
||||
return [(await m.to_openai_format()) for m in messages]
|
||||
|
||||
|
||||
class Message(CommonMixin, Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
conversation_id: Mapped[int] = mapped_column(ForeignKey("conversations.id"))
|
||||
|
||||
role: Mapped[str]
|
||||
content: Mapped[str | None] = mapped_column(default=None)
|
||||
feedback: Mapped[bool | None] = mapped_column(default=None)
|
||||
tools: Mapped[dict | None] = mapped_column(default=None)
|
||||
query_id: Mapped[int | None] = mapped_column(default=None)
|
||||
|
||||
async def to_openai_format(self):
|
||||
role = await self.awaitable_attrs.role
|
||||
content = await self.awaitable_attrs.content
|
||||
tools = await self.awaitable_attrs.tools
|
||||
return {
|
||||
"role": role,
|
||||
"content": content,
|
||||
**(tools or {}),
|
||||
}
|
||||
|
||||
|
||||
class Comment(CommonMixin, Base):
|
||||
__tablename__ = "comments"
|
||||
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("messages.id"))
|
||||
content: Mapped[str]
|
||||
Reference in New Issue
Block a user