Skip to content

Commit af0711e

Browse files
committed
include initial prompt review in user stats
1 parent f547aa0 commit af0711e

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

backend/oasst_backend/user_stats_repository.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,20 @@
1818
)
1919
from oasst_backend.models.db_payload import (
2020
LabelAssistantReplyPayload,
21+
LabelInitialPromptPayload,
2122
LabelPrompterReplyPayload,
2223
RankingReactionPayload,
2324
)
2425
from oasst_backend.models.message_tree_state import State as TreeState
25-
from oasst_shared.schemas.protocol import EmojiCode, LeaderboardStats, TextLabel, TrollboardStats, TrollScore, UserScore
26+
from oasst_shared.schemas.protocol import (
27+
EmojiCode,
28+
LabelTaskMode,
29+
LeaderboardStats,
30+
TextLabel,
31+
TrollboardStats,
32+
TrollScore,
33+
UserScore,
34+
)
2635
from oasst_shared.utils import log_timing, utcnow
2736
from sqlalchemy.dialects import postgresql
2837
from sqlalchemy.sql.functions import coalesce
@@ -310,9 +319,9 @@ def get_stats(id: UUID) -> UserStats:
310319
for r in qry:
311320
uid, mode, count = r
312321
s = get_stats(uid)
313-
if mode == "simple":
322+
if mode == LabelTaskMode.simple:
314323
s.labels_simple = count
315-
elif mode == "full":
324+
elif mode == LabelTaskMode.full:
316325
s.labels_full = count
317326

318327
qry = self.query_labels_by_mode_per_user(
@@ -321,9 +330,20 @@ def get_stats(id: UUID) -> UserStats:
321330
for r in qry:
322331
uid, mode, count = r
323332
s = get_stats(uid)
324-
if mode == "simple":
333+
if mode == LabelTaskMode.simple:
334+
s.labels_simple += count
335+
elif mode == LabelTaskMode.full:
336+
s.labels_full += count
337+
338+
qry = self.query_labels_by_mode_per_user(
339+
payload_type=LabelInitialPromptPayload.__name__, reference_time=base_date
340+
)
341+
for r in qry:
342+
uid, mode, count = r
343+
s = get_stats(uid)
344+
if mode == LabelTaskMode.simple:
325345
s.labels_simple += count
326-
elif mode == "full":
346+
elif mode == LabelTaskMode.full:
327347
s.labels_full += count
328348

329349
qry = self.query_rankings_per_user(reference_time=base_date)

0 commit comments

Comments
 (0)