Skip to content

Commit

Permalink
Add /messages/{message_id}/emoji endpoint to toggle, add, remove mess…
Browse files Browse the repository at this point in the history
…age emojis (#925)

* add endpoint to set message emojis

* make refresh result optional in db utils
  • Loading branch information
andreaskoepf authored Jan 25, 2023
1 parent 4146930 commit 558b207
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""add message_emoji
Revision ID: 40ed93df0ed5
Revises: 8ba17b5f467a
Create Date: 2023-01-24 22:56:28.229408
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "40ed93df0ed5"
down_revision = "8ba17b5f467a"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"message_emoji",
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
),
sa.Column("emoji", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("message_id", "user_id", "emoji"),
)
op.create_index("ix_message_emoji__user_id__message_id", "message_emoji", ["user_id", "message_id"], unique=False)
op.add_column("message", sa.Column("emojis", postgresql.JSONB(astext_type=sa.Text()), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "emojis")
op.drop_index("ix_message_emoji__user_id__message_id", table_name="message_emoji")
op.drop_table("message_emoji")
# ### end Alembic commands ###
20 changes: 20 additions & 0 deletions backend/oasst_backend/api/v1/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
Expand Down Expand Up @@ -229,3 +230,22 @@ def mark_message_deleted(
):
pr = PromptRepository(db, api_client)
pr.mark_messages_deleted(message_id)


@router.post("/{message_id}/emoji", response_model=protocol.Message)
def post_message_emoji(
*,
message_id: UUID,
request: protocol.MessageEmojiRequest,
api_client: ApiClient = Depends(deps.get_api_client),
) -> protocol.Message:
"""
Toggle, add or remove message emoji.
"""

@managed_tx_function(CommitMode.COMMIT)
def emoji_tx(session: deps.Session):
pr = PromptRepository(session, api_client, client_user=request.user)
return pr.handle_message_emoji(message_id, request.op, request.emoji)

return utils.prepare_message(emoji_tx())
1 change: 1 addition & 0 deletions backend/oasst_backend/api/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def prepare_message(m: Message) -> protocol.Message:
lang=m.lang,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
emojis=m.emojis,
)


Expand Down
2 changes: 2 additions & 0 deletions backend/oasst_backend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .journal import Journal, JournalIntegration
from .message import Message
from .message_embedding import MessageEmbedding
from .message_emoji import MessageEmoji
from .message_reaction import MessageReaction
from .message_toxicity import MessageToxicity
from .message_tree_state import MessageTreeState
Expand All @@ -24,4 +25,5 @@
"TextLabels",
"Journal",
"JournalIntegration",
"MessageEmoji",
]
2 changes: 2 additions & 0 deletions backend/oasst_backend/models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class Message(SQLModel, table=True):

rank: Optional[int] = Field(nullable=True)

emojis: dict[str, int] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)

def ensure_is_message(self) -> None:
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
Expand Down
27 changes: 27 additions & 0 deletions backend/oasst_backend/models/message_emoji.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from datetime import datetime
from typing import Optional
from uuid import UUID

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel


class MessageEmoji(SQLModel, table=True):
__tablename__ = "message_emoji"
__table_args__ = (Index("ix_message_emoji__user_id__message_id", "user_id", "message_id", unique=False),)

message_id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True
)
)
user_id: UUID = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), nullable=False, primary_key=True
)
)
emoji: str = Field(nullable=False, max_length=128, primary_key=True)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
61 changes: 61 additions & 0 deletions backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ApiClient,
Message,
MessageEmbedding,
MessageEmoji,
MessageReaction,
MessageToxicity,
MessageTreeState,
Expand All @@ -29,6 +30,7 @@
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
from oasst_shared.utils import unaware_to_utc
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, and_, func, not_, or_, text, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND

Expand Down Expand Up @@ -843,3 +845,62 @@ def get_stats(self) -> SystemStats:
deleted=result.get(True, 0),
message_trees=result.get(None, 0),
)

