Skip to content

Commit

Permalink
Exclude extendible parents with young reply tasks (#1196)
Browse files Browse the repository at this point in the history
* exclude extendible parents with young reply tasks

* fix typos
  • Loading branch information
andreaskoepf authored Feb 5, 2023
1 parent 8bc7d08 commit 2434d44
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 20 deletions.
2 changes: 1 addition & 1 deletion backend/oasst_backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class TreeManagerConfiguration(BaseModel):
p_lonely_child_extension: float = 0.75
"""Probability to select a prompter message parent with less than lonely_children_count children."""

recent_tasks_span_sec: int = 3 * 60 # 3 min
recent_tasks_span_sec: int = 5 * 60 # 5 min
"""Time in seconds of recent tasks to consider for exclusion during task selection."""


Expand Down
2 changes: 1 addition & 1 deletion backend/oasst_backend/task_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def fetch_recent_reply_tasks(
self, max_age: timedelta = timedelta(minutes=5), done: bool = False, skipped: bool = False, limit: int = 100
) -> list[Task]:
qry = self.db.query(Task).filter(
func.age(Task.created_date) < max_age,
func.age(func.current_timestamp(), Task.created_date) < max_age,
or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
)
if done is not None:
Expand Down
33 changes: 15 additions & 18 deletions backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,14 +562,6 @@ def next_task(

case TaskType.REPLY:

recent_reply_tasks = self.pr.task_repository.fetch_recent_reply_tasks(
max_age=timedelta(seconds=self.cfg.recent_tasks_span_sec),
done=False,
skipped=False,
limit=500,
)
recent_reply_task_parents = {t.parent_message_id for t in recent_reply_tasks}

if task_role == TaskRole.PROMPTER:
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
elif task_role == TaskRole.ASSISTANT:
Expand All @@ -580,24 +572,17 @@ def next_task(
random_parent: ExtendibleParentRow = None
if self.cfg.p_lonely_child_extension > 0 and self.cfg.lonely_children_count > 1:
# check if we have extendible prompter parents with a small number of replies

lonely_children_parents = [
p
for p in extendible_parents
if 0 < p.active_children_count < self.cfg.lonely_children_count
and p.parent_role == "prompter"
and p.parent_id not in recent_reply_task_parents
]
if len(lonely_children_parents) > 0 and random.random() < self.cfg.p_lonely_child_extension:
random_parent = random.choice(lonely_children_parents)

if random_parent is None:
# try to exclude parents for which tasks were recently handed out
fresh_parents = [p for p in extendible_parents if p.parent_id not in recent_reply_task_parents]
if len(fresh_parents) > 0:
random_parent = random.choice(fresh_parents)
else:
random_parent = random.choice(extendible_parents)
random_parent = random.choice(extendible_parents)

# fetch random conversation to extend
logger.debug(f"selected {random_parent=}")
Expand Down Expand Up @@ -895,7 +880,7 @@ def update_message_ranks(
logger.warning("The intersection of ranking results ID sets has less than two elements. Skipping.")
continue

# keep only elements in command set
# keep only elements in common set
ordered_ids_list = [list(filter(lambda x: x in common_set, ids)) for ids in ordered_ids_list]
assert all(len(x) == len(common_set) for x in ordered_ids_list)

Expand Down Expand Up @@ -1087,14 +1072,23 @@ def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:

_sql_find_extendible_parents = """
-- find all extendible parent nodes
WITH recent_reply_tasks (parent_message_id) AS (
-- recent incomplete tasks to exclude
SELECT parent_message_id FROM task
WHERE not done
AND not skipped
AND created_date > (CURRENT_TIMESTAMP - :recent_tasks_interval)
AND (payload_type = 'AssistantReplyPayload' OR payload_type = 'PrompterReplyPayload')
)
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
FROM message_tree_state mts
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
LEFT JOIN message_emoji me ON
(m.id = me.message_id
AND :skip_user_id IS NOT NULL
AND me.user_id = :skip_user_id
AND me.emoji = :skip_reply)
LEFT JOIN recent_reply_tasks rrt ON m.id = rrt.parent_message_id -- recent tasks
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
WHERE mts.active -- only consider active trees
AND mts.state = :growing_state -- message tree must be growing
Expand All @@ -1103,6 +1097,7 @@ def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
AND m.review_result -- parent node must have positive review
AND m.lang = :lang -- parent matches lang
AND me.message_id IS NULL -- no skip reply emoji for user
AND rrt.parent_message_id IS NULL -- no recent reply task found
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
Expand All @@ -1125,6 +1120,7 @@ def query_extendible_parents(self, lang: str) -> tuple[list[ExtendibleParentRow]
"user_id": user_id,
"skip_user_id": self.pr.user_id,
"skip_reply": protocol_schema.EmojiCode.skip_reply,
"recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec),
},
)

Expand Down Expand Up @@ -1165,6 +1161,7 @@ def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]:
"user_id": user_id,
"skip_user_id": self.pr.user_id,
"skip_reply": protocol_schema.EmojiCode.skip_reply,
"recent_tasks_interval": timedelta(seconds=self.cfg.recent_tasks_span_sec),
},
)
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
Expand Down

0 comments on commit 2434d44

Please sign in to comment.