|
9 | 9 |
|
10 | 10 | import numpy as np
|
11 | 11 | import pydantic
|
| 12 | +import sqlalchemy as sa |
12 | 13 | from fastapi.encoders import jsonable_encoder
|
13 | 14 | from loguru import logger
|
14 | 15 | from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
|
31 | 32 | from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
32 | 33 | from oasst_shared.schemas import protocol as protocol_schema
|
33 | 34 | from oasst_shared.utils import utcnow
|
| 35 | +from sqlalchemy.sql.functions import coalesce |
34 | 36 | from sqlmodel import Session, and_, func, not_, or_, text, update
|
35 | 37 |
|
36 | 38 |
|
@@ -269,13 +271,39 @@ def _prompt_lottery(self, lang: str) -> int:
|
269 | 271 | self._enter_state(mts, message_tree_state.State.GROWING)
|
270 | 272 | self.db.flush()
|
271 | 273 |
|
| 274 | + def _auto_moderation(self, lang: str) -> None: |
| 275 | + if not self.cfg.auto_mod_enabled: |
| 276 | + return |
| 277 | + |
| 278 | + bad_messages = self.query_moderation_bad_messages(lang=lang) |
| 279 | + for m in bad_messages: |
| 280 | + num_red_flag = m.emojis.get(protocol_schema.EmojiCode.red_flag) |
| 281 | + |
| 282 | + if num_red_flag is not None and num_red_flag >= self.cfg.auto_mod_red_flags: |
| 283 | + if m.parent_id is None: |
| 284 | + logger.warning( |
| 285 | + f"[AUTO MOD] Halting tree {m.message_tree_id}, inital prompt got too many red flags ({m.emojis})." |
| 286 | + ) |
| 287 | + self.enter_low_grade_state(m.message_tree_id) |
| 288 | + else: |
| 289 | + logger.warning(f"[AUTO MOD] Deleting message {m.id=}, it received too many red flags ({m.emojis}).") |
| 290 | + self.pr.mark_messages_deleted(m.id, recursive=True) |
| 291 | + |
| 292 | + num_skip_reply = m.emojis.get(protocol_schema.EmojiCode.skip_reply) |
| 293 | + if num_skip_reply is not None and num_skip_reply >= self.cfg.auto_mod_max_skip_reply: |
| 294 | + logger.warning( |
| 295 | + f"[AUTO MOD] Halting tree {m.message_tree_id} due to high skip-reply count of message {m.id=} ({m.emojis})." |
| 296 | + ) |
| 297 | + self.halt_tree(m.id, halt=True) |
| 298 | + |
272 | 299 | def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
|
273 | 300 | self.pr.ensure_user_is_enabled()
|
274 | 301 |
|
275 | 302 | if not lang:
|
276 | 303 | lang = "en"
|
277 | 304 | logger.warning("Task availability request without lang tag received, assuming lang='en'.")
|
278 | 305 |
|
| 306 | + self._auto_moderation(lang=lang) |
279 | 307 | num_missing_prompts = self._prompt_lottery(lang=lang)
|
280 | 308 | extendible_parents, _ = self.query_extendible_parents(lang=lang)
|
281 | 309 | prompts_need_review = self.query_prompts_need_review(lang=lang)
|
@@ -313,6 +341,7 @@ def next_task(
|
313 | 341 | lang = "en"
|
314 | 342 | logger.warning("Task request without lang tag received, assuming 'en'.")
|
315 | 343 |
|
| 344 | + self._auto_moderation(lang=lang) |
316 | 345 | num_missing_prompts = self._prompt_lottery(lang=lang)
|
317 | 346 |
|
318 | 347 | prompts_need_review = self.query_prompts_need_review(lang=lang)
|
@@ -1254,6 +1283,37 @@ def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
|
1254 | 1283 | )
|
1255 | 1284 | return qry.all()
|
1256 | 1285 |
|
| 1286 | + def query_moderation_bad_messages(self, lang: str) -> list[Message]: |
| 1287 | + qry = ( |
| 1288 | + self.db.query(Message) |
| 1289 | + .select_from(MessageTreeState) |
| 1290 | + .join(Message, MessageTreeState.message_tree_id == Message.message_tree_id) |
| 1291 | + .filter( |
| 1292 | + MessageTreeState.active, |
| 1293 | + or_( |
| 1294 | + MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW, |
| 1295 | + MessageTreeState.state == message_tree_state.State.GROWING, |
| 1296 | + ), |
| 1297 | + or_( |
| 1298 | + Message.parent_id.is_(None), |
| 1299 | + Message.review_result, |
| 1300 | + and_(Message.parent_id.is_not(None), Message.review_count < self.cfg.num_reviews_reply), |
| 1301 | + ), |
| 1302 | + not_(Message.deleted), |
| 1303 | + or_( |
| 1304 | + coalesce(Message.emojis[protocol_schema.EmojiCode.red_flag].cast(sa.Integer), 0) |
| 1305 | + >= self.cfg.auto_mod_red_flags, |
| 1306 | + coalesce(Message.emojis[protocol_schema.EmojiCode.skip_reply].cast(sa.Integer), 0) |
| 1307 | + >= self.cfg.auto_mod_max_skip_reply, |
| 1308 | + ), |
| 1309 | + ) |
| 1310 | + ) |
| 1311 | + |
| 1312 | + if lang is not None: |
| 1313 | + qry = qry.filter(Message.lang == lang) |
| 1314 | + |
| 1315 | + return qry.all() |
| 1316 | + |
1257 | 1317 | @managed_tx_method(CommitMode.FLUSH)
|
1258 | 1318 | def _insert_tree_state(
|
1259 | 1319 | self,
|
@@ -1281,10 +1341,17 @@ def _insert_default_state(
|
1281 | 1341 | self,
|
1282 | 1342 | root_message_id: UUID,
|
1283 | 1343 | state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
| 1344 | + *, |
| 1345 | + goal_tree_size: int = None, |
1284 | 1346 | ) -> MessageTreeState:
|
| 1347 | + if goal_tree_size is None: |
| 1348 | + if self.cfg.random_goal_tree_size and self.cfg.min_goal_tree_size < self.cfg.goal_tree_size: |
| 1349 | + goal_tree_size = random.randint(self.cfg.min_goal_tree_size, self.cfg.goal_tree_size) |
| 1350 | + else: |
| 1351 | + goal_tree_size = self.cfg.goal_tree_size |
1285 | 1352 | return self._insert_tree_state(
|
1286 | 1353 | root_message_id=root_message_id,
|
1287 |
| - goal_tree_size=self.cfg.goal_tree_size, |
| 1354 | + goal_tree_size=goal_tree_size, |
1288 | 1355 | max_depth=self.cfg.max_tree_depth,
|
1289 | 1356 | max_children_count=self.cfg.max_children_count,
|
1290 | 1357 | state=state,
|
@@ -1379,9 +1446,32 @@ def _purge_message_internal(self, message_id: UUID) -> None:
|
1379 | 1446 | DELETE FROM task t WHERE t.parent_message_id = :message_id;
|
1380 | 1447 | DELETE FROM message WHERE id = :message_id;
|
1381 | 1448 | """
|
| 1449 | + parent_id = self.pr.fetch_message(message_id=message_id).parent_id |
1382 | 1450 | r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
|
1383 | 1451 | logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")
|
1384 | 1452 |
|
| 1453 | + sql_update_ranking_counts = """ |
| 1454 | +WITH r AS ( |
| 1455 | + -- find ranking results and count per child |
| 1456 | + SELECT c.id, |
| 1457 | + count(*) FILTER ( |
| 1458 | + WHERE mr.payload#>'{payload, ranked_message_ids}' ? CAST(c.id AS varchar) |
| 1459 | + ) AS ranking_count |
| 1460 | + FROM message c |
| 1461 | + LEFT JOIN message_reaction mr ON mr.payload_type = 'RankingReactionPayload' |
| 1462 | + AND mr.message_id = c.parent_id |
| 1463 | + WHERE c.parent_id = :parent_id |
| 1464 | + GROUP BY c.id |
| 1465 | +) |
| 1466 | +UPDATE message m SET ranking_count = r.ranking_count |
| 1467 | +FROM r WHERE m.id = r.id AND m.ranking_count != r.ranking_count; |
| 1468 | +""" |
| 1469 | + |
| 1470 | + if parent_id is not None: |
| 1471 | + # update ranking counts of remaining children |
| 1472 | + r = self.db.execute(text(sql_update_ranking_counts), {"parent_id": parent_id}) |
| 1473 | + logger.debug(f"ranking_count updated for {r.rowcount} rows.") |
| 1474 | + |
1385 | 1475 | def purge_message_tree(self, message_tree_id: UUID) -> None:
|
1386 | 1476 | sql_purge_message_tree = """
|
1387 | 1477 | DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
|
0 commit comments