Skip to content

Commit b7fedb9

Browse files
committed
unify run_grpo with multiple env
Signed-off-by: ruit <[email protected]>
1 parent 249788d commit b7fedb9

File tree

7 files changed

+84
-170
lines changed

7 files changed

+84
-170
lines changed

examples/configs/grpo_helpsteer3.yaml

Lines changed: 0 additions & 106 deletions
This file was deleted.

examples/configs/recipes/llm/sft-nemotron-super-49b-tulu-v3.yaml

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,36 @@
1-
defaults:
2-
- ../../sft.yaml
3-
- ../../sft_nemotron_super_49b_base.yaml
1+
defaults: ../../sft.yaml
42
sft:
53
max_num_steps: 50
64
val_period: 5
5+
val_global_batch_size: 128
6+
7+
checkpointing:
8+
checkpoint_dir: results/sft_nemotron_super_49b
9+
metric_name: val_loss
10+
keep_top_k: 100
11+
save_period: 500
12+
713
policy:
814
model_name: /lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf
915
max_total_sequence_length: 32768
16+
train_global_batch_size: 128
1017
dtensor_cfg:
1118
_v2: true
1219
activation_checkpointing: true
1320
context_parallel_size: 8
21+
tensor_parallel_size: 4
22+
custom_parallel_plan: examples.custom_parallel.llama_nemotron_super_49b_custom_plan.custom_parallel_plan
23+
dynamic_batching:
24+
train_mb_tokens: 4096
25+
logprob_mb_tokens: 8192
26+
make_sequence_length_divisible_by: ${max:${mul:${policy.dtensor_cfg.context_parallel_size},
27+
2}, ${policy.max_total_sequence_length}}
28+
max_grad_norm: null
1429
optimizer:
1530
kwargs:
1631
lr: 1.0e-05
32+
weight_decay: 0.01
33+
eps: 1.0e-08
1734
scheduler:
1835
- name: torch.optim.lr_scheduler.LinearLR
1936
kwargs:
@@ -28,8 +45,12 @@ policy:
2845
- 10
2946
data:
3047
dataset_name: tulu3_sft_mixture
48+
num_workers: 20
3149
test_size: 0.05
3250
logger:
51+
tensorboard_enabled: false
52+
monitor_gpus: false
53+
num_val_samples_to_print: 0
3354
wandb:
3455
project: nemotron-tulu-3-sft
3556
name: nemotron-tulu-3
Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,37 @@
1-
defaults:
2-
- ../../sft.yaml
3-
- ../../sft_nemotron_super_49b_base.yaml
1+
defaults: ../../sft.yaml
42
sft:
53
max_num_epochs: 3
4+
max_num_steps: 100
5+
val_global_batch_size: 128
6+
checkpointing:
7+
checkpoint_dir: results/sft-nemotron-super-49b
8+
metric_name: val_loss
9+
keep_top_k: 100
10+
save_period: 500
611
policy:
12+
model_name: /lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf
13+
max_total_sequence_length: 4096
14+
train_global_batch_size: 128
715
train_micro_batch_size: 8
16+
dtensor_cfg:
17+
_v2: true
18+
activation_checkpointing: true
19+
context_parallel_size: 2
20+
tensor_parallel_size: 4
21+
custom_parallel_plan: examples.custom_parallel.llama_nemotron_super_49b_custom_plan.custom_parallel_plan
22+
dynamic_batching:
23+
train_mb_tokens: 4096
24+
logprob_mb_tokens: 8192
25+
make_sequence_length_divisible_by: ${max:${mul:${policy.dtensor_cfg.context_parallel_size},
26+
2}, ${policy.max_total_sequence_length}}
27+
max_grad_norm: null
28+
optimizer:
29+
kwargs:
30+
lr: 2.0e-05
31+
weight_decay: 0.01
32+
eps: 1.0e-08
33+
data:
34+
num_workers: 20
835
logger:
936
tensorboard_enabled: false
1037
monitor_gpus: false
@@ -15,4 +42,7 @@ logger:
1542
tensorboard:
1643
log_dir: tb_logs-openmathinstruct-nemorl-1M_train
1744
mlflow:
45+
experiment_name: sft-nemotron-super-49b
1846
run_name: openmathinstruct-nemorl-1M_train
47+
cluster:
48+
gpus_per_node: 8

