18
18
)
19
19
from oasst_backend .models .db_payload import (
20
20
LabelAssistantReplyPayload ,
21
+ LabelInitialPromptPayload ,
21
22
LabelPrompterReplyPayload ,
22
23
RankingReactionPayload ,
23
24
)
24
25
from oasst_backend .models .message_tree_state import State as TreeState
25
- from oasst_shared .schemas .protocol import EmojiCode , LeaderboardStats , TextLabel , TrollboardStats , TrollScore , UserScore
26
+ from oasst_shared .schemas .protocol import (
27
+ EmojiCode ,
28
+ LabelTaskMode ,
29
+ LeaderboardStats ,
30
+ TextLabel ,
31
+ TrollboardStats ,
32
+ TrollScore ,
33
+ UserScore ,
34
+ )
26
35
from oasst_shared .utils import log_timing , utcnow
27
36
from sqlalchemy .dialects import postgresql
28
37
from sqlalchemy .sql .functions import coalesce
@@ -310,9 +319,9 @@ def get_stats(id: UUID) -> UserStats:
310
319
for r in qry :
311
320
uid , mode , count = r
312
321
s = get_stats (uid )
313
- if mode == " simple" :
322
+ if mode == LabelTaskMode . simple :
314
323
s .labels_simple = count
315
- elif mode == " full" :
324
+ elif mode == LabelTaskMode . full :
316
325
s .labels_full = count
317
326
318
327
qry = self .query_labels_by_mode_per_user (
@@ -321,9 +330,20 @@ def get_stats(id: UUID) -> UserStats:
321
330
for r in qry :
322
331
uid , mode , count = r
323
332
s = get_stats (uid )
324
- if mode == "simple" :
333
+ if mode == LabelTaskMode .simple :
334
+ s .labels_simple += count
335
+ elif mode == LabelTaskMode .full :
336
+ s .labels_full += count
337
+
338
+ qry = self .query_labels_by_mode_per_user (
339
+ payload_type = LabelInitialPromptPayload .__name__ , reference_time = base_date
340
+ )
341
+ for r in qry :
342
+ uid , mode , count = r
343
+ s = get_stats (uid )
344
+ if mode == LabelTaskMode .simple :
325
345
s .labels_simple += count
326
- elif mode == " full" :
346
+ elif mode == LabelTaskMode . full :
327
347
s .labels_full += count
328
348
329
349
qry = self .query_rankings_per_user (reference_time = base_date )
0 commit comments