Skip to content

Commit

Permalink
968 add flagged message table and endpoints (#1325)
Browse files Browse the repository at this point in the history
* Added flagged message table

* Added alembic migration and updated imports to match style

* Added GET endpoint to query all flagged messages

* Updates from linter

* Added POST endpoint for processing flagged messages

* Added pydantic interface model and fixed limit update bug

* fixed session in admin endpoint and added require session refresh for returned update

* removed unused import
  • Loading branch information
GraemeHarris authored Feb 7, 2023
1 parent 1153734 commit 66b7ed2
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Added new table for flagged messages
Revision ID: caee1e8ee0bc
Revises: 8c8241d1f973
Create Date: 2023-02-07 19:22:12.696257
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "caee1e8ee0bc"
down_revision = "8c8241d1f973"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"flagged_message",
sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
),
sa.Column("processed", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(["message_id"], ["message.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("message_id"),
)
op.create_index(op.f("ix_flagged_message_created_date"), "flagged_message", ["created_date"], unique=False)
op.create_index(op.f("ix_flagged_message_processed"), "flagged_message", ["processed"], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_flagged_message_processed"), table_name="flagged_message")
op.drop_index(op.f("ix_flagged_message_created_date"), table_name="flagged_message")
op.drop_table("flagged_message")
# ### end Alembic commands ###
35 changes: 35 additions & 0 deletions backend/oasst_backend/api/v1/admin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Optional
from uuid import UUID

import pydantic
Expand Down Expand Up @@ -162,3 +163,37 @@ def purge_user_messages_tx(session: deps.Session):

logger.info(f"{before=}; {after=}")
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)


class FlaggedMessageResponse(pydantic.BaseModel):
message_id: UUID
processed: bool
created_date: Optional[datetime]


@router.get("/flagged_messages", response_model=list[FlaggedMessageResponse])
async def get_flagged_messages(
max_count: Optional[int],
session: deps.Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> str:
assert api_client.trusted

pr = PromptRepository(session, api_client)
flagged_messages = pr.fetch_flagged_messages(max_count=max_count)
resp = [FlaggedMessageResponse(**msg.__dict__) for msg in flagged_messages]
return resp


@router.post("/admin/flagged_messages/{message_id}/processed", response_model=FlaggedMessageResponse)
async def process_flagged_messages(
message_id: UUID,
session: deps.Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> str:
assert api_client.trusted

pr = PromptRepository(session, api_client)
flagged_msg = pr.process_flagged_message(message_id=message_id)
resp = FlaggedMessageResponse(**flagged_msg.__dict__)
return resp
2 changes: 2 additions & 0 deletions backend/oasst_backend/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .api_client import ApiClient
from .flagged_message import FlaggedMessage
from .journal import Journal, JournalIntegration
from .message import Message
from .message_embedding import MessageEmbedding
Expand Down Expand Up @@ -28,4 +29,5 @@
"JournalIntegration",
"MessageEmoji",
"TrollStats",
"FlaggedMessage",
]
23 changes: 23 additions & 0 deletions backend/oasst_backend/models/flagged_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from datetime import datetime
from typing import Optional
from uuid import UUID

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel


class FlaggedMessage(SQLModel, table=True):
__tablename__ = "flagged_message"

message_id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), sa.ForeignKey("message.id", ondelete="CASCADE"), nullable=False, primary_key=True
)
)
processed: bool = Field(nullable=False, index=True)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(
sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
)
)
31 changes: 31 additions & 0 deletions backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from uuid import UUID, uuid4

import oasst_backend.models.db_payload as db_payload
import sqlalchemy.dialects.postgresql as pg
from loguru import logger
from oasst_backend.api.deps import FrontendUserId
from oasst_backend.config import settings
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import (
ApiClient,
FlaggedMessage,
Message,
MessageEmbedding,
MessageEmoji,
Expand Down Expand Up @@ -1092,6 +1094,15 @@ def handle_message_emoji(
logger.debug(f"Ignoring add emoji op for user's own message ({emoji=})")
return message

# Add to flagged_message table if the red flag emoji is applied
if emoji == protocol_schema.EmojiCode.red_flag:
flagged_message = FlaggedMessage(
message_id=message_id, processed=False, created_date=datetime.now().astimezone()
)
insert_stmt = pg.insert(FlaggedMessage).values(**flagged_message.__dict__)
upsert_stmt = insert_stmt.on_conflict_do_update(constraint="message_id", set_=flagged_message.__dict__)
self.db.execute(upsert_stmt)

# insert emoji record & increment count
message_emoji = MessageEmoji(message_id=message.id, user_id=self.user_id, emoji=emoji)
self.db.add(message_emoji)
Expand Down Expand Up @@ -1127,3 +1138,23 @@ def handle_message_emoji(
self.db.add(message)
self.db.flush()
return message

def fetch_flagged_messages(self, max_count: Optional[int]) -> list[FlaggedMessage]:
qry = self.db.query(FlaggedMessage)
if max_count is not None:
qry = qry.limit(max_count)

return qry.all()

def process_flagged_message(self, message_id: UUID) -> FlaggedMessage:

message = self.db.query(FlaggedMessage).get(message_id)

if not message:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)

message.processed = True
self.db.commit()
self.db.refresh(message)

return message

0 comments on commit 66b7ed2

Please sign in to comment.