From 2434d44c01ae4aa0f1d548384aa8a706d5e5fad8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sun, 5 Feb 2023 20:22:45 +0100 Subject: [PATCH] Exclude extendible parents with young reply tasks (#1196) * exclude extendible parents with young reply tasks * fix typos --- backend/oasst_backend/config.py | 2 +- backend/oasst_backend/task_repository.py | 2 +- backend/oasst_backend/tree_manager.py | 33 +++++++++++------------- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 2a21b3ecfa..87b14c6c68 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -138,7 +138,7 @@ class TreeManagerConfiguration(BaseModel): p_lonely_child_extension: float = 0.75 """Probability to select a prompter message parent with less than lonely_children_count children.""" - recent_tasks_span_sec: int = 3 * 60 # 3 min + recent_tasks_span_sec: int = 5 * 60 # 5 min """Time in seconds of recent tasks to consider for exclusion during task selection.""" diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index 99bb07bb97..2df9efa45a 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -225,7 +225,7 @@ def fetch_recent_reply_tasks( self, max_age: timedelta = timedelta(minutes=5), done: bool = False, skipped: bool = False, limit: int = 100 ) -> list[Task]: qry = self.db.query(Task).filter( - func.age(Task.created_date) < max_age, + func.age(func.current_timestamp(), Task.created_date) < max_age, or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"), ) if done is not None: diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 72a4c31f4c..9556b832e8 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -562,14 +562,6 @@ def next_task( case TaskType.REPLY: - recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks( - max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec), - done=False, - skipped=False, - limit=500, - ) - recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks} - if task_role == TaskRole.PROMPTER: extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents)) elif task_role == TaskRole.ASSISTANT: @@ -580,24 +572,17 @@ def next_task( random_parent: ExtendibleParentRow = None if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1: # check if we have extendible prompter parents with a small number of replies - lonely_children_parents = [ p for p in extendible_parents if 0 < p.active_children_count < self.cfg.lonely_children_count and p.parent_role == "prompter" - and p.parent_id not in recent_reply_task_parents ] if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension: random_parent = random.choice(lonely_children_parents) if random_parent is None: - # try to exclude parents for which tasks were recently handed out - fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents] - if len(fresh_parents) > 0: - random_parent = random.choice(fresh_parents) - else: - random_parent = random.choice(extendible_parents) + random_parent = random.choice(extendible_parents) # fetch random conversation to extend logger.debug(f"selected {random_parent=}") @@ -895,7 +880,7 @@ def update_message_ranks( logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.") continue - # keep only elements in command set + # keep only elements in common set ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list] assert all(len(x) == len(common_set) for x in ordered_ids_list) @@ -1087,14 +1072,23 @@ def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]: _sql_find_extendible_parents = """ -- find all extendible parent nodes +WITH recent_reply_tasks (parent_message_id) AS ( + -- recent incomplete tasks to exclude + SELECT parent_message_id FROM task + WHERE not done + AND not skipped + AND created_date > (CURRENT_TIMESTAMP - :recent_tasks_interval) + AND (payload_type = 'AssistantReplyPayload' OR payload_type = 'PrompterReplyPayload') +) SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count FROM message_tree_state mts - INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree + INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree LEFT JOIN message_emoji me ON (m.id = me.message_id AND :skip_user_id IS NOT NULL AND me.user_id = :skip_user_id AND me.emoji = :skip_reply) + LEFT JOIN recent_reply_tasks rrt ON m.id = rrt.parent_message_id -- recent tasks LEFT JOIN message c ON m.id = c.parent_id -- child nodes WHERE mts.active -- only consider active trees AND mts.state = :growing_state -- message tree must be growing @@ -1103,6 +1097,7 @@ def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]: AND m.review_result -- parent node must have positive review AND m.lang = :lang -- parent matches lang AND me.message_id IS NULL -- no skip reply emoji for user + AND rrt.parent_message_id IS NULL -- no recent reply task found AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count @@ -1125,6 +1120,7 @@ def query_extendible_parents(self, lang: str) -> tuple[list[ExtendibleParentRow] "user_id": user_id, "skip_user_id": self.pr.user_id, "skip_reply": protocol_schema.EmojiCode.skip_reply, + "recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec), }, ) @@ -1165,6 +1161,7 @@ def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]: "user_id": user_id, "skip_user_id": self.pr.user_id, "skip_reply": protocol_schema.EmojiCode.skip_reply, + "recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec), }, ) return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]