Skip to content

Commit

Permalink
Limit initial prompts, ensure max_active_trees = growing trees
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Feb 4, 2023
1 parent 280979c commit 100ecbe
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 15 deletions.
1 change: 1 addition & 0 deletions .github/workflows/deploy-to-node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }}
AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }}
MAX_ACTIVE_TREES: ${{ vars.MAX_ACTIVE_TREES }}
MAX_INITIAL_PROMPT_REVIEW: ${{ vars.MAX_INITIAL_PROMPT_REVIEW }}
MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }}
MAX_CHILDREN_COUNT: ${{ vars.MAX_CHILDREN_COUNT }}
LONELY_CHILDREN_COUNT: ${{ vars.LONELY_CHILDREN_COUNT }}
Expand Down
3 changes: 3 additions & 0 deletions ansible/deploy-to-node.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@
TREE_MANAGER__MAX_ACTIVE_TREES:
"{{ lookup('ansible.builtin.env', 'MAX_ACTIVE_TREES') |
default('10', true) }}"
TREE_MANAGER__MAX_INITIAL_PROMPT_REVIEW:
"{{ lookup('ansible.builtin.env', 'MAX_INITIAL_PROMPT_REVIEW') |
default('100', true) }}"
TREE_MANAGER__MAX_TREE_DEPTH:
"{{ lookup('ansible.builtin.env', 'MAX_TREE_DEPTH') | default('5',
true) }}"
Expand Down
3 changes: 3 additions & 0 deletions backend/oasst_backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class TreeManagerConfiguration(BaseModel):
No new initial prompt tasks are handed out to users if this
number is reached."""

max_initial_prompt_review: int = 100
"""Maximum number of initial prompts under review before no more inital prompt tasks will be handed out."""

max_tree_depth: int = 3
"""Maximum depth of message tree."""

Expand Down
88 changes: 74 additions & 14 deletions backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ class TaskRole(Enum):
ASSISTANT = 2


class TreeStateStats(pydantic.BaseModel):
initial_prompt_review: int
growing: int
ranking: int
ready_for_scoring: int
scoring_failed: int
ready_for_export: int
aborted_low_grade: int
halted_by_moderator: int
backlog_ranking: int
prompt_lottery_waiting: int


class ActiveTreeSizeRow(pydantic.BaseModel):
message_tree_id: UUID
goal_tree_size: int
Expand Down Expand Up @@ -173,7 +186,7 @@ def _determine_task_availability_internal(
) -> dict[protocol_schema.TaskRequestType, int]:
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}

task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = max(1, num_missing_prompts)
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts

task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
Expand Down Expand Up @@ -209,10 +222,15 @@ def _prompt_lottery(self, lang: str) -> int:
MAX_RETRIES = 5
retry = 0
while True:
num_active_trees = self.query_num_active_trees(lang=lang, exclude_ranking=True)
num_missing_prompts = self.cfg.max_active_trees - num_active_trees
if num_missing_prompts <= 0:
return 0

stats = self.tree_counts_by_state_stats(lang=lang, only_active=True)

remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review)
num_missing_growing = max(0, self.cfg.max_active_trees - stats.growing)
logger.debug(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")

if num_missing_growing == 0:
return remaining_prompt_review

# select among distinct users
authors_qry = (
Expand All @@ -231,9 +249,9 @@ def _prompt_lottery(self, lang: str) -> int:
author_ids = authors_qry.all()
if len(author_ids) == 0:
logger.info(
f"No prompts for prompt lottery available ({num_missing_prompts} trees missing for {lang=})."
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
)
return num_missing_prompts
return num_missing_growing + remaining_prompt_review

# first select an authour
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
Expand Down Expand Up @@ -261,7 +279,7 @@ def _prompt_lottery(self, lang: str) -> int:
continue
else:
logger.warning("Max retries in prompt lottery reached.")
return num_missing_prompts
return num_missing_growing + remaining_prompt_review

winner_prompt = random.choice(prompt_candidates)
message: Message = winner_prompt.Message
Expand Down Expand Up @@ -1268,15 +1286,32 @@ def ensure_tree_states(self) -> None:
for t in ranking_trees:
self.check_condition_for_scoring_state(t.message_tree_id)

def query_num_active_trees(self, lang: str, exclude_ranking: bool = True) -> int:
"""Count all active trees (optionally exclude those in ranking state)."""
def query_num_growing_trees(self, lang: str) -> int:
"""Count all active trees in growing state."""
query = (
self.db.query(func.count(MessageTreeState.message_tree_id))
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(
MessageTreeState.active,
MessageTreeState.state == message_tree_state.State.GROWING,
Message.lang == lang,
)
)
return query.scalar()

def query_num_active_trees(
self, lang: str, exclude_ranking: bool = True, exclude_prompt_review: bool = True
) -> int:
"""Count all active trees (optionally exclude those in ranking and initial prompt review states)."""
query = (
self.db.query(func.count(MessageTreeState.message_tree_id))
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(MessageTreeState.active, Message.lang == lang)
)
if exclude_ranking:
query = query.filter(MessageTreeState.state != message_tree_state.State.RANKING)
if exclude_prompt_review:
query = query.filter(MessageTreeState.state != message_tree_state.State.INITIAL_PROMPT_REVIEW)
return query.scalar()

def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
Expand Down Expand Up @@ -1363,12 +1398,37 @@ def _insert_default_state(
active=True,
)

def tree_counts_by_state(self) -> dict[str, int]:
qry = self.db.query(
MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
).group_by(MessageTreeState.state)
def tree_counts_by_state(self, lang: str = None, only_active: bool = False) -> dict[str, int]:
qry = self.db.query(MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count"))

if lang is not None:
qry = (
qry.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(Message.lang == lang)
)
if only_active:
qry = qry.filter(MessageTreeState.active)

qry = qry.group_by(MessageTreeState.state)
return {x["state"]: x["count"] for x in qry}

def tree_counts_by_state_stats(self, lang: str = None, only_active: bool = False) -> TreeStateStats:
count_by_state = self.tree_counts_by_state(lang=lang, only_active=only_active)
r = TreeStateStats(
initial_prompt_review=count_by_state.get(message_tree_state.State.INITIAL_PROMPT_REVIEW) or 0,
growing=count_by_state.get(message_tree_state.State.GROWING) or 0,
ranking=count_by_state.get(message_tree_state.State.RANKING) or 0,
ready_for_scoring=count_by_state.get(message_tree_state.State.READY_FOR_SCORING) or 0,
ready_for_export=count_by_state.get(message_tree_state.State.READY_FOR_EXPORT) or 0,
scoring_failed=count_by_state.get(message_tree_state.State.SCORING_FAILED) or 0,
halted_by_moderator=count_by_state.get(message_tree_state.State.HALTED_BY_MODERATOR) or 0,
backlog_ranking=count_by_state.get(message_tree_state.State.BACKLOG_RANKING) or 0,
prompt_lottery_waiting=count_by_state.get(message_tree_state.State.PROMPT_LOTTERY_WAITING) or 0,
aborted_low_grade=count_by_state.get(message_tree_state.State.ABORTED_LOW_GRADE) or 0,
)
return r

def tree_message_count_stats(self, only_active: bool = True) -> list[TreeMessageCountStats]:
qry = (
self.db.query(
Expand Down
2 changes: 1 addition & 1 deletion website/public/locales/en/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
"create_initial_prompt": {
"label": "Create Initial Prompts",
"desc": "Write initial prompts to help Open Assistant to try replying to diverse messages.",
"desc": "Write initial prompts to help Open Assistant to try replying to diverse messages. (enter into lottery)",
"overview": "Create an initial message to send to the assistant",
"instruction": "Provide the initial prompts",
"response_placeholder": "Write your prompt here..."
Expand Down

0 comments on commit 100ecbe

Please sign in to comment.