Skip to content

Commit c0bfaa6

Browse files
committed
fix unit test failed
Signed-off-by: ruit <[email protected]>
1 parent 9078e33 commit c0bfaa6

File tree

6 files changed

+53
-5
lines changed

6 files changed

+53
-5
lines changed

nemo_rl/data/datasets/preference_datasets/helpsteer3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class HelpSteer3Dataset:
5858

5959
def __init__(self) -> None:
6060
ds = load_dataset("nvidia/HelpSteer3", "preference")
61-
self.task_name = "helpsteer3"
61+
self.task_name = "HelpSteer3"
6262
self.formatted_ds = ds.map(to_preference_data_format)
6363
self.formatted_ds = self.formatted_ds.map(
6464
lambda _: {"task_name": self.task_name}

nemo_rl/data/datasets/response_datasets/helpsteer3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class HelpSteer3Dataset:
5353

5454
def __init__(self) -> None:
5555
ds = load_dataset("nvidia/HelpSteer3", "preference")
56-
self.task_name = "helpsteer3"
56+
self.task_name = "HelpSteer3"
5757
self.formatted_ds = ds.map(to_response_data_format)
5858
self.formatted_ds = self.formatted_ds.map(
5959
lambda _: {"task_name": self.task_name}

tests/functional/grpo.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ mkdir -p $EXP_DIR $LOG_DIR
1919

2020
cd $PROJECT_ROOT
2121
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
22-
$PROJECT_ROOT/examples/run_grpo_math.py \
22+
$PROJECT_ROOT/examples/run_grpo.py \
2323
policy.model_name=Qwen/Qwen3-0.6B \
2424
grpo.num_prompts_per_step=2 \
2525
grpo.num_generations_per_prompt=4 \

tests/functional/grpo_math_env.sh

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
4+
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
5+
# Mark the current repo as safe, since wandb fetches metadata about the repo
6+
git config --global --add safe.directory $PROJECT_ROOT
7+
8+
set -eou pipefail
9+
10+
EXP_NAME=$(basename $0 .sh)
11+
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
12+
LOG_DIR=$EXP_DIR/logs
13+
JSON_METRICS=$EXP_DIR/metrics.json
14+
RUN_LOG=$EXP_DIR/run.log
15+
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
16+
17+
rm -rf $EXP_DIR $LOG_DIR
18+
mkdir -p $EXP_DIR $LOG_DIR
19+
20+
cd $PROJECT_ROOT
21+
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
22+
$PROJECT_ROOT/examples/run_grpo_math.py \
23+
policy.model_name=Qwen/Qwen3-0.6B \
24+
grpo.num_prompts_per_step=2 \
25+
grpo.num_generations_per_prompt=4 \
26+
policy.train_global_batch_size=4 \
27+
policy.train_micro_batch_size=1 \
28+
cluster.gpus_per_node=2 \
29+
grpo.max_num_steps=2 \
30+
logger.tensorboard_enabled=true \
31+
logger.log_dir=$LOG_DIR \
32+
logger.wandb_enabled=false \
33+
logger.monitor_gpus=true \
34+
checkpointing.enabled=false \
35+
$@ \
36+
2>&1 | tee $RUN_LOG
37+
38+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
39+
40+
uv run tests/check_metrics.py $JSON_METRICS \
41+
'max(data["train/token_mult_prob_error"]) < 1.05'
42+

tests/unit/data/test_data_shuffle_reproducity.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,24 @@ def create_dataloader(
4343
"""Create a dataloader with consistent configuration for testing."""
4444
# Initialize dataset
4545
data = OpenMathInstruct2Dataset(seed=seed)
46+
task_name = (
47+
data.task_name if hasattr(data, "task_name") else data.task_spec.task_name
48+
)
4649

4750
# Setup tokenizer
4851
tokenizer = get_tokenizer(TOKENIZER_CONFIG)
4952

5053
# Configure task specification
5154
math_task_spec = TaskDataSpec(
52-
task_name="math",
55+
task_name=task_name,
5356
prompt_file=f"{os.path.dirname(os.path.abspath(__file__))}/../../../examples/prompts/cot.txt",
5457
system_prompt_file=None,
5558
)
5659

5760
task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = (
5861
defaultdict(lambda: (math_task_spec, math_hf_data_processor))
5962
)
60-
task_data_processors["math"] = (math_task_spec, math_hf_data_processor)
63+
task_data_processors[task_name] = (math_task_spec, math_hf_data_processor)
6164

6265
dataset = AllTaskProcessedDataset(
6366
dataset=data.formatted_ds["train"].select(range(1000)),

tests/unit/test_config_validation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
if not OmegaConf.has_resolver("mul"):
3636
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
3737

38+
if not OmegaConf.has_resolver("max"):
39+
OmegaConf.register_new_resolver("max", lambda a, b: max(a, b))
40+
3841

3942
def validate_config_section(
4043
section_config: Dict[str, Any],

0 commit comments

Comments
 (0)