Skip to content

Commit ad6c39b

Browse files
authored
Auto-moderation & remaining v1 fixes (#1089)
* add expiry date for tasks and periodic removal, fix purge user messages sibling ranking counts * add auto-moderation feature * fix doc strings * fix bad message query * add debug log on insert message * fix >= comparison
1 parent 323432c commit ad6c39b

File tree

5 files changed

+145
-6
lines changed

5 files changed

+145
-6
lines changed

backend/main.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from oasst_backend.config import settings
1919
from oasst_backend.database import engine
2020
from oasst_backend.models import message_tree_state
21-
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
21+
from oasst_backend.prompt_repository import PromptRepository, UserRepository
22+
from oasst_backend.task_repository import TaskRepository, delete_expired_tasks
2223
from oasst_backend.tree_manager import TreeManager
2324
from oasst_backend.user_repository import User
2425
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
@@ -318,6 +319,13 @@ def update_user_streak(session: Session) -> None:
318319
return
319320

320321

322+
@app.on_event("startup")
323+
@repeat_every(seconds=60 * 60) # 1 hour
324+
@managed_tx_function(auto_commit=CommitMode.COMMIT)
325+
def cronjob_delete_expired_tasks(session: Session) -> None:
326+
delete_expired_tasks(session)
327+
328+
321329
app.include_router(api_router, prefix=settings.API_V1_STR)
322330

323331

backend/oasst_backend/config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,29 @@ class TreeManagerConfiguration(BaseModel):
2525
goal_tree_size: int = 12
2626
"""Total number of messages to gather per tree."""
2727

28+
random_goal_tree_size: bool = False
29+
"""If set to true goal tree sizes will be generated randomly within range [min_goal_tree_size, goal_tree_size]."""
30+
31+
min_goal_tree_size: int = 5
32+
"""Minimum tree size for random goal sizes."""
33+
2834
num_reviews_initial_prompt: int = 3
2935
"""Number of peer review checks to collect in INITIAL_PROMPT_REVIEW state."""
3036

3137
num_reviews_reply: int = 3
3238
"""Number of peer review checks to collect per reply (other than initial_prompt)."""
3339

40+
auto_mod_enabled: bool = True
41+
"""Flag to enable/disable auto moderation."""
42+
43+
auto_mod_max_skip_reply: int = 25
44+
"""Automatically set tree state to `halted_by_moderator` when more than the specified number
45+
of users skip replying to a message. (auto moderation)"""
46+
47+
auto_mod_red_flags: int = 3
48+
"""Delete messages that receive more than this number of red flags if it is a reply or
49+
set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)"""
50+
3451
p_full_labeling_review_prompt: float = 1.0
3552
"""Probability of full text-labeling (instead of mandatory only) for initial prompts."""
3653

@@ -222,6 +239,8 @@ def validate_user_stats_intervals(cls, v: int):
222239
RATE_LIMIT_TASK_API_TIMES: int = 10_000
223240
RATE_LIMIT_TASK_API_MINUTES: int = 1
224241

242+
TASK_VALIDITY_MINUTES: int = 60 * 24 * 2 # tasks expire after 2 days
243+
225244
class Config:
226245
env_file = ".env"
227246
env_file_encoding = "utf-8"

backend/oasst_backend/prompt_repository.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,6 @@ def insert_message(
155155
review_result=review_result,
156156
)
157157
self.db.add(message)
158-
159-
# self.db.refresh(message)
160158
return message
161159

162160
def _validate_task(
@@ -288,6 +286,10 @@ def store_text_reply(
288286
task.done = True
289287
self.db.add(task)
290288
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
289+
logger.debug(
290+
f"Inserted message id={user_message.id}, tree={user_message.message_tree_id}, user_id={user_message.user_id}, "
291+
f"text[:100]='{user_message.text[:100]}', role='{user_message.role}', lang='{user_message.lang}'"
292+
)
291293
return user_message
292294

293295
@managed_tx_method(CommitMode.FLUSH)

backend/oasst_backend/task_repository.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
from datetime import timedelta
1+
from datetime import datetime, timedelta
22
from typing import Optional
33
from uuid import UUID
44

55
import oasst_backend.models.db_payload as db_payload
66
from loguru import logger
7+
from oasst_backend.config import settings
78
from oasst_backend.models import ApiClient, Task
89
from oasst_backend.models.payload_column_type import PayloadContainer
910
from oasst_backend.user_repository import UserRepository
1011
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
1112
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
1213
from oasst_shared.schemas import protocol as protocol_schema
13-
from sqlmodel import Session, func, or_
14+
from oasst_shared.utils import utcnow
15+
from sqlmodel import Session, delete, func, or_
1416
from starlette.status import HTTP_404_NOT_FOUND
1517

1618

@@ -24,6 +26,13 @@ def validate_frontend_message_id(message_id: str) -> None:
2426
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
2527

2628

29+
def delete_expired_tasks(session: Session) -> int:
30+
stm = delete(Task).where(Task.expiry_date < utcnow())
31+
result = session.exec(stm)
32+
logger.info(f"Deleted {result.rowcount} expired tasks.")
33+
return result.rowcount
34+
35+
2736
class TaskRepository:
2837
def __init__(
2938
self,
@@ -118,12 +127,18 @@ def store_task(
118127
case _:
119128
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
120129

130+
if not collective and settings.TASK_VALIDITY_MINUTES > 0:
131+
expiry_date = utcnow() + timedelta(minutes=settings.TASK_VALIDITY_MINUTES)
132+
else:
133+
expiry_date = None
134+
121135
task_model = self.insert_task(
122136
payload=payload,
123137
id=task.id,
124138
message_tree_id=message_tree_id,
125139
parent_message_id=parent_message_id,
126140
collective=collective,
141+
expiry_date=expiry_date,
127142
)
128143
assert task_model.id == task.id
129144
return task_model
@@ -175,6 +190,7 @@ def insert_task(
175190
message_tree_id: UUID = None,
176191
parent_message_id: UUID = None,
177192
collective: bool = False,
193+
expiry_date: datetime = None,
178194
) -> Task:
179195
c = PayloadContainer(payload=payload)
180196
task = Task(
@@ -186,6 +202,7 @@ def insert_task(
186202
message_tree_id=message_tree_id,
187203
parent_message_id=parent_message_id,
188204
collective=collective,
205+
expiry_date=expiry_date,
189206
)
190207
logger.debug(f"inserting {task=}")
191208
self.db.add(task)
@@ -218,3 +235,6 @@ def fetch_recent_reply_tasks(
218235
if limit:
219236
qry = qry.limit(limit)
220237
return qry.all()
238+
239+
def delete_expired_tasks(self) -> int:
240+
return delete_expired_tasks(self.db)

backend/oasst_backend/tree_manager.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import pydantic
12+
import sqlalchemy as sa
1213
from fastapi.encoders import jsonable_encoder
1314
from loguru import logger
1415
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
@@ -31,6 +32,7 @@
3132
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
3233
from oasst_shared.schemas import protocol as protocol_schema
3334
from oasst_shared.utils import utcnow
35+
from sqlalchemy.sql.functions import coalesce
3436
from sqlmodel import Session, and_, func, not_, or_, text, update
3537

3638

@@ -269,13 +271,39 @@ def _prompt_lottery(self, lang: str) -> int:
269271
self._enter_state(mts, message_tree_state.State.GROWING)
270272
self.db.flush()
271273

274+
def _auto_moderation(self, lang: str) -> None:
275+
if not self.cfg.auto_mod_enabled:
276+
return
277+
278+
bad_messages = self.query_moderation_bad_messages(lang=lang)
279+
for m in bad_messages:
280+
num_red_flag = m.emojis.get(protocol_schema.EmojiCode.red_flag)
281+
282+
if num_red_flag is not None and num_red_flag >= self.cfg.auto_mod_red_flags:
283+
if m.parent_id is None:
284+
logger.warning(
285+
f"[AUTO MOD] Halting tree {m.message_tree_id}, inital prompt got too many red flags ({m.emojis})."
286+
)
287+
self.enter_low_grade_state(m.message_tree_id)
288+
else:
289+
logger.warning(f"[AUTO MOD] Deleting message {m.id=}, it received too many red flags ({m.emojis}).")
290+
self.pr.mark_messages_deleted(m.id, recursive=True)
291+
292+
num_skip_reply = m.emojis.get(protocol_schema.EmojiCode.skip_reply)
293+
if num_skip_reply is not None and num_skip_reply >= self.cfg.auto_mod_max_skip_reply:
294+
logger.warning(
295+
f"[AUTO MOD] Halting tree {m.message_tree_id} due to high skip-reply count of message {m.id=} ({m.emojis})."
296+
)
297+
self.halt_tree(m.id, halt=True)
298+
272299
def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
273300
self.pr.ensure_user_is_enabled()
274301

275302
if not lang:
276303
lang = "en"
277304
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
278305

306+
self._auto_moderation(lang=lang)
279307
num_missing_prompts = self._prompt_lottery(lang=lang)
280308
extendible_parents, _ = self.query_extendible_parents(lang=lang)
281309
prompts_need_review = self.query_prompts_need_review(lang=lang)
@@ -313,6 +341,7 @@ def next_task(
313341
lang = "en"
314342
logger.warning("Task request without lang tag received, assuming 'en'.")
315343

344+
self._auto_moderation(lang=lang)
316345
num_missing_prompts = self._prompt_lottery(lang=lang)
317346

318347
prompts_need_review = self.query_prompts_need_review(lang=lang)
@@ -1254,6 +1283,37 @@ def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
12541283
)
12551284
return qry.all()
12561285

1286+
def query_moderation_bad_messages(self, lang: str) -> list[Message]:
1287+
qry = (
1288+
self.db.query(Message)
1289+
.select_from(MessageTreeState)
1290+
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
1291+
.filter(
1292+
MessageTreeState.active,
1293+
or_(
1294+
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
1295+
MessageTreeState.state == message_tree_state.State.GROWING,
1296+
),
1297+
or_(
1298+
Message.parent_id.is_(None),
1299+
Message.review_result,
1300+
and_(Message.parent_id.is_not(None), Message.review_count < self.cfg.num_reviews_reply),
1301+
),
1302+
not_(Message.deleted),
1303+
or_(
1304+
coalesce(Message.emojis[protocol_schema.EmojiCode.red_flag].cast(sa.Integer), 0)
1305+
>= self.cfg.auto_mod_red_flags,
1306+
coalesce(Message.emojis[protocol_schema.EmojiCode.skip_reply].cast(sa.Integer), 0)
1307+
>= self.cfg.auto_mod_max_skip_reply,
1308+
),
1309+
)
1310+
)
1311+
1312+
if lang is not None:
1313+
qry = qry.filter(Message.lang == lang)
1314+
1315+
return qry.all()
1316+
12571317
@managed_tx_method(CommitMode.FLUSH)
12581318
def _insert_tree_state(
12591319
self,
@@ -1281,10 +1341,17 @@ def _insert_default_state(
12811341
self,
12821342
root_message_id: UUID,
12831343
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
1344+
*,
1345+
goal_tree_size: int = None,
12841346
) -> MessageTreeState:
1347+
if goal_tree_size is None:
1348+
if self.cfg.random_goal_tree_size and self.cfg.min_goal_tree_size < self.cfg.goal_tree_size:
1349+
goal_tree_size = random.randint(self.cfg.min_goal_tree_size, self.cfg.goal_tree_size)
1350+
else:
1351+
goal_tree_size = self.cfg.goal_tree_size
12851352
return self._insert_tree_state(
12861353
root_message_id=root_message_id,
1287-
goal_tree_size=self.cfg.goal_tree_size,
1354+
goal_tree_size=goal_tree_size,
12881355
max_depth=self.cfg.max_tree_depth,
12891356
max_children_count=self.cfg.max_children_count,
12901357
state=state,
@@ -1379,9 +1446,32 @@ def _purge_message_internal(self, message_id: UUID) -> None:
13791446
DELETE FROM task t WHERE t.parent_message_id = :message_id;
13801447
DELETE FROM message WHERE id = :message_id;
13811448
"""
1449+
parent_id = self.pr.fetch_message(message_id=message_id).parent_id
13821450
r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
13831451
logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")
13841452

1453+
sql_update_ranking_counts = """
1454+
WITH r AS (
1455+
-- find ranking results and count per child
1456+
SELECT c.id,
1457+
count(*) FILTER (
1458+
WHERE mr.payload#>'{payload, ranked_message_ids}' ? CAST(c.id AS varchar)
1459+
) AS ranking_count
1460+
FROM message c
1461+
LEFT JOIN message_reaction mr ON mr.payload_type = 'RankingReactionPayload'
1462+
AND mr.message_id = c.parent_id
1463+
WHERE c.parent_id = :parent_id
1464+
GROUP BY c.id
1465+
)
1466+
UPDATE message m SET ranking_count = r.ranking_count
1467+
FROM r WHERE m.id = r.id AND m.ranking_count != r.ranking_count;
1468+
"""
1469+
1470+
if parent_id is not None:
1471+
# update ranking counts of remaining children
1472+
r = self.db.execute(text(sql_update_ranking_counts), {"parent_id": parent_id})
1473+
logger.debug(f"ranking_count updated for {r.rowcount} rows.")
1474+
13851475
def purge_message_tree(self, message_tree_id: UUID) -> None:
13861476
sql_purge_message_tree = """
13871477
DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;

0 commit comments

Comments
 (0)