examples/configs/sft_nemotron_super_49b_base.yaml

Lines changed: 0 additions & 47 deletions
This file was deleted.

examples/run_grpo_math.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,15 @@ def setup_data(
8282

8383
# load dataset
8484
data: Any = load_response_dataset(data_config, seed)
85+
task_name = (
86+
data.task_name if hasattr(data, "task_name") else data.task_spec.task_name
87+
)
8588

8689
# data processor
8790
task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = (
8891
defaultdict(lambda: (math_task_spec, math_hf_data_processor))
8992
)
90-
task_data_processors["math"] = (math_task_spec, math_hf_data_processor)
93+
task_data_processors[task_name] = (math_task_spec, math_hf_data_processor)
9194

9295
# setup math environment
9396
math_env = MathEnvironment.options( # type: ignore # it's wrapped with ray.remote
@@ -120,7 +123,7 @@ def setup_data(
120123
val_dataset = None
121124

122125
task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env)
123-
task_to_env["math"] = math_env
126+
task_to_env[task_name] = math_env
124127
return dataset, val_dataset, task_to_env, task_to_env
125128

126129

nemo_rl/data/datasets/response_datasets/helpsteer3.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
from typing import Any
215

316
from absl import logging
@@ -7,11 +20,7 @@
720

821

922
# Choose the chosen response as the response and the rejected response as the target
10-
def to_response_data_format(
11-
data: dict[str, Any],
12-
) -> dict[
13-
str, list[dict[str, int | list[dict[str, str | Any]]]] | list[dict[str, str]]
14-
]:
23+
def to_response_data_format(data: dict[str, Any]) -> dict:
1524
response_1 = data["response1"]
1625
response_2 = data["response2"]
1726
overall_preference = data["overall_preference"]
@@ -28,10 +37,13 @@ def to_response_data_format(
2837
else:
2938
chosen = response_2
3039

40+
if isinstance(data["context"], str):
41+
context = [{"role": "user", "content": data["context"]}]
42+
else:
43+
context = data["context"]
44+
3145
return {
32-
"context": [{"role": "user", "content": data["context"]}]
33-
if isinstance(data["context"], str)
34-
else data["context"],
46+
"context": context,
3547
"response": [{"role": "assistant", "content": chosen}],
3648
}
3749

pyrefly.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ project-includes = [
6464
"nemo_rl/data/datasets/response_datasets/oai_format_dataset.py",
6565
"nemo_rl/data/datasets/response_datasets/oasst.py",
6666
"nemo_rl/data/datasets/response_datasets/openmathinstruct2.py",
67+
"nemo_rl/data/datasets/response_datasets/helpsteer3.py",
6768
"nemo_rl/data/datasets/response_datasets/refcoco.py",
6869
"nemo_rl/data/datasets/response_datasets/response_dataset.py",
6970
"nemo_rl/data/datasets/response_datasets/squad.py",
@@ -81,7 +82,7 @@ project-includes = [
8182
"nemo_rl/distributed/worker_group_utils.py",
8283
"nemo_rl/environments/__init__.py",
8384
"nemo_rl/environments/games/sliding_puzzle.py",
84-
"nemo_rl/environments/helpsteer3_environment.py",
85+
"nemo_rl/environments/code_jaccard_environment.py",
8586
"nemo_rl/environments/interfaces.py",
8687
"nemo_rl/environments/math_environment.py",
8788
"nemo_rl/environments/metrics.py",

0 commit comments

Comments
 (0)