26
26
)
27
27
from oasst_backend .prompt_repository import PromptRepository
28
28
from oasst_backend .utils import tree_export
29
- from oasst_backend .utils .database_utils import CommitMode , async_managed_tx_method , managed_tx_method
29
+ from oasst_backend .utils .database_utils import (
30
+ CommitMode ,
31
+ async_managed_tx_method ,
32
+ managed_tx_function ,
33
+ managed_tx_method ,
34
+ )
30
35
from oasst_backend .utils .hugging_face import HfClassificationModel , HfEmbeddingModel , HfUrl , HuggingFaceAPI
31
36
from oasst_backend .utils .ranking import ranked_pairs
32
37
from oasst_shared .exceptions .oasst_api_error import OasstError , OasstErrorCode
@@ -218,17 +223,12 @@ def _determine_task_availability_internal(
218
223
219
224
return task_count_by_type
220
225
221
- def _prompt_lottery (self , lang : str ) -> int :
222
- MAX_RETRIES = 5
223
-
226
+ def _prompt_lottery (self , lang : str , max_activate : int = 1 ) -> int :
224
227
# Under high load the DB runs into deadlocks when many trees are released
225
228
# simultaneously (happens whens the max_active_trees setting is increased).
226
229
# To reduce the chance of write conflicts during updates of rows in the
227
230
# message_tree_state table we limit the number of trees that are activated
228
- # per _prompt_lottery() call to MAX_ACTIVATE.
229
- MAX_ACTIVATE = 2
230
-
231
- retry = 0
231
+ # per _prompt_lottery() call to max_activate.
232
232
activated = 0
233
233
234
234
while True :
@@ -237,67 +237,76 @@ def _prompt_lottery(self, lang: str) -> int:
237
237
238
238
remaining_prompt_review = max (0 , self .cfg .max_initial_prompt_review - stats .initial_prompt_review )
239
239
num_missing_growing = max (0 , self .cfg .max_active_trees - stats .growing )
240
- logger .debug (f"_prompt_lottery { remaining_prompt_review = } , { num_missing_growing = } " )
240
+ logger .info (f"_prompt_lottery { remaining_prompt_review = } , { num_missing_growing = } " )
241
241
242
- if num_missing_growing == 0 or activated >= MAX_ACTIVATE :
242
+ if num_missing_growing == 0 or activated >= max_activate :
243
243
return num_missing_growing + remaining_prompt_review
244
244
245
- # select among distinct users
246
- authors_qry = (
247
- self .db .query (Message .user_id )
248
- .select_from (MessageTreeState )
249
- .join (Message , MessageTreeState .message_tree_id == Message .id )
250
- .filter (
251
- MessageTreeState .state == message_tree_state .State .PROMPT_LOTTERY_WAITING ,
252
- Message .lang == lang ,
253
- not_ (Message .deleted ),
254
- Message .review_result ,
245
+ @managed_tx_function (CommitMode .COMMIT )
246
+ def activate_one (db : Session ) -> int :
247
+ # select among distinct users
248
+ authors_qry = (
249
+ db .query (Message .user_id )
250
+ .select_from (MessageTreeState )
251
+ .join (Message , MessageTreeState .message_tree_id == Message .id )
252
+ .filter (
253
+ MessageTreeState .state == message_tree_state .State .PROMPT_LOTTERY_WAITING ,
254
+ Message .lang == lang ,
255
+ not_ (Message .deleted ),
256
+ Message .review_result ,
257
+ )
258
+ .distinct (Message .user_id )
255
259
)
256
- .distinct (Message .user_id )
257
- )
258
260
259
- author_ids = authors_qry .all ()
260
- if len (author_ids ) == 0 :
261
- logger .info (
262
- f"No prompts for prompt lottery available ({ num_missing_growing = } , trees missing for { lang = } )."
261
+ author_ids = authors_qry .all ()
262
+ if len (author_ids ) == 0 :
263
+ logger .info (
264
+ f"No prompts for prompt lottery available ({ num_missing_growing = } , trees missing for { lang = } )."
265
+ )
266
+ return False
267
+
268
+ # first select an authour
269
+ prompt_author_id : UUID = random .choice (author_ids )["user_id" ]
270
+ logger .info (f"Selected random prompt author { prompt_author_id } among { len (author_ids )} candidates." )
271
+
272
+ # select random prompt of author
273
+ qry = (
274
+ db .query (MessageTreeState , Message )
275
+ .select_from (MessageTreeState )
276
+ .join (Message , MessageTreeState .message_tree_id == Message .id )
277
+ .filter (
278
+ MessageTreeState .state == message_tree_state .State .PROMPT_LOTTERY_WAITING ,
279
+ Message .user_id == prompt_author_id ,
280
+ Message .lang == lang ,
281
+ not_ (Message .deleted ),
282
+ Message .review_result ,
283
+ )
284
+ .limit (100 )
263
285
)
264
- return num_missing_growing + remaining_prompt_review
265
286
266
- # first select an authour
267
- prompt_author_id : UUID = random .choice (author_ids )["user_id" ]
268
- logger .info (f"Selected random prompt author { prompt_author_id } among { len (author_ids )} candidates." )
287
+ prompt_candidates = qry .all ()
288
+ if len (prompt_candidates ) == 0 :
289
+ logger .warning ("No prompt candidates of selected author found." )
290
+ return False
269
291
270
- # select random prompt of author
271
- qry = (
272
- self .db .query (MessageTreeState , Message )
273
- .select_from (MessageTreeState )
274
- .join (Message , MessageTreeState .message_tree_id == Message .id )
275
- .filter (
276
- MessageTreeState .state == message_tree_state .State .PROMPT_LOTTERY_WAITING ,
277
- Message .user_id == prompt_author_id ,
278
- Message .lang == lang ,
279
- not_ (Message .deleted ),
280
- Message .review_result ,
281
- )
282
- .limit (100 )
283
- )
292
+ winner_prompt = random .choice (prompt_candidates )
293
+ message : Message = winner_prompt .Message
294
+ logger .info (f"Prompt lottery winner: { message .id = } " )
284
295
285
- prompt_candidates = qry .all ()
286
- if len (prompt_candidates ) == 0 :
287
- retry += 1 # not sure if this can happen with repeatable read isolation level, just in case we retry
288
- if retry < MAX_RETRIES :
289
- continue
290
- else :
291
- logger .warning ("Max retries in prompt lottery reached." )
292
- return num_missing_growing + remaining_prompt_review
296
+ mts : MessageTreeState = winner_prompt .MessageTreeState
297
+ mts .state = message_tree_state .State .GROWING
298
+ mts .active = True
299
+ db .add (mts )
293
300
294
- winner_prompt = random .choice (prompt_candidates )
295
- message : Message = winner_prompt .Message
296
- logger .info (f"Prompt lottery winner: { message .id = } " )
301
+ if mts .won_prompt_lottery_date is None :
302
+ mts .won_prompt_lottery_date = utcnow ()
303
+ logger .info (f"Tree entered '{ mts .state } ' state ({ mts .message_tree_id = } )" )
304
+
305
+ return True
306
+
307
+ if not activate_one ():
308
+ return num_missing_growing + remaining_prompt_review
297
309
298
- mts : MessageTreeState = winner_prompt .MessageTreeState
299
- self ._enter_state (mts , message_tree_state .State .GROWING )
300
- self .db .flush ()
301
310
activated += 1
302
311
303
312
def _auto_moderation (self , lang : str ) -> None :
@@ -333,7 +342,7 @@ def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskReq
333
342
logger .warning ("Task availability request without lang tag received, assuming lang='en'." )
334
343
335
344
self ._auto_moderation (lang = lang )
336
- num_missing_prompts = self ._prompt_lottery (lang = lang )
345
+ num_missing_prompts = self ._prompt_lottery (lang = lang , max_activate = 1 )
337
346
extendible_parents , _ = self .query_extendible_parents (lang = lang )
338
347
prompts_need_review = self .query_prompts_need_review (lang = lang )
339
348
replies_need_review = self .query_replies_need_review (lang = lang )
@@ -371,7 +380,7 @@ def next_task(
371
380
logger .warning ("Task request without lang tag received, assuming 'en'." )
372
381
373
382
self ._auto_moderation (lang = lang )
374
- num_missing_prompts = self ._prompt_lottery (lang = lang )
383
+ num_missing_prompts = self ._prompt_lottery (lang = lang , max_activate = 2 )
375
384
376
385
prompts_need_review = self .query_prompts_need_review (lang = lang )
377
386
replies_need_review = self .query_replies_need_review (lang = lang )
0 commit comments