Skip to content

Commit c32778d

Browse files
authored
feat: Support for nano-v2 (#1514)
Signed-off-by: Yi-Fu Wu <[email protected]>
1 parent 775fc34 commit c32778d

File tree

9 files changed

+191
-5
lines changed

9 files changed

+191
-5
lines changed
Submodule Megatron-Bridge updated 48 files
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
max_num_steps: 30
4+
checkpointing:
5+
checkpoint_dir: results/grpo-nano-v2-12b-1n8g-megatron
6+
policy:
7+
model_name: nvidia/NVIDIA-Nemotron-Nano-12B-v2
8+
tokenizer:
9+
name: nvidia/NVIDIA-Nemotron-Nano-12B-v2
10+
optimizer: null
11+
megatron_cfg:
12+
enabled: true
13+
bias_activation_fusion: false
14+
tensor_model_parallel_size: 8
15+
dtensor_cfg:
16+
enabled: false
17+
make_sequence_length_divisible_by: 1
18+
generation:
19+
max_new_tokens: 512
20+
vllm_cfg:
21+
max_model_len: 512
22+
sequence_packing:
23+
enabled: false
24+
data:
25+
max_input_seq_length: 512
26+
logger:
27+
log_dir: logs/grpo-nano-v2-12b-1n8g-megatron
28+
wandb_enabled: true
29+
tensorboard_enabled: true
30+
wandb:
31+
project: nemo-rl
32+
name: grpo-nano-v2-12b-1n8g-megatron
33+
cluster:
34+
gpus_per_node: 8
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
max_num_steps: 30
4+
checkpointing:
5+
checkpoint_dir: results/grpo-nano-v2-12b-2n8g-fsdp2tp1
6+
policy:
7+
model_name: nvidia/NVIDIA-Nemotron-Nano-12B-v2
8+
tokenizer:
9+
name: nvidia/NVIDIA-Nemotron-Nano-12B-v2
10+
dtensor_cfg:
11+
cpu_offload: true
12+
activation_checkpointing: true
13+
dynamic_batching:
14+
enabled: true
15+
sequence_packing:
16+
enabled: false
17+
make_sequence_length_divisible_by: 1
18+
generation:
19+
max_new_tokens: 512
20+
vllm_cfg:
21+
max_model_len: 512
22+
scheduler:
23+
- name: "torch.optim.lr_scheduler.LinearLR"
24+
kwargs:
25+
start_factor: 0.1
26+
end_factor: 1.0
27+
total_iters: 13
28+
- name: "torch.optim.lr_scheduler.ConstantLR"
29+
kwargs:
30+
factor: 1.0
31+
total_iters: 10000000000
32+
- milestones: [13]
33+
data:
34+
max_input_seq_length: 512
35+
logger:
36+
log_dir: logs/grpo-nano-v2-12b-2n8g-fsdp2tp1
37+
wandb_enabled: true
38+
tensorboard_enabled: true
39+
wandb:
40+
project: nemo-rl
41+
name: grpo-nano-v2-12b-2n8g-fsdp2tp1
42+
cluster:
43+
gpus_per_node: 8
44+
num_nodes: 2

nemo_rl/models/megatron/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,17 @@ def forward_step_arbitrary_loss(
348348
if len(multimodal_data) > 0:
349349
position_ids = None
350350

351+
additional_kwargs = {}
352+
# Mamba models currently do not support packed_seq_params
353+
if packed_seq_params is not None:
354+
additional_kwargs["packed_seq_params"] = packed_seq_params
355+
351356
with straggler_timer:
352357
output_tensor = model(
353358
input_ids=input_ids_cp_sharded,
354359
position_ids=position_ids,
355360
attention_mask=attention_mask,
356-
packed_seq_params=packed_seq_params,
361+
**additional_kwargs,
357362
**multimodal_data,
358363
)
359364

nemo_rl/models/megatron/community_import.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def import_model_from_hf_name(
4242
# Keep track of defaults so can restore them to the config after loading the model
4343
orig_tensor_model_parallel_size = model_provider.tensor_model_parallel_size
4444
orig_pipeline_model_parallel_size = model_provider.pipeline_model_parallel_size
45+
orig_context_parallel_size = model_provider.context_parallel_size
4546
orig_expert_model_parallel_size = model_provider.expert_model_parallel_size
4647
orig_expert_tensor_parallel_size = model_provider.expert_tensor_parallel_size
4748
orig_num_layers_in_first_pipeline_stage = (
@@ -59,6 +60,7 @@ def import_model_from_hf_name(
5960
model_provider.pipeline_model_parallel_size = megatron_config[
6061
"pipeline_model_parallel_size"
6162
]
63+
model_provider.context_parallel_size = megatron_config["context_parallel_size"]
6264
model_provider.expert_model_parallel_size = megatron_config[
6365
"expert_model_parallel_size"
6466
]
@@ -83,6 +85,7 @@ def import_model_from_hf_name(
8385
config = megatron_model[0].config
8486
config.tensor_model_parallel_size = orig_tensor_model_parallel_size
8587
config.pipeline_model_parallel_size = orig_pipeline_model_parallel_size
88+
config.context_parallel_size = orig_context_parallel_size
8689
config.expert_model_parallel_size = orig_expert_model_parallel_size
8790
config.expert_tensor_parallel_size = orig_expert_tensor_parallel_size
8891
config.num_layers_in_first_pipeline_stage = orig_num_layers_in_first_pipeline_stage
@@ -123,6 +126,11 @@ def export_model_from_megatron(
123126

124127
# Export performs on CPU with proper distributed context
125128
with temporary_distributed_context(backend="gloo"):
129+
# Need to set model parallel cuda manual seed for mamba mixer
130+
from megatron.core.tensor_parallel import model_parallel_cuda_manual_seed
131+
132+
model_parallel_cuda_manual_seed(0)
133+
126134
# Load the Megatron model
127135
megatron_model = bridge.load_megatron_model(
128136
input_path, skip_temp_dist_context=True

nemo_rl/models/policy/megatron_policy_worker.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def freeze_moe_router(megatron_model):
269269
if hasattr(model_module, "language_model"):
270270
model_module = model_module.language_model
271271
for layer in model_module.decoder.layers:
272-
if hasattr(layer.mlp, "router"):
272+
if hasattr(layer, "mlp") and hasattr(layer.mlp, "router"):
273273
layer.mlp.router.weight.requires_grad = False
274274

275275
mixed_precision_wrapper = CustomFloat16Module
@@ -1271,12 +1271,17 @@ def forward_step_fn(
12711271
if len(multimodal_data) > 0:
12721272
position_ids = None
12731273

1274+
additional_kwargs = {}
1275+
# Mamba models currently do not support packed_seq_params
1276+
if packed_seq_params is not None:
1277+
additional_kwargs["packed_seq_params"] = packed_seq_params
1278+
12741279
output_tensor = model(
12751280
input_ids=input_ids_cp_sharded,
12761281
position_ids=position_ids,
12771282
attention_mask=attention_mask,
1278-
packed_seq_params=packed_seq_params,
12791283
**multimodal_data,
1284+
**additional_kwargs,
12801285
)
12811286

12821287
# Apply temperature scaling to logits for training
@@ -1550,11 +1555,15 @@ def forward_step_fn(
15501555
if len(multimodal_data) > 0:
15511556
position_ids = None
15521557

1558+
additional_kwargs = {}
1559+
if packed_seq_params is not None:
1560+
additional_kwargs["packed_seq_params"] = packed_seq_params
1561+
15531562
output_tensor = model(
15541563
input_ids=input_ids_cp_sharded,
15551564
position_ids=position_ids,
15561565
attention_mask=attention_mask,
1557-
packed_seq_params=packed_seq_params,
1566+
**additional_kwargs,
15581567
**multimodal_data,
15591568
)
15601569

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/bin/bash
2+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
3+
source $SCRIPT_DIR/common.env
4+
5+
# ===== BEGIN CONFIG =====
6+
NUM_NODES=1
7+
STEPS_PER_RUN=30
8+
MAX_STEPS=30
9+
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
10+
NUM_MINUTES=60
11+
# ===== END CONFIG =====
12+
13+
exit_if_max_steps_reached
14+
15+
# Run the experiment
16+
cd $PROJECT_ROOT
17+
uv run examples/run_grpo_math.py \
18+
--config $CONFIG_PATH \
19+
grpo.max_num_steps=$MAX_STEPS \
20+
logger.log_dir=$LOG_DIR \
21+
logger.wandb_enabled=True \
22+
logger.wandb.project=nemo-rl \
23+
logger.wandb.name=$EXP_NAME \
24+
logger.monitor_gpus=True \
25+
logger.tensorboard_enabled=True \
26+
checkpointing.enabled=True \
27+
checkpointing.checkpoint_dir=$CKPT_DIR \
28+
$@ \
29+
2>&1 | tee $RUN_LOG
30+
31+
# Convert tensorboard logs to json
32+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
33+
34+
# Only run metrics if the target step is reached
35+
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
36+
uv run tests/check_metrics.py $JSON_METRICS \
37+
'mean(data["train/token_mult_prob_error"]) < 1.05' \
38+
'data["train/token_mult_prob_error"]["30"] < 1.05' \
39+
'data["train/reward"]["30"] > 0.4' \
40+
'mean(data["timing/train/total_step_time"], -6, -1) < 80'
41+
fi
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/bin/bash
2+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
3+
source $SCRIPT_DIR/common.env
4+
5+
# ===== BEGIN CONFIG =====
6+
NUM_NODES=2
7+
STEPS_PER_RUN=30
8+
MAX_STEPS=30
9+
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
10+
NUM_MINUTES=60
11+
# ===== END CONFIG =====
12+
13+
exit_if_max_steps_reached
14+
15+
# Run the experiment
16+
cd $PROJECT_ROOT
17+
uv run examples/run_grpo_math.py \
18+
--config $CONFIG_PATH \
19+
grpo.max_num_steps=$MAX_STEPS \
20+
logger.log_dir=$LOG_DIR \
21+
logger.wandb_enabled=True \
22+
logger.wandb.project=nemo-rl \
23+
logger.wandb.name=$EXP_NAME \
24+
logger.monitor_gpus=True \
25+
logger.tensorboard_enabled=True \
26+
checkpointing.enabled=True \
27+
checkpointing.checkpoint_dir=$CKPT_DIR \
28+
$@ \
29+
2>&1 | tee $RUN_LOG
30+
31+
# Convert tensorboard logs to json
32+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
33+
34+
# Only run metrics if the target step is reached
35+
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
36+
uv run tests/check_metrics.py $JSON_METRICS \
37+
'mean(data["train/token_mult_prob_error"]) < 1.05' \
38+
'data["train/token_mult_prob_error"]["30"] < 1.05' \
39+
'data["train/reward"]["30"] > 0.4' \
40+
'mean(data["timing/train/total_step_time"], -6, -1) < 60'
41+
fi

tests/test_suites/nightly.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh
4848
#https://github.com/NVIDIA-NeMo/RL/issues/1374
4949
#tests/test_suites/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.sh
5050

51+
# Nano-v2
52+
tests/test_suites/llm/grpo-nano-v2-12b-1n8g-megatron.sh
53+
tests/test_suites/llm/grpo-nano-v2-12b-2n8g-fsdp2tp1.sh
54+
5155
#######
5256
# SFT #
5357
#######

0 commit comments

Comments
 (0)