Skip to content

Commit 263edba

Browse files
committed
move prompt lottery tree activation into separate transaction scope
1 parent ae8b9ea commit 263edba

File tree

2 files changed

+70
-61
lines changed

2 files changed

+70
-61
lines changed

backend/oasst_backend/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class TreeManagerConfiguration(BaseModel):
4747
"""Automatically set tree state to `halted_by_moderator` when more than the specified number
4848
of users skip replying to a message. (auto moderation)"""
4949

50-
auto_mod_red_flags: int = 3
50+
auto_mod_red_flags: int = 4
5151
"""Delete messages that receive more than this number of red flags if it is a reply or
5252
set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)"""
5353

backend/oasst_backend/tree_manager.py

Lines changed: 69 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
)
2727
from oasst_backend.prompt_repository import PromptRepository
2828
from oasst_backend.utils import tree_export
29-
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
29+
from oasst_backend.utils.database_utils import (
30+
CommitMode,
31+
async_managed_tx_method,
32+
managed_tx_function,
33+
managed_tx_method,
34+
)
3035
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
3136
from oasst_backend.utils.ranking import ranked_pairs
3237
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
@@ -218,17 +223,12 @@ def _determine_task_availability_internal(
218223

219224
return task_count_by_type
220225

221-
def _prompt_lottery(self, lang: str) -> int:
222-
MAX_RETRIES = 5
223-
226+
def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:
224227
# Under high load the DB runs into deadlocks when many trees are released
225228
# simultaneously (happens whens the max_active_trees setting is increased).
226229
# To reduce the chance of write conflicts during updates of rows in the
227230
# message_tree_state table we limit the number of trees that are activated
228-
# per _prompt_lottery() call to MAX_ACTIVATE.
229-
MAX_ACTIVATE = 2
230-
231-
retry = 0
231+
# per _prompt_lottery() call to max_activate.
232232
activated = 0
233233

234234
while True:
@@ -237,67 +237,76 @@ def _prompt_lottery(self, lang: str) -> int:
237237

238238
remaining_prompt_review = max(0, self.cfg.max_initial_prompt_review - stats.initial_prompt_review)
239239
num_missing_growing = max(0, self.cfg.max_active_trees - stats.growing)
240-
logger.debug(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")
240+
logger.info(f"_prompt_lottery {remaining_prompt_review=}, {num_missing_growing=}")
241241

242-
if num_missing_growing == 0 or activated >= MAX_ACTIVATE:
242+
if num_missing_growing == 0 or activated >= max_activate:
243243
return num_missing_growing + remaining_prompt_review
244244

245-
# select among distinct users
246-
authors_qry = (
247-
self.db.query(Message.user_id)
248-
.select_from(MessageTreeState)
249-
.join(Message, MessageTreeState.message_tree_id == Message.id)
250-
.filter(
251-
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
252-
Message.lang == lang,
253-
not_(Message.deleted),
254-
Message.review_result,
245+
@managed_tx_function(CommitMode.COMMIT)
246+
def activate_one(db: Session) -> int:
247+
# select among distinct users
248+
authors_qry = (
249+
db.query(Message.user_id)
250+
.select_from(MessageTreeState)
251+
.join(Message, MessageTreeState.message_tree_id == Message.id)
252+
.filter(
253+
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
254+
Message.lang == lang,
255+
not_(Message.deleted),
256+
Message.review_result,
257+
)
258+
.distinct(Message.user_id)
255259
)
256-
.distinct(Message.user_id)
257-
)
258260

259-
author_ids = authors_qry.all()
260-
if len(author_ids) == 0:
261-
logger.info(
262-
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
261+
author_ids = authors_qry.all()
262+
if len(author_ids) == 0:
263+
logger.info(
264+
f"No prompts for prompt lottery available ({num_missing_growing=}, trees missing for {lang=})."
265+
)
266+
return False
267+
268+
# first select an authour
269+
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
270+
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
271+
272+
# select random prompt of author
273+
qry = (
274+
db.query(MessageTreeState, Message)
275+
.select_from(MessageTreeState)
276+
.join(Message, MessageTreeState.message_tree_id == Message.id)
277+
.filter(
278+
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
279+
Message.user_id == prompt_author_id,
280+
Message.lang == lang,
281+
not_(Message.deleted),
282+
Message.review_result,
283+
)
284+
.limit(100)
263285
)
264-
return num_missing_growing + remaining_prompt_review
265286

266-
# first select an authour
267-
prompt_author_id: UUID = random.choice(author_ids)["user_id"]
268-
logger.info(f"Selected random prompt author {prompt_author_id} among {len(author_ids)} candidates.")
287+
prompt_candidates = qry.all()
288+
if len(prompt_candidates) == 0:
289+
logger.warning("No prompt candidates of selected author found.")
290+
return False
269291

270-
# select random prompt of author
271-
qry = (
272-
self.db.query(MessageTreeState, Message)
273-
.select_from(MessageTreeState)
274-
.join(Message, MessageTreeState.message_tree_id == Message.id)
275-
.filter(
276-
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
277-
Message.user_id == prompt_author_id,
278-
Message.lang == lang,
279-
not_(Message.deleted),
280-
Message.review_result,
281-
)
282-
.limit(100)
283-
)
292+
winner_prompt = random.choice(prompt_candidates)
293+
message: Message = winner_prompt.Message
294+
logger.info(f"Prompt lottery winner: {message.id=}")
284295

285-
prompt_candidates = qry.all()
286-
if len(prompt_candidates) == 0:
287-
retry += 1 # not sure if this can happen with repeatable read isolation level, just in case we retry
288-
if retry < MAX_RETRIES:
289-
continue
290-
else:
291-
logger.warning("Max retries in prompt lottery reached.")
292-
return num_missing_growing + remaining_prompt_review
296+
mts: MessageTreeState = winner_prompt.MessageTreeState
297+
mts.state = message_tree_state.State.GROWING
298+
mts.active = True
299+
db.add(mts)
293300

294-
winner_prompt = random.choice(prompt_candidates)
295-
message: Message = winner_prompt.Message
296-
logger.info(f"Prompt lottery winner: {message.id=}")
301+
if mts.won_prompt_lottery_date is None:
302+
mts.won_prompt_lottery_date = utcnow()
303+
logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})")
304+
305+
return True
306+
307+
if not activate_one():
308+
return num_missing_growing + remaining_prompt_review
297309

298-
mts: MessageTreeState = winner_prompt.MessageTreeState
299-
self._enter_state(mts, message_tree_state.State.GROWING)
300-
self.db.flush()
301310
activated += 1
302311

303312
def _auto_moderation(self, lang: str) -> None:
@@ -333,7 +342,7 @@ def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskReq
333342
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
334343

335344
self._auto_moderation(lang=lang)
336-
num_missing_prompts = self._prompt_lottery(lang=lang)
345+
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=1)
337346
extendible_parents, _ = self.query_extendible_parents(lang=lang)
338347
prompts_need_review = self.query_prompts_need_review(lang=lang)
339348
replies_need_review = self.query_replies_need_review(lang=lang)
@@ -371,7 +380,7 @@ def next_task(
371380
logger.warning("Task request without lang tag received, assuming 'en'.")
372381

373382
self._auto_moderation(lang=lang)
374-
num_missing_prompts = self._prompt_lottery(lang=lang)
383+
num_missing_prompts = self._prompt_lottery(lang=lang, max_activate=2)
375384

376385
prompts_need_review = self.query_prompts_need_review(lang=lang)
377386
replies_need_review = self.query_replies_need_review(lang=lang)

0 commit comments

Comments
 (0)