From 100ecbe2d53c35c00279b2d1d9aa31754ec265e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 4 Feb 2023 23:23:16 +0100 Subject: [PATCH] Limit initial prompts, ensure max_active_trees = growing trees --- .github/workflows/deploy-to-node.yaml | 1 + ansible/deploy-to-node.yaml | 3 + backend/oasst_backend/config.py | 3 + backend/oasst_backend/tree_manager.py | 88 ++++++++++++++++++++++----- website/public/locales/en/tasks.json | 2 +- 5 files changed, 82 insertions(+), 15 deletions(-) diff --git a/.github/workflows/deploy-to-node.yaml b/.github/workflows/deploy-to-node.yaml index 436e7f582a..d710eca61e 100644 --- a/.github/workflows/deploy-to-node.yaml +++ b/.github/workflows/deploy-to-node.yaml @@ -38,6 +38,7 @@ jobs: AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }} MAX_ACTIVE_TREES: ${{ vars.MAX_ACTIVE_TREES }} + MAX_INITIAL_PROMPT_REVIEW: ${{ vars.MAX_INITIAL_PROMPT_REVIEW }} MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }} MAX_CHILDREN_COUNT: ${{ vars.MAX_CHILDREN_COUNT }} LONELY_CHILDREN_COUNT: ${{ vars.LONELY_CHILDREN_COUNT }} diff --git a/ansible/deploy-to-node.yaml b/ansible/deploy-to-node.yaml index 639003a42f..958e5172ac 100644 --- a/ansible/deploy-to-node.yaml +++ b/ansible/deploy-to-node.yaml @@ -113,6 +113,9 @@ TREE_MANAGER__MAX_ACTIVE_TREES: "{{ lookup('ansible.builtin.env', 'MAX_ACTIVE_TREES') | default('10', true) }}" + TREE_MANAGER__MAX_INITIAL_PROMPT_REVIEW: + "{{ lookup('ansible.builtin.env', 'MAX_INITIAL_PROMPT_REVIEW') | + default('100', true) }}" TREE_MANAGER__MAX_TREE_DEPTH: "{{ lookup('ansible.builtin.env', 'MAX_TREE_DEPTH') | default('5', true) }}" diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index ff4c549715..4d6dea30da 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -13,6 +13,9 @@ class TreeManagerConfiguration(BaseModel): No new initial prompt tasks are handed out to users if this number is reached.""" + max_initial_prompt_review: int = 100 + """Maximum number of initial prompts under review before no more inital prompt tasks will be handed out.""" + max_tree_depth: int = 3 """Maximum depth of message tree.""" diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 6b4f523615..a6c803be62 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -51,6 +51,19 @@ class TaskRole(Enum): ASSISTANT = 2 +class TreeStateStats(pydantic.BaseModel): + initial_prompt_review: int + growing: int + ranking: int + ready_for_scoring: int + scoring_failed: int + ready_for_export: int + aborted_low_grade: int + halted_by_moderator: int + backlog_ranking: int + prompt_lottery_waiting: int + + class ActiveTreeSizeRow(pydantic.BaseModel): message_tree_id: UUID goal_tree_size: int @@ -173,7 +186,7 @@ def _determine_task_availability_internal( ) -> dict[protocol_schema.TaskRequestType, int]: task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType} - task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = max(1, num_missing_prompts) + task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len( list(filter(lambda x: x.parent_role == "assistant", extendible_parents)) @@ -209,10 +222,15 @@ def _prompt_lottery(self, lang: str) -> int: MAX_RETRIES = 5 retry = 0 while True: - num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True) - num_missing_prompts = self.cfg.max_active_trees - num_active_trees - if num_missing_prompts <= 0: - return 0 + + stats = self.tree_counts_by_state_stats(lang=lang, only_active=True) + + remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review) + num_missing_growing = max(0, self.cfg.max_active_trees - stats.growing) + logger.debug(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}") + + if num_missing_growing == 0: + return remaining_prompt_review # select among distinct users authors_qry = ( @@ -231,9 +249,9 @@ def _prompt_lottery(self, lang: str) -> int: author_ids = authors_qry.all() if len(author_ids) == 0: logger.info( - f"No prompts for prompt lottery available ({num_missing_prompts} trees missing for {lang=})." + f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})." ) - return num_missing_prompts + return num_missing_growing + remaining_prompt_review # first select an authour prompt_author_id: UUID = random.choice(author_ids)["user_id"] @@ -261,7 +279,7 @@ def _prompt_lottery(self, lang: str) -> int: continue else: logger.warning("Max retries in prompt lottery reached.") - return num_missing_prompts + return num_missing_growing + remaining_prompt_review winner_prompt = random.choice(prompt_candidates) message: Message = winner_prompt.Message @@ -1268,8 +1286,23 @@ def ensure_tree_states(self) -> None: for t in ranking_trees: self.check_condition_for_scoring_state(t.message_tree_id) - def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int: - """Count all active trees (optionally exclude those in ranking state).""" + def query_num_growing_trees(self, lang: str) -> int: + """Count all active trees in growing state.""" + query = ( + self.db.query(func.count(MessageTreeState.message_tree_id)) + .join(Message, MessageTreeState.message_tree_id == Message.id) + .filter( + MessageTreeState.active, + MessageTreeState.state == message_tree_state.State.GROWING, + Message.lang == lang, + ) + ) + return query.scalar() + + def query_num_active_trees( + self, lang: str, exclude_ranking: bool = True, exclude_prompt_review: bool = True + ) -> int: + """Count all active trees (optionally exclude those in ranking and initial prompt review states).""" query = ( self.db.query(func.count(MessageTreeState.message_tree_id)) .join(Message, MessageTreeState.message_tree_id == Message.id) @@ -1277,6 +1310,8 @@ def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int ) if exclude_ranking: query = query.filter(MessageTreeState.state != message_tree_state.State.RANKING) + if exclude_prompt_review: + query = query.filter(MessageTreeState.state != message_tree_state.State.INITIAL_PROMPT_REVIEW) return query.scalar() def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: @@ -1363,12 +1398,37 @@ def _insert_default_state( active=True, ) - def tree_counts_by_state(self) -> dict[str, int]: - qry = self.db.query( - MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count") - ).group_by(MessageTreeState.state) + def tree_counts_by_state(self, lang: str = None, only_active: bool = False) -> dict[str, int]: + qry = self.db.query(MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")) + + if lang is not None: + qry = ( + qry.select_from(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.id) + .filter(Message.lang == lang) + ) + if only_active: + qry = qry.filter(MessageTreeState.active) + + qry = qry.group_by(MessageTreeState.state) return {x["state"]: x["count"] for x in qry} + def tree_counts_by_state_stats(self, lang: str = None, only_active: bool = False) -> TreeStateStats: + count_by_state = self.tree_counts_by_state(lang=lang, only_active=only_active) + r = TreeStateStats( + initial_prompt_review=count_by_state.get(message_tree_state.State.INITIAL_PROMPT_REVIEW) or 0, + growing=count_by_state.get(message_tree_state.State.GROWING) or 0, + ranking=count_by_state.get(message_tree_state.State.RANKING) or 0, + ready_for_scoring=count_by_state.get(message_tree_state.State.READY_FOR_SCORING) or 0, + ready_for_export=count_by_state.get(message_tree_state.State.READY_FOR_EXPORT) or 0, + scoring_failed=count_by_state.get(message_tree_state.State.SCORING_FAILED) or 0, + halted_by_moderator=count_by_state.get(message_tree_state.State.HALTED_BY_MODERATOR) or 0, + backlog_ranking=count_by_state.get(message_tree_state.State.BACKLOG_RANKING) or 0, + prompt_lottery_waiting=count_by_state.get(message_tree_state.State.PROMPT_LOTTERY_WAITING) or 0, + aborted_low_grade=count_by_state.get(message_tree_state.State.ABORTED_LOW_GRADE) or 0, + ) + return r + def tree_message_count_stats(self, only_active: bool = True) -> list[TreeMessageCountStats]: qry = ( self.db.query( diff --git a/website/public/locales/en/tasks.json b/website/public/locales/en/tasks.json index 7fa7c45262..ee3da48c06 100644 --- a/website/public/locales/en/tasks.json +++ b/website/public/locales/en/tasks.json @@ -9,7 +9,7 @@ }, "create_initial_prompt": { "label": "Create Initial Prompts", - "desc": "Write initial prompts to help Open Assistant to try replying to diverse messages.", + "desc": "Write initial prompts to help Open Assistant to try replying to diverse messages. (enter into lottery)", "overview": "Create an initial message to send to the assistant", "instruction": "Provide the initial prompts", "response_placeholder": "Write your prompt here..."