def handle_message_emoji(self, message_id: UUID, op: protocol_schema.EmojiOp, emoji: protocol_schema) -> Message:
self.ensure_user_is_enabled()

message = self.fetch_message(message_id)

# check if emoji exists
existing_emoji = (
self.db.query(MessageEmoji)
.filter(
MessageEmoji.message_id == message_id, MessageEmoji.user_id == self.user_id, MessageEmoji.emoji == emoji
)
.one_or_none()
)

if existing_emoji:
if op == protocol_schema.EmojiOp.add:
logger.info(f"Emoji record already exists {message_id=}, {emoji=}, {self.user_id=}")
return message
elif op == protocol_schema.EmojiOp.togggle:
op = protocol_schema.EmojiOp.remove

if existing_emoji is None:
if op == protocol_schema.EmojiOp.remove:
logger.info(f"Emoji record not found {message_id=}, {emoji=}, {self.user_id=}")
return message
elif op == protocol_schema.EmojiOp.togggle:
op = protocol_schema.EmojiOp.add

if op == protocol_schema.EmojiOp.add:
# insert emoji record & increment count
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
self.db.add(message_emoji)
emoji_counts = message.emojis
if not emoji_counts:
message.emojis = {emoji.value: 1}
else:
count = emoji_counts.get(emoji.value) or 0
emoji_counts[emoji.value] = count + 1
elif op == protocol_schema.EmojiOp.remove:
# remove emoji record and & decrement count
message = self.fetch_message(message_id)
self.db.delete(existing_emoji)
emoji_counts = message.emojis
count = emoji_counts.get(emoji.value)
if count is not None:
if count == 1:
del emoji_counts[emoji.value]
else:
emoji_counts[emoji.value] = count - 1
flag_modified(message, "emojis")
self.db.add(message)
else:
raise OasstError("Emoji op not supported", OasstErrorCode.EMOJI_OP_UNSUPPORTED)

flag_modified(message, "emojis")
self.db.add(message)
self.db.flush()
return message
3 changes: 3 additions & 0 deletions backend/oasst_backend/utils/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def managed_tx_function(
auto_commit: CommitMode = CommitMode.COMMIT,
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
session_factory: Callable[..., Session] = default_session_factor,
refresh_result: bool = True,
):
"""Passes Session object as first argument to wrapped function."""

Expand All @@ -124,6 +125,8 @@ def wrapped_f(*args, **kwargs):
session.flush()
elif auto_commit == CommitMode.ROLLBACK:
session.rollback()
if refresh_result and isinstance(result, SQLModel):
session.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
Expand Down
2 changes: 2 additions & 0 deletions oasst-shared/oasst_shared/exceptions/oasst_api_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class OasstErrorCode(IntEnum):
USER_DISABLED = 4001
USER_NOT_FOUND = 4002

EMOJI_OP_UNSUPPORTED = 5000


class OasstError(Exception):
"""Base class for Open-Assistant exceptions."""
Expand Down
25 changes: 25 additions & 0 deletions oasst-shared/oasst_shared/schemas/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def is_prompter_turn(self) -> bool:
class Message(ConversationMessage):
parent_id: Optional[UUID] = None
created_date: Optional[datetime] = None
emojis: Optional[dict] = None


class MessagePage(PageResult):
Expand Down Expand Up @@ -432,3 +433,27 @@ class OasstErrorResponse(BaseModel):

error_code: OasstErrorCode
message: str


class EmojiCode(str, enum.Enum):
thumbs_up = "+1" # 👍
thumbs_down = "-1" # 👎
red_flag = "red_flag" # 🚩
hundred = "100" # 💯
rofl = "rofl" # 🤣"
heart_eyes = "heart_eyes" # 😍
disappointed = "disappointed" # 😞
poop = "poop" # 💩
skull = "skull" # 💀


class EmojiOp(str, enum.Enum):
togggle = "toggle"
add = "add"
remove = "remove"


class MessageEmojiRequest(BaseModel):
user: User
op: EmojiOp = EmojiOp.togggle
emoji: EmojiCode

0 comments on commit 558b207

Please sign in to comment.