File tree Expand file tree Collapse file tree 2 files changed +16
-17
lines changed Expand file tree Collapse file tree 2 files changed +16
-17
lines changed Original file line number Diff line number Diff line change @@ -31,11 +31,10 @@ uv run examples/run_grpo_math.py \
3131# Convert tensorboard logs to json
3232uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
3333
34- # Only run metrics if the target step is reached
35- # TODO(ahmadki): set metrics
36- # if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
37- # uv run tests/check_metrics.py $JSON_METRICS \
38- # 'data["train/token_mult_prob_error"]["30"] < 1.1' \
39- # 'data["train/reward"]["30"] > 0.43' \
40- # 'mean(data["timing/train/total_step_time"], -6, -1) < 220'
41- # fi
34+ Only run metrics if the target step is reached
35+ if [[ $( jq ' to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS ) -ge $MAX_STEPS ]]; then
36+ uv run tests/check_metrics.py $JSON_METRICS \
37+ ' data["train/token_mult_prob_error"]["30"] < 1.1' \
38+ ' data["train/reward"]["30"] > 0.43' \
39+ ' mean(data["timing/train/total_step_time"], -6, -1) < 305'
40+ fi
Original file line number Diff line number Diff line change @@ -4,10 +4,10 @@ source $SCRIPT_DIR/common.env
44
55# ===== BEGIN CONFIG =====
66NUM_NODES=8
7- STEPS_PER_RUN=80
8- MAX_STEPS=80
7+ STEPS_PER_RUN=200
8+ MAX_STEPS=200
99NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1 ) / STEPS_PER_RUN )) # Round up
10- NUM_MINUTES=30
10+ NUM_MINUTES=120
1111# ===== END CONFIG =====
1212
1313exit_if_max_steps_reached
@@ -32,9 +32,9 @@ uv run examples/run_sft.py \
3232uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
3333
3434# Only run metrics if the target step is reached
35- # # TODO(ahmadki): set metrics
36- # if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
37- # uv run tests/check_metrics.py $JSON_METRICS \
38- # 'data["train/loss "]["80 "] < 0.301 ' \
39- # ' data["validation/val_loss"]["80"] < 0.304 '
40- # fi
35+ if [[ $( jq ' to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max ' $JSON_METRICS ) -ge $MAX_STEPS ]] ; then
36+ uv run tests/check_metrics.py $JSON_METRICS \
37+ ' data["train/loss"]["200"] < 0.3 ' \
38+ ' data["validation/val_loss "]["200 "] < 0.3 ' \
39+ ' mean( data["timing/train/total_step_time"], -6, -1) < 20 '
40+ fi
You can’t perform that action at this time.
0 commit comments