From ed540c2e5c829175df08e5f25e93ec4422cc5a54 Mon Sep 17 00:00:00 2001 From: Manasa Manohara Date: Mon, 3 Nov 2025 13:40:51 -0800 Subject: [PATCH 1/4] MergeGRPO to main --- nemo_rl/utils/logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index fa76f1295d..97457ea54d 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -815,7 +815,7 @@ def __init__(self, cfg: LoggerConfig): self.wandb_logger = WandbLogger(cfg["wandb"], log_dir=wandb_log_dir) self.loggers.append(self.wandb_logger) - if cfg["swanlab_enabled"]: + if cfg.get("swanlab_enabled", False): swanlab_log_dir = os.path.join(self.base_log_dir, "swanlab") os.makedirs(swanlab_log_dir, exist_ok=True) self.swanlab_logger = SwanlabLogger(cfg["swanlab"], log_dir=swanlab_log_dir) @@ -845,7 +845,7 @@ def __init__(self, cfg: LoggerConfig): f"{metric_prefix}/*", step_metric=step_metric ) - if cfg["swanlab_enabled"] and self.swanlab_logger: + if cfg.get("swanlab_enabled", False) and self.swanlab_logger: self.swanlab_logger.define_metric( f"{metric_prefix}/*", step_metric=step_metric ) From 8f3b99528e5078c894de0e421346a7d6f065bf4c Mon Sep 17 00:00:00 2001 From: Manasa Manohara Date: Mon, 3 Nov 2025 13:43:38 -0800 Subject: [PATCH 2/4] MergeGRPO to main --- ...helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml | 186 +++++++++++ ...nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml | 191 +++++++++++ .../llama_nemotron_super_49b_custom_plan.py | 91 ++++-- examples/configs/sft_nemotron_super_49b.yaml | 134 ++++++++ .../sft_nemotron_super_49b_tulu_v3.yaml | 115 +++++++ examples/run_grpo_helpsteer3.py | 305 ++++++++++++++++++ .../preference_datasets/helpsteer3.py | 1 + .../datasets/response_datasets/__init__.py | 10 + .../data/datasets/response_datasets/tulu3.py | 148 +++++++++ .../ray_actor_environment_registry.py | 1 + .../environments/helpsteer3_environment.py | 259 +++++++++++++++ 11 files changed, 1413 insertions(+), 28 deletions(-) create mode 100644 examples/configs/recipes/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml create mode 100644 examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml create mode 100644 examples/configs/sft_nemotron_super_49b.yaml create mode 100644 examples/configs/sft_nemotron_super_49b_tulu_v3.yaml create mode 100644 examples/run_grpo_helpsteer3.py create mode 100644 nemo_rl/data/datasets/response_datasets/tulu3.py create mode 100644 nemo_rl/environments/helpsteer3_environment.py diff --git a/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml new file mode 100644 index 0000000000..1b7c7334c8 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml @@ -0,0 +1,186 @@ +# GRPO Algorithm Configuration for Llama-3.2-1B with HelpSteer3 +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_rollout_turns: 1 # for multi-turn rollouts. HelpSteer3 conversations can have multiple turns + max_num_epochs: 3 + max_num_steps: 500 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + overlong_filtering: false + max_val_samples: 256 + val_batch_size: 256 + seed: 42 + use_dynamic_sampling: false + batch_multiplier: 1 + dynamic_sampling_max_gen_batches: 10 + reward_shaping: + enabled: false + overlong_buffer_length: 128 + overlong_buffer_penalty: 1 + max_response_length: ${policy.max_total_sequence_length} + reward_scaling: + enabled: false + source_min: 0.0 + source_max: 1.0 + target_min: 0.0 + target_max: 1.0 + + async_grpo: + enabled: false # Set to true to enable async training mode + # Max age (in training steps) for trajectories used in training + max_trajectory_age_steps: 1 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + sequence_level_importance_ratios: false + token_level_loss: true + truncated_importance_sampling_ratio: null + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo-helpsteer3-llama-3.2-1b-5" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false + +policy: + model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer: + name: "meta-llama/Llama-3.2-1B-Instruct" + max_total_sequence_length: 2048 + precision: "bfloat16" + train_global_batch_size: 512 + train_micro_batch_size: 4 + logprob_batch_size: 4 + logprob_chunk_size: null + + dtensor_cfg: + _v2: true + enabled: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + megatron_cfg: + enabled: false + + # See docs/design-docs/sequence-packing-and-dynamic-batching.md + # for more details on dynamic batching and sequence packing. + dynamic_batching: + enabled: True + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + # The scheduler iteration is per GPRO step and is decoupled with the optimizer step (may be >=1 per GPRO step) + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: + - 128009 # <|eot_id|> for Llama-3.2 + stop_strings: null + vllm_cfg: + async_engine: false + precision: ${policy.precision} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + enforce_eager: False + vllm_kwargs: {} + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: null # HelpSteer3 contains its own prompts + system_prompt_file: null + shuffle: true + num_workers: 1 + dataset_name: "HelpSteer3" + # HelpSteer3 preference dataset will be converted to response format for GRPO + # The preferred responses will be used as target responses for the environment + +env: + helpsteer3: + num_workers: 8 + # Environment configuration for HelpSteer3 preference-based rewards + reward_model: "preference_based" # Use preference scores as rewards + +logger: + log_dir: "logs" # Base directory for all logs + wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running + tensorboard_enabled: true + mlflow_enabled: false + monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-helpsteer3-llama-3.2-1b" + name: "grpo-helpsteer3-llama-3.2-1b-tp${policy.dtensor_cfg.tensor_parallel_size}" + tensorboard: + log_dir: "tb_logs-grpo-helpsteer3-llama-3.2-1b" + mlflow: + experiment_name: "grpo-helpsteer3" + run_name: "grpo-helpsteer3-llama-3.2-1b" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml new file mode 100644 index 0000000000..ea9e471ac7 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml @@ -0,0 +1,191 @@ +# GRPO Algorithm Configuration for Llama-3.3-Nemotron-Super-49B-v1.5 with HelpSteer3 +grpo: + num_prompts_per_step: 64 + num_generations_per_prompt: 16 + max_rollout_turns: 1 # for multi-turn rollouts. HelpSteer3 conversations can have multiple turns + max_num_epochs: 1 + max_num_steps: 10 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + overlong_filtering: false + max_val_samples: 256 + val_batch_size: 256 + seed: 42 + use_dynamic_sampling: false + batch_multiplier: 1 + dynamic_sampling_max_gen_batches: 10 + reward_shaping: + enabled: false + overlong_buffer_length: 128 + overlong_buffer_penalty: 1 + max_response_length: ${policy.max_total_sequence_length} + reward_scaling: + enabled: false + source_min: 0.0 + source_max: 1.0 + target_min: 0.0 + target_max: 1.0 + + async_grpo: + enabled: false # Set to true to enable async training mode + # Max age (in training steps) for trajectories used in training + max_trajectory_age_steps: 1 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + sequence_level_importance_ratios: false + truncated_importance_sampling_ratio: null + token_level_loss: true + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-3" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false + +policy: + model_name: /lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf + tokenizer: + name: ${policy.model_name} + max_total_sequence_length: 32768 + precision: "bfloat16" + train_global_batch_size: 64 + train_micro_batch_size: 1 + logprob_batch_size: 1 + logprob_chunk_size: null + + dtensor_cfg: + _v2: true + activation_checkpointing: true + context_parallel_size: 4 + cpu_offload: true + enabled: true + sequence_parallel: false + tensor_parallel_size: 8 + custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan + + megatron_cfg: + enabled: false + + # See docs/design-docs/sequence-packing-and-dynamic-batching.md + # for more details on dynamic batching and sequence packing. + dynamic_batching: + enabled: True + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 3.0e-7 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + # The scheduler iteration is per GPRO step and is decoupled with the optimizer step (may be >=1 per GPRO step) + total_iters: 13 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [13] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: ${policy.precision} + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + # when enforce_eager is False, it is optional to set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy, + # with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile + # for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998 + enforce_eager: False + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + vllm_kwargs: {} + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: null # HelpSteer3 contains its own prompts + system_prompt_file: null + shuffle: true + num_workers: 1 + dataset_name: "HelpSteer3" + # HelpSteer3 preference dataset will be converted to response format for GRPO + # The preferred responses will be used as target responses for the environment + +env: + helpsteer3: + num_workers: 8 + # Environment configuration for HelpSteer3 preference-based rewards + reward_model: "preference_based" # Use preference scores as rewards + +logger: + log_dir: "logs" # Base directory for all logs + wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running + tensorboard_enabled: false + mlflow_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5" + name: "grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-tp${policy.dtensor_cfg.tensor_parallel_size}" + tensorboard: + log_dir: "tb_logs-grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5" + mlflow: + experiment_name: "grpo-helpsteer3" + run_name: "grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 16 diff --git a/examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.py b/examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.py index a0381adf9c..2922c69f9e 100644 --- a/examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.py +++ b/examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.py @@ -12,38 +12,73 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast + from torch.distributed.tensor.parallel import ( ColwiseParallel, ParallelStyle, - PrepareModuleInput, - PrepareModuleOutput, RowwiseParallel, + SequenceParallel, ) from torch.distributed.tensor.placement_types import Replicate, Shard -custom_parallel_plan: dict[str, ParallelStyle] = { - "model.layers.*.self_attn": PrepareModuleInput( - input_kwarg_layouts={"attention_mask": Replicate()}, - desired_input_kwarg_layouts={"attention_mask": Replicate()}, - ), - "model.embed_tokens": RowwiseParallel( - input_layouts=Replicate(), output_layouts=Replicate(), use_local_output=True - ), - "model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False), - "model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False), - "model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=False), - "model.layers.*.self_attn.o_proj": RowwiseParallel( - output_layouts=Replicate(), use_local_output=True - ), - "model.layers.*.self_attn.rotary_emb": PrepareModuleOutput( - output_layouts=(Replicate(), Replicate()), - desired_output_layouts=(Replicate(), Replicate()), - use_local_output=False, - ), - "model.layers.*.mlp.up_proj": ColwiseParallel(), - "model.layers.*.mlp.gate_proj": ColwiseParallel(), - "model.layers.*.mlp.down_proj": RowwiseParallel( - output_layouts=Replicate(), use_local_output=True - ), - "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), -} + +def get_custom_parallel_plan(): + # Reuse llama default parallel plan + base_model_tp_plan: dict[str, ParallelStyle] = { + "model.embed_tokens": RowwiseParallel(input_layouts=Replicate()), + "model.layers.*.self_attn.q_proj": ColwiseParallel(), + "model.layers.*.self_attn.k_proj": ColwiseParallel(), + "model.layers.*.self_attn.v_proj": ColwiseParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(), + "model.layers.*.mlp.up_proj": ColwiseParallel(), + "model.layers.*.mlp.gate_proj": ColwiseParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(), + "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), + } + + base_model_sp_plan = { + "model.embed_tokens": RowwiseParallel( + input_layouts=Replicate(), output_layouts=Shard(1) + ), + "model.norm": SequenceParallel(), + "model.layers.*.input_layernorm": SequenceParallel(), + "model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), + "model.layers.*.post_attention_layernorm": SequenceParallel(), + "model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), + "lm_head": ColwiseParallel( + input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False + ), + } + + if False: + # Enable sequence parallelism only if TP size > 1 + base_model_tp_plan.update(cast(dict[str, ParallelStyle], base_model_sp_plan)) + + return base_model_tp_plan + + +custom_parallel_plan: dict[str, ParallelStyle] = get_custom_parallel_plan() +# { + +# "model.embed_tokens": RowwiseParallel( +# input_layouts=Replicate(), output_layouts=Replicate(), use_local_output=True +# ), +# "model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False), +# "model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False), +# "model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=False), +# "model.layers.*.self_attn.o_proj": RowwiseParallel( +# output_layouts=Replicate(), use_local_output=True +# ), +# "model.layers.*.self_attn.rotary_emb": PrepareModuleOutput( +# output_layouts=(Replicate(), Replicate()), +# desired_output_layouts=(Replicate(), Replicate()), +# use_local_output=False, +# ), +# "model.layers.*.mlp.up_proj": ColwiseParallel(), +# "model.layers.*.mlp.gate_proj": ColwiseParallel(), +# "model.layers.*.mlp.down_proj": RowwiseParallel( +# output_layouts=Replicate(), use_local_output=True +# ), +# "lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False), +# } diff --git a/examples/configs/sft_nemotron_super_49b.yaml b/examples/configs/sft_nemotron_super_49b.yaml new file mode 100644 index 0000000000..d79837dbb8 --- /dev/null +++ b/examples/configs/sft_nemotron_super_49b.yaml @@ -0,0 +1,134 @@ +# SFT Algorithm Configuration +sft: + max_num_epochs: 3 + max_num_steps: 100 + val_period: 10 + val_batches: 8 + val_global_batch_size: 128 + val_micro_batch_size: 1 + val_at_start: true + seed: 42 + +checkpointing: + enabled: true + checkpoint_dir: "results/sft_nemotron_super_49b" + metric_name: "val_loss" + higher_is_better: false + keep_top_k: 100 + save_period: 500 + checkpoint_must_save_by: null + +policy: + # model_name: Qwen/Qwen2.5-7B-Instruct + # tokenizer: + # name: Qwen/Qwen2.5-7B-Instruct + model_name: "/lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf" + tokenizer: + name: ${policy.model_name} + max_total_sequence_length: 4096 + precision: "bfloat16" + train_global_batch_size: 128 + train_micro_batch_size: 8 + + dtensor_cfg: + _v2: true + activation_checkpointing: true + context_parallel_size: 2 + cpu_offload: false + enabled: true + sequence_parallel: false + tensor_parallel_size: 4 + custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan + + megatron_cfg: + enabled: false + + dynamic_batching: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 8192 + sequence_length_round: 64 + + sequence_packing: + enabled: false + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${max:${mul:${policy.dtensor_cfg.context_parallel_size}, 2}, ${policy.max_total_sequence_length}} + max_grad_norm: null + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 2e-5 + weight_decay: 0.01 + betas: [0.9, 0.98] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + +# data: +# add_bos: true +# add_eos: true +# add_generation_prompt: false +# dataset_name: "tulu3_sft_mixture" +# cache_dir: "/lustre/fsw/portfolios/coreai/users/gvenkatakris/data-cache" +# max_input_seq_length: 1024 +# max_samples: 10000 +# shuffle: true +# test_size: 0.05 + +data: + max_input_seq_length: ${policy.max_total_sequence_length} + add_bos: true + add_eos: true + add_generation_prompt: false + shuffle: true + num_workers: 20 + + dataset_name: "squad" + # You can use custom response datasets for training and validation. For example: + # data: + # dataset_name: ResponseDataset + # train_data_path: # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace) + # val_data_path: + # input_key: , default is "input" + # output_key: , default is "output" + # train_split: , default is None # used for HuggingFace datasets + # val_split: , default is None # used for HuggingFace datasets + # See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details. + + ## unused with squad dataset + prompt_file: null + split: null + output_key: null + seed: null + +logger: + log_dir: "logs" # Base directory for all logs + wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running + tensorboard_enabled: false + mlflow_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + swanlab_enabled: false + wandb: + project: "sft-nemotron-joyang" + name: "sft-${data.dataset_name}-nemotron-super-49b-joyang" + tensorboard: + log_dir: "tb_logs-openmathinstruct-nemorl-1M_train" + mlflow: + experiment_name: "sft-dev" + run_name: "openmathinstruct-nemorl-1M_train" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 1 diff --git a/examples/configs/sft_nemotron_super_49b_tulu_v3.yaml b/examples/configs/sft_nemotron_super_49b_tulu_v3.yaml new file mode 100644 index 0000000000..7c819f0e9f --- /dev/null +++ b/examples/configs/sft_nemotron_super_49b_tulu_v3.yaml @@ -0,0 +1,115 @@ +# SFT on Tulu3-SFT-Mixture dataset +sft: + max_num_epochs: 1 + max_num_steps: 50 + val_period: 5 + val_batches: 8 + val_global_batch_size: 128 + val_micro_batch_size: 1 + val_at_start: true + seed: 42 + +checkpointing: + enabled: true + checkpoint_dir: "results/sft_nemotron_super_49b" + metric_name: "val_loss" + higher_is_better: false + keep_top_k: 100 + save_period: 500 + checkpoint_must_save_by: null + +policy: + model_name: "/lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf" + tokenizer: + name: ${policy.model_name} + max_total_sequence_length: 32768 + precision: "bfloat16" + train_global_batch_size: 128 + train_micro_batch_size: 1 + + dtensor_cfg: + _v2: true + activation_checkpointing: true + context_parallel_size: 8 + cpu_offload: false + enabled: true + sequence_parallel: false + tensor_parallel_size: 4 + custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan + + megatron_cfg: + enabled: false + + dynamic_batching: + enabled: false + train_mb_tokens: 4096 + logprob_mb_tokens: 8192 + sequence_length_round: 64 + + sequence_packing: + enabled: false + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${max:${mul:${policy.dtensor_cfg.context_parallel_size}, 2}, ${policy.max_total_sequence_length}} + max_grad_norm: null + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 1e-5 + weight_decay: 0.01 + betas: [0.9, 0.98] + eps: 1e-8 + foreach: False + fused: False + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 10 # warmup_steps + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + # max_num_steps - warmup_steps = cosine steps + T_max: 40 + eta_min: 2e-6 + - milestones: [10] + +data: + max_input_seq_length: ${policy.max_total_sequence_length} + dataset_name: "tulu3_sft_mixture" + add_bos: true + add_eos: true + add_generation_prompt: false + shuffle: true + num_workers: 20 + test_size: 0.05 + # max_samples: 10000 # remove this line to use all data + +logger: + log_dir: "logs" + wandb_enabled: true + tensorboard_enabled: false + mlflow_enabled: false + monitor_gpus: true + swanlab_enabled: false + num_val_samples_to_print: 0 + wandb: + project: "nemotron-tulu-3-sft" + name: "nemotron-tulu-3" + tensorboard: + log_dir: "tb_logs-nemotron-tulu-3-sft" + mlflow: + experiment_name: "nemotron-tulu-3-sft" + run_name: "nemotron-tulu-3-sft" + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 + +cluster: + gpus_per_node: 8 + num_nodes: 8 \ No newline at end of file diff --git a/examples/run_grpo_helpsteer3.py b/examples/run_grpo_helpsteer3.py new file mode 100644 index 0000000000..048f72d64c --- /dev/null +++ b/examples/run_grpo_helpsteer3.py @@ -0,0 +1,305 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +from collections import defaultdict +from typing import Any, Optional +from omegaconf import OmegaConf +from transformers import PreTrainedTokenizerBase + +from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data import DataConfig +from nemo_rl.data.datasets import AllTaskProcessedDataset, load_preference_dataset +from nemo_rl.data.interfaces import ( + DatumSpec, + LLMMessageLogType, + TaskDataProcessFnCallable, + TaskDataSpec, +) +from nemo_rl.distributed.ray_actor_environment_registry import ( + get_actor_python_env, +) +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.helpsteer3_environment import HelpSteer3Environment +from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training with HelpSteer3 configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +# =============================================================================== +# HelpSteer3 Data Processor +# =============================================================================== +TokenizerType = PreTrainedTokenizerBase + + +def helpsteer3_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a HelpSteer3 preference datum into a DatumSpec for GRPO training. + + This function converts HelpSteer3 preference data to work with GRPO by: + 1. Using the context as the prompt + 2. Using the preferred completion as the target response + 3. Creating a reward signal based on preference scores + """ + # Extract context and completions from HelpSteer3 format + context = datum_dict["context"] + completions = datum_dict["completions"] + + # Sort completions by rank (0 is preferred, 1 is rejected) + completions = sorted(completions, key=lambda x: x["rank"]) + preferred_completion = completions[0]["completion"] + + # Build the conversation from context + message_log: LLMMessageLogType = [] + + # Add context messages + if isinstance(context, list): + for msg in context: + message_log.append({ + "role": msg["role"], + "content": msg["content"], + }) + else: + # If context is a string, treat it as a user message + message_log.append({ + "role": "user", + "content": context, + }) + + # Add the preferred completion as the target + for completion_msg in preferred_completion: + message_log.append({ + "role": completion_msg["role"], + "content": completion_msg["content"], + }) + + # Apply chat template and tokenize + formatted_conversation = tokenizer.apply_chat_template( + message_log, + tokenize=False, + add_generation_prompt=False, + add_special_tokens=True, + ) + + # Tokenize the entire conversation + full_tokens = tokenizer( + formatted_conversation, + return_tensors="pt", + add_special_tokens=False, # Already added by chat template + )["input_ids"][0] + + # For simplicity, assign all tokens to the first message + # In a more sophisticated implementation, you might want to split tokens properly + message_log[0]["token_ids"] = full_tokens + message_log[0]["content"] = formatted_conversation + + # Clear token_ids for other messages to avoid double counting + for i in range(1, len(message_log)): + message_log[i]["token_ids"] = tokenizer("", return_tensors="pt")["input_ids"][0] # Empty tensor + + length = sum(len(m["token_ids"]) for m in message_log) + + # Create ground truth from the preferred completion for environment evaluation + ground_truth = " ".join([msg["content"] for msg in preferred_completion]) + extra_env_info = {"ground_truth": ground_truth} + + loss_multiplier = 1.0 + if length > max_seq_length: + # Truncate if too long + for chat_message in message_log: + chat_message["token_ids"] = chat_message["token_ids"][ + : min(max_seq_length // len(message_log), len(chat_message["token_ids"])) + ] + loss_multiplier = 0.1 # Reduce loss for truncated sequences + + output: DatumSpec = { + "message_log": message_log, + "length": length, + "extra_env_info": extra_env_info, + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": task_data_spec.task_name, + } + return output + + +def setup_data( + tokenizer: TokenizerType, + data_config: DataConfig, + env_configs: dict[str, Any], + seed: int, +) -> tuple[ + AllTaskProcessedDataset, + Optional[AllTaskProcessedDataset], + dict[str, EnvironmentInterface], + dict[str, EnvironmentInterface], +]: + print("\nā–¶ Setting up HelpSteer3 data and environment...") + helpsteer3_task_spec = TaskDataSpec( + task_name="helpsteer3", + prompt_file=data_config.get("prompt_file"), + system_prompt_file=data_config.get("system_prompt_file"), + ) + + # Load HelpSteer3 preference dataset + data: Any = load_preference_dataset(data_config) + + # Data processor for HelpSteer3 + task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = ( + defaultdict(lambda: (helpsteer3_task_spec, helpsteer3_data_processor)) + ) + task_data_processors["helpsteer3"] = (helpsteer3_task_spec, helpsteer3_data_processor) + + # Setup dedicated HelpSteer3Environment + helpsteer3_env = HelpSteer3Environment.options( # type: ignore # it's wrapped with ray.remote + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.helpsteer3_environment.HelpSteer3Environment" + ), + "env_vars": dict(os.environ), # Pass thru all user environment variables + } + ).remote(env_configs.get("helpsteer3", {"num_workers": 8})) + + # Create training dataset + dataset = AllTaskProcessedDataset( + data.formatted_ds["train"], + tokenizer, + helpsteer3_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + + # Create validation dataset if available + val_dataset: Optional[AllTaskProcessedDataset] = None + if "validation" in data.formatted_ds and data.formatted_ds["validation"]: + val_dataset = AllTaskProcessedDataset( + data.formatted_ds["validation"], + tokenizer, + helpsteer3_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + + # Map tasks to environments + task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: helpsteer3_env) + task_to_env["helpsteer3"] = helpsteer3_env + + return dataset, val_dataset, task_to_env, task_to_env + + +def main() -> None: + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "recipes", "llm", "grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"šŸ“Š Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"šŸ“Š Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + init_ray() + + # Setup tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + assert config["policy"]["generation"] is not None, ( + "A generation config is required for GRPO" + ) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # Setup data + ( + dataset, + val_dataset, + task_to_env, + val_task_to_env, + ) = setup_data(tokenizer, config["data"], config["env"], config["grpo"]["seed"]) + + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/data/datasets/preference_datasets/helpsteer3.py b/nemo_rl/data/datasets/preference_datasets/helpsteer3.py index e80fbff302..570756d311 100644 --- a/nemo_rl/data/datasets/preference_datasets/helpsteer3.py +++ b/nemo_rl/data/datasets/preference_datasets/helpsteer3.py @@ -51,6 +51,7 @@ def to_preference_data_format( {"rank": 0, "completion": [{"role": "assistant", "content": chosen}]}, {"rank": 1, "completion": [{"role": "assistant", "content": rejected}]}, ], + "task_name": "helpsteer3", # Add task_name for GRPO compatibility } diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index 0fc279ac8c..4250765676 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -27,6 +27,7 @@ from nemo_rl.data.datasets.response_datasets.refcoco import RefCOCODataset from nemo_rl.data.datasets.response_datasets.response_dataset import ResponseDataset from nemo_rl.data.datasets.response_datasets.squad import SquadDataset +from nemo_rl.data.datasets.response_datasets.tulu3 import Tulu3SftMixtureDataset from nemo_rl.data.datasets.utils import get_extra_kwargs @@ -93,6 +94,14 @@ def load_response_dataset(data_config, seed: int = 42): base_dataset: Any = Geometry3KDataset( split=data_config["split"], ) + elif dataset_name == "tulu3_sft_mixture": + print("Loading allenai/tulu-3-sft-mixture for training and validation") + base_dataset: Any = Tulu3SftMixtureDataset( + test_size=data_config.get("test_size", 0.05), + prompt_file=data_config.get("prompt_file", None), + max_samples=data_config.get("max_samples", None), + seed=seed, + ) # fall back to load from JSON file elif dataset_name == "ResponseDataset": if "train_data_path" not in data_config: @@ -133,4 +142,5 @@ def load_response_dataset(data_config, seed: int = 42): "RefCOCODataset", "ResponseDataset", "SquadDataset", + "Tulu3SftMixtureDataset", ] diff --git a/nemo_rl/data/datasets/response_datasets/tulu3.py b/nemo_rl/data/datasets/response_datasets/tulu3.py new file mode 100644 index 0000000000..0efc3841a9 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/tulu3.py @@ -0,0 +1,148 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any + +from datasets import load_dataset + +from nemo_rl.data.interfaces import TaskDataSpec + + +def to_preference_data_format( + data: dict[str, Any], +) -> dict[ + str, list[dict[str, int | list[dict[str, str | Any]]]] | list[dict[str, str]] +]: + chosen_conversation = data["chosen"] + rejected_conversation = data["rejected"] + + context = chosen_conversation[:-1] + + # We assume that except last assistant response, all messages in + # chosen and rejected conversations are similar. Validating this... + assert json.dumps(context, ensure_ascii=False) == json.dumps( + rejected_conversation[:-1], ensure_ascii=False + ), ( + f"Context mismatch.\n\nchosen: {chosen_conversation}\n\n rejected: {rejected_conversation}" + ) + + # We assume that last response is always from the assistant. Validating this... + assert chosen_conversation[-1]["role"] == "assistant", ( + f"The last chosen response ({chosen_conversation[-1]}) is not from assistant!" + ) + assert rejected_conversation[-1]["role"] == "assistant", ( + f"The last rejected response ({rejected_conversation[-1]}) is not from assistant!" + ) + + chosen_response = chosen_conversation[-1]["content"] + rejected_response = rejected_conversation[-1]["content"] + + return { + "context": context, + "completions": [ + { + "rank": 0, + "completion": [{"role": "assistant", "content": chosen_response}], + }, + { + "rank": 1, + "completion": [{"role": "assistant", "content": rejected_response}], + }, + ], + } + + +class Tulu3PreferenceDataset: + """Tulu3 preference dataset for DPO training.""" + + def __init__(self) -> None: + ds = load_dataset( + path="allenai/llama-3.1-tulu-3-8b-preference-mixture", + trust_remote_code=True, + ) + self.formatted_ds = ds.map(to_preference_data_format) + + self.task_spec = TaskDataSpec( + task_name="Tulu3Preference", + ) + +def format_tulu3_sft_mixture(data: dict[str, Any]) -> dict[str, str | dict[str, str]]: + """format for Tulu3 SFT data.""" + messages = data["messages"] + + # Ensure last message is from assistant + if not messages or messages[-1]["role"] != "assistant": + raise ValueError(f"Expected last message to be from assistant, got: {messages}") + + return { + "messages": messages, + "task_name": "tulu3_sft_mixture", + } + + +class Tulu3SftMixtureDataset: + """Tulu3 SFT mixture dataset.""" + + def __init__( + self, + seed: int = 42, + test_size: float = 0.05, + prompt_file: str | None = None, + max_samples: int | None = None, + ) -> None: + """Initialize the Tulu3 SFT mixture dataset. + + Args: + seed: Random seed for train/validation split + test_size: Proportion of data to use for validation (0.0-1.0) + prompt_file: Optional prompt file path to be applied via TaskDataSpec + max_samples: Optional maximum number of samples to use from the dataset + """ + print( + "WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets." + ) + + # Load the original dataset + original_ds = load_dataset( + path="allenai/tulu-3-sft-mixture", + trust_remote_code=True, + )["train"] # This dataset only has a train split + + # Optionally limit the number of samples + if max_samples is not None and max_samples > 0: + original_ds = original_ds.shuffle(seed=seed).select(range(min(max_samples, len(original_ds)))) + + # Split into train and validation sets + split_ds = original_ds.train_test_split(test_size=test_size, seed=seed) + + # Format the examples without any reasoning processing + train_formatted = split_ds["train"].map( + format_tulu3_sft_mixture, + remove_columns=split_ds["train"].column_names, + ) + val_formatted = split_ds["test"].map( + format_tulu3_sft_mixture, + remove_columns=split_ds["test"].column_names, + ) + + self.formatted_ds = { + "train": train_formatted, + "validation": val_formatted, + } + + self.task_spec = TaskDataSpec( + task_name="Tulu3SftMixture", + prompt_file=prompt_file, + ) \ No newline at end of file diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 6a3529d4a1..8528b740d6 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -36,6 +36,7 @@ "nemo_rl.environments.vlm_environment.VLMEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.reward_model_environment.RewardModelEnvironment": PY_EXECUTABLES.SYSTEM, + "nemo_rl.environments.helpsteer3_environment.HelpSteer3Environment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM, # AsyncTrajectoryCollector needs vLLM environment to handle exceptions from VllmGenerationWorker "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector": PY_EXECUTABLES.VLLM, diff --git a/nemo_rl/environments/helpsteer3_environment.py b/nemo_rl/environments/helpsteer3_environment.py new file mode 100644 index 0000000000..df4c773824 --- /dev/null +++ b/nemo_rl/environments/helpsteer3_environment.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, TypedDict, Union + +import ray +import torch + +from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES +from nemo_rl.environments.interfaces import ( + EnvironmentInterface, + EnvironmentReturn, +) +from nemo_rl.environments.utils import chunk_list_to_workers + + +class HelpSteer3EnvConfig(TypedDict): + num_workers: int + stop_strings: Optional[list[str]] # Default stop strings for this env + + +class HelpSteer3EnvironmentMetadata(TypedDict): + ground_truth: str + + +@ray.remote # pragma: no cover +class HelpSteer3VerifyWorker: + """Worker for evaluating HelpSteer3 responses based on preference alignment.""" + + def __init__(self) -> None: + pass + + def verify( + self, + pred_responses: list[str], + ground_truths: list[str], + return_extracted_answer: bool = False, + ) -> Union[list[float], tuple[list[float], list[str | None]]]: + """Verify HelpSteer3 responses against preferred completions. + + For HelpSteer3, we use a simple text similarity approach to evaluate + how well the model's response aligns with the preferred completion. + + Args: + pred_responses: list[str]. The predicted responses from the LLM. + ground_truths: list[str]. The preferred completion responses. + return_extracted_answer: bool. Whether to return extracted answers. + + Returns: + Union[list[float], tuple[list[float], list[str | None]]]. + If return_extracted_answer is False, returns only the scores. + If return_extracted_answer is True, returns (scores, extracted_answers). + """ + results = [] + extracted_answers: list[str | None] = [] + + for response, ground_truth in zip(pred_responses, ground_truths): + try: + # Simple reward based on text similarity/alignment + # This is a basic implementation - could be enhanced with more sophisticated metrics + score = self._calculate_preference_score(response, ground_truth) + results.append(float(score)) + + if return_extracted_answer: + # For HelpSteer3, the "extracted answer" is the full response + extracted_answers.append(response.strip()) + + except Exception: + results.append(0.0) + if return_extracted_answer: + extracted_answers.append(None) + + if return_extracted_answer: + return results, extracted_answers + else: + return results + + def _calculate_preference_score(self, response: str, ground_truth: str) -> float: + """Calculate a preference alignment score between response and ground truth. + + This is a simplified scoring function. In practice, you might want to use: + - Semantic similarity models + - BLEU/ROUGE scores + - Human preference models + - Quality metrics specific to HelpSteer3 + + Args: + response: The model's response + ground_truth: The preferred completion + + Returns: + float: Score between 0.0 and 1.0 + """ + # Normalize both texts + response_clean = response.strip().lower() + ground_truth_clean = ground_truth.strip().lower() + + # Simple exact match (could be enhanced) + if response_clean == ground_truth_clean: + return 1.0 + + # Basic similarity based on common words + response_words = set(response_clean.split()) + ground_truth_words = set(ground_truth_clean.split()) + + if not ground_truth_words: + return 0.0 + + # Jaccard similarity + intersection = len(response_words & ground_truth_words) + union = len(response_words | ground_truth_words) + + if union == 0: + return 0.0 + + jaccard_score = intersection / union + + # Length penalty for responses that are too short or too long + len_ratio = min(len(response_clean), len(ground_truth_clean)) / max(len(response_clean), len(ground_truth_clean), 1) + + # Combine scores + final_score = jaccard_score * len_ratio + + return min(1.0, max(0.0, final_score)) + + +@ray.remote(max_restarts=-1, max_task_retries=-1) # pragma: no cover +class HelpSteer3Environment(EnvironmentInterface[HelpSteer3EnvironmentMetadata]): + """Environment for evaluating HelpSteer3 preference alignment.""" + + def __init__(self, cfg: HelpSteer3EnvConfig): + self.cfg = cfg + self.num_workers = cfg["num_workers"] + + # Create worker pool + self.workers = [ + HelpSteer3VerifyWorker.options( # type: ignore # (decorated with @ray.remote) + runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM} + ).remote() + for _ in range(self.num_workers) + ] + + def shutdown(self) -> None: + """Shutdown all workers.""" + for worker in self.workers: + ray.kill(worker) + + def step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[HelpSteer3EnvironmentMetadata], + return_extracted_answer: bool = False, + ) -> EnvironmentReturn[HelpSteer3EnvironmentMetadata]: + """Runs a step in the HelpSteer3 environment. + + Args: + message_log_batch: Batch of OpenAI-API-like message logs. + metadata: Batch of HelpSteer3EnvironmentMetadata with ground truth. + return_extracted_answer: Whether to return extracted answers. + + Returns: + EnvironmentReturn: Tuple containing observations, metadata, stop strings, rewards, and done flags. + """ + # Extract the assistant's responses from the message history + assistant_response_batch = [] + for conversation in message_log_batch: + assistant_responses = [ + str(interaction["content"]) + for interaction in conversation + if interaction["role"] == "assistant" + ] + assistant_response_batch.append("".join(assistant_responses)) + + ground_truths = [g["ground_truth"] for g in metadata] + + # Chunk work across workers + chunked_assistant_response_batch = chunk_list_to_workers( + assistant_response_batch, self.num_workers + ) + chunked_ground_truths = chunk_list_to_workers(ground_truths, self.num_workers) + + # Process each chunk in parallel + futures = [ + self.workers[i].verify.remote( + chunk, ground_truth_chunk, return_extracted_answer + ) + for i, (chunk, ground_truth_chunk) in enumerate( + zip(chunked_assistant_response_batch, chunked_ground_truths) + ) + ] + + worker_results = ray.get(futures) + + # Flatten the results and extract both scores and answers + results = [] + extracted_answers: list[str | None] | None = ( + [] if return_extracted_answer else None + ) + + for worker_result in worker_results: + if return_extracted_answer: + worker_scores, worker_answers = worker_result + results.extend(worker_scores) + extracted_answers.extend(worker_answers) + else: + results.extend(worker_result) + + # Create observations based on preference alignment + observations = [ + { + "role": "environment", + "content": f"Environment: preference aligned (score: {result:.2f})" + if result > 0.5 + else f"Environment: preference misaligned (score: {result:.2f})", + } + for result in results + ] + + # Create reward and done tensors + rewards = torch.tensor(results).cpu() + done = torch.ones_like(rewards).cpu() + next_stop_strings = [None] * len(message_log_batch) + + return EnvironmentReturn( + observations=observations, + metadata=metadata, + next_stop_strings=next_stop_strings, + rewards=rewards, + terminateds=done, + answers=extracted_answers, + ) + + def global_post_process_and_metrics( + self, batch: BatchedDataDict[Any] + ) -> tuple[BatchedDataDict[Any], dict[str, float | int]]: + """Post-process batch and compute metrics for HelpSteer3.""" + # Calculate preference alignment metrics + rewards = batch["rewards"] + + metrics = { + "preference_alignment_rate": float(torch.mean((rewards > 0.5).float())), + "average_preference_score": float(torch.mean(rewards)), + "high_alignment_rate": float(torch.mean((rewards > 0.8).float())), + } + + return batch, metrics From 25ef3965b295eb61d3f1154db953e6ca8182daad Mon Sep 17 00:00:00 2001 From: Manasa Manohara Date: Mon, 3 Nov 2025 20:42:49 -0800 Subject: [PATCH 3/4] SFT update --- examples/run_sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_sft.py b/examples/run_sft.py index b804b4e19f..233817947b 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -32,7 +32,7 @@ from nemo_rl.utils.logger import get_next_experiment_dir OmegaConf.register_new_resolver("mul", lambda a, b: a * b) - +OmegaConf.register_new_resolver("max", lambda a, b: max(a, b)) def parse_args(): """Parse command line arguments.""" From 86f1dea0e015c9a10367c5c0a86da60bbc2e4ce6 Mon Sep 17 00:00:00 2001 From: Manasa Manohara Date: Sat, 8 Nov 2025 20:19:59 -0800 Subject: [PATCH 4/4] Resolving comments --- examples/configs/grpo_helpsteer3.yaml | 192 ++++++++++++++++++ ...helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml | 168 +-------------- ...nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml | 140 +------------ ...-super-49b-v.5-4n8g-fsdp2tp8.yaml.disabled | 47 ----- .../llm/sft-nemotron-super-49b-tulu-v3.yaml | 49 +++++ .../recipes/llm/sft-nemotron-super-49b.yaml | 18 ++ ....yaml => sft_nemotron_super_49b_base.yaml} | 48 ++--- .../sft_nemotron_super_49b_tulu_v3.yaml | 115 ----------- .../{ => custom_parallel}/custom_parallel.py | 0 .../llama_nemotron_super_49b_custom_plan.py | 0 .../data/datasets/response_datasets/tulu3.py | 59 ------ ...o-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.sh | 40 ++++ ...3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.sh | 40 ++++ .../llm/sft-nemotron-super-49b-tulu-v3.sh | 40 ++++ .../test_suites/llm/sft-nemotron-super-49b.sh | 40 ++++ tests/test_suites/nightly.txt | 8 + 16 files changed, 449 insertions(+), 555 deletions(-) create mode 100644 examples/configs/grpo_helpsteer3.yaml delete mode 100644 examples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml.disabled create mode 100644 examples/configs/recipes/llm/sft-nemotron-super-49b-tulu-v3.yaml create mode 100644 examples/configs/recipes/llm/sft-nemotron-super-49b.yaml rename examples/configs/{sft_nemotron_super_49b.yaml => sft_nemotron_super_49b_base.yaml} (68%) delete mode 100644 examples/configs/sft_nemotron_super_49b_tulu_v3.yaml rename examples/{ => custom_parallel}/custom_parallel.py (100%) rename examples/{configs/recipes/llm => custom_parallel}/llama_nemotron_super_49b_custom_plan.py (100%) create mode 100755 tests/test_suites/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.sh create mode 100755 tests/test_suites/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.sh create mode 100755 tests/test_suites/llm/sft-nemotron-super-49b-tulu-v3.sh create mode 100755 tests/test_suites/llm/sft-nemotron-super-49b.sh diff --git a/examples/configs/grpo_helpsteer3.yaml b/examples/configs/grpo_helpsteer3.yaml new file mode 100644 index 0000000000..867a6bd5ab --- /dev/null +++ b/examples/configs/grpo_helpsteer3.yaml @@ -0,0 +1,192 @@ +# Base GRPO Algorithm Configuration for HelpSteer3 dataset +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_rollout_turns: 1 # for multi-turn rollouts. HelpSteer3 conversations can have multiple turns + max_num_epochs: 1 + max_num_steps: 500 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + overlong_filtering: false + max_val_samples: 256 + val_batch_size: 256 + seed: 42 + use_dynamic_sampling: false + batch_multiplier: 1 + dynamic_sampling_max_gen_batches: 10 + reward_shaping: + enabled: false + overlong_buffer_length: 128 + overlong_buffer_penalty: 1 + max_response_length: ${policy.max_total_sequence_length} + reward_scaling: + enabled: false + source_min: 0.0 + source_max: 1.0 + target_min: 0.0 + target_max: 1.0 + + async_grpo: + enabled: false # Set to true to enable async training mode + # Max age (in training steps) for trajectories used in training + max_trajectory_age_steps: 1 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + sequence_level_importance_ratios: false + truncated_importance_sampling_ratio: null + token_level_loss: true + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo-helpsteer3" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + checkpoint_must_save_by: null + model_save_format: "safetensors" + save_consolidated: false + +policy: + model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer: + name: ${policy.model_name} + max_total_sequence_length: 2048 + precision: "bfloat16" + train_global_batch_size: 512 + train_micro_batch_size: 4 + logprob_batch_size: 4 + logprob_chunk_size: null + + dtensor_cfg: + _v2: true + enabled: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: false + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + + megatron_cfg: + enabled: false + + # See docs/design-docs/sequence-packing-and-dynamic-batching.md + # for more details on dynamic batching and sequence packing. + dynamic_batching: + enabled: True + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + # The scheduler iteration is per GPRO step and is decoupled with the optimizer step (may be >=1 per GPRO step) + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: ${policy.precision} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + # when enforce_eager is False, it is optional to set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy, + # with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile + # for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998 + enforce_eager: False + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + vllm_kwargs: {} + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: null # HelpSteer3 contains its own prompts + system_prompt_file: null + shuffle: true + num_workers: 1 + dataset_name: "HelpSteer3" + # HelpSteer3 preference dataset will be converted to response format for GRPO + # The preferred responses will be used as target responses for the environment + +env: + helpsteer3: + num_workers: 8 + # Environment configuration for HelpSteer3 preference-based rewards + reward_model: "preference_based" # Use preference scores as rewards + +logger: + log_dir: "logs" # Base directory for all logs + wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running + tensorboard_enabled: false + mlflow_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-helpsteer3" + name: "grpo-helpsteer3" + tensorboard: + log_dir: "tb_logs-grpo-helpsteer3" + mlflow: + experiment_name: "grpo-helpsteer3" + run_name: "grpo-helpsteer3" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 1 + diff --git a/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml index 1b7c7334c8..5ffc3fc415 100644 --- a/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.yaml @@ -1,186 +1,24 @@ +defaults: ../../grpo_helpsteer3.yaml + # GRPO Algorithm Configuration for Llama-3.2-1B with HelpSteer3 grpo: - num_prompts_per_step: 32 - num_generations_per_prompt: 16 - max_rollout_turns: 1 # for multi-turn rollouts. HelpSteer3 conversations can have multiple turns max_num_epochs: 3 - max_num_steps: 500 - normalize_rewards: true - use_leave_one_out_baseline: true - val_period: 10 - val_at_start: false - overlong_filtering: false - max_val_samples: 256 - val_batch_size: 256 - seed: 42 - use_dynamic_sampling: false - batch_multiplier: 1 - dynamic_sampling_max_gen_batches: 10 - reward_shaping: - enabled: false - overlong_buffer_length: 128 - overlong_buffer_penalty: 1 - max_response_length: ${policy.max_total_sequence_length} - reward_scaling: - enabled: false - source_min: 0.0 - source_max: 1.0 - target_min: 0.0 - target_max: 1.0 - - async_grpo: - enabled: false # Set to true to enable async training mode - # Max age (in training steps) for trajectories used in training - max_trajectory_age_steps: 1 - -loss_fn: - reference_policy_kl_penalty: 0.01 - ratio_clip_min: 0.2 - ratio_clip_max: 0.2 - ratio_clip_c: null - # (default off) loss formulation improvements (docs/guides/grpo.md#loss) - use_on_policy_kl_approximation: false - use_importance_sampling_correction: false - sequence_level_importance_ratios: false - token_level_loss: true - truncated_importance_sampling_ratio: null checkpointing: - enabled: true checkpoint_dir: "results/grpo-helpsteer3-llama-3.2-1b-5" - metric_name: "val_reward" - higher_is_better: true - keep_top_k: 3 - save_period: 10 - checkpoint_must_save_by: null - model_save_format: "safetensors" - save_consolidated: false policy: - model_name: "meta-llama/Llama-3.2-1B-Instruct" - tokenizer: - name: "meta-llama/Llama-3.2-1B-Instruct" - max_total_sequence_length: 2048 - precision: "bfloat16" - train_global_batch_size: 512 - train_micro_batch_size: 4 - logprob_batch_size: 4 - logprob_chunk_size: null - - dtensor_cfg: - _v2: true - enabled: true - cpu_offload: false - sequence_parallel: false - activation_checkpointing: false - tensor_parallel_size: 1 - context_parallel_size: 1 - custom_parallel_plan: null - - megatron_cfg: - enabled: false - - # See docs/design-docs/sequence-packing-and-dynamic-batching.md - # for more details on dynamic batching and sequence packing. - dynamic_batching: - enabled: True - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - sequence_length_round: 64 - - sequence_packing: - enabled: False - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - algorithm: "modified_first_fit_decreasing" - sequence_length_round: 64 - - make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} - max_grad_norm: 1.0 - - optimizer: - name: "torch.optim.AdamW" - kwargs: - lr: 5.0e-6 - weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1e-8 - - scheduler: - - name: "torch.optim.lr_scheduler.LinearLR" - kwargs: - start_factor: 0.1 - end_factor: 1.0 - # The scheduler iteration is per GPRO step and is decoupled with the optimizer step (may be >=1 per GPRO step) - total_iters: 50 - - name: "torch.optim.lr_scheduler.ConstantLR" - kwargs: - factor: 1.0 - total_iters: 10000000000 - - milestones: [50] - generation: - backend: "vllm" - max_new_tokens: ${policy.max_total_sequence_length} - temperature: 1.0 - top_p: 1.0 - top_k: null stop_token_ids: - 128009 # <|eot_id|> for Llama-3.2 - stop_strings: null - vllm_cfg: - async_engine: false - precision: ${policy.precision} - tensor_parallel_size: 1 - pipeline_parallel_size: 1 - expert_parallel_size: 1 - gpu_memory_utilization: 0.6 - max_model_len: ${policy.max_total_sequence_length} - enforce_eager: False - vllm_kwargs: {} - colocated: - # true: generation shares training GPUs - # false: uses dedicated generation resources - enabled: true - # only relevant when enabled is false - resources: - gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 - num_nodes: null # Decides number of nodes to be dedicated to generation - -data: - max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len - prompt_file: null # HelpSteer3 contains its own prompts - system_prompt_file: null - shuffle: true - num_workers: 1 - dataset_name: "HelpSteer3" - # HelpSteer3 preference dataset will be converted to response format for GRPO - # The preferred responses will be used as target responses for the environment - -env: - helpsteer3: - num_workers: 8 - # Environment configuration for HelpSteer3 preference-based rewards - reward_model: "preference_based" # Use preference scores as rewards logger: - log_dir: "logs" # Base directory for all logs - wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running tensorboard_enabled: true - mlflow_enabled: false - monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + monitor_gpus: true wandb: project: "grpo-helpsteer3-llama-3.2-1b" name: "grpo-helpsteer3-llama-3.2-1b-tp${policy.dtensor_cfg.tensor_parallel_size}" tensorboard: log_dir: "tb_logs-grpo-helpsteer3-llama-3.2-1b" mlflow: - experiment_name: "grpo-helpsteer3" run_name: "grpo-helpsteer3-llama-3.2-1b" - gpu_monitoring: - collection_interval: 10 # How often to collect GPU usage metrics (in seconds) - flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) - -cluster: - gpus_per_node: 8 - num_nodes: 1 diff --git a/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml index ea9e471ac7..d39028ffda 100644 --- a/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml +++ b/examples/configs/recipes/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.yaml @@ -1,117 +1,38 @@ +defaults: ../../grpo_helpsteer3.yaml + # GRPO Algorithm Configuration for Llama-3.3-Nemotron-Super-49B-v1.5 with HelpSteer3 grpo: num_prompts_per_step: 64 - num_generations_per_prompt: 16 - max_rollout_turns: 1 # for multi-turn rollouts. HelpSteer3 conversations can have multiple turns max_num_epochs: 1 max_num_steps: 10 - normalize_rewards: true - use_leave_one_out_baseline: true - val_period: 10 - val_at_start: false - overlong_filtering: false - max_val_samples: 256 - val_batch_size: 256 - seed: 42 - use_dynamic_sampling: false - batch_multiplier: 1 - dynamic_sampling_max_gen_batches: 10 - reward_shaping: - enabled: false - overlong_buffer_length: 128 - overlong_buffer_penalty: 1 - max_response_length: ${policy.max_total_sequence_length} - reward_scaling: - enabled: false - source_min: 0.0 - source_max: 1.0 - target_min: 0.0 - target_max: 1.0 - - async_grpo: - enabled: false # Set to true to enable async training mode - # Max age (in training steps) for trajectories used in training - max_trajectory_age_steps: 1 - -loss_fn: - reference_policy_kl_penalty: 0.01 - ratio_clip_min: 0.2 - ratio_clip_max: 0.2 - ratio_clip_c: null - # (default off) loss formulation improvements (docs/guides/grpo.md#loss) - use_on_policy_kl_approximation: false - use_importance_sampling_correction: false - sequence_level_importance_ratios: false - truncated_importance_sampling_ratio: null - token_level_loss: true checkpointing: - enabled: true checkpoint_dir: "results/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-3" - metric_name: "val_reward" - higher_is_better: true - keep_top_k: 3 - save_period: 10 - checkpoint_must_save_by: null - model_save_format: "safetensors" - save_consolidated: false policy: model_name: /lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf - tokenizer: - name: ${policy.model_name} max_total_sequence_length: 32768 - precision: "bfloat16" train_global_batch_size: 64 train_micro_batch_size: 1 logprob_batch_size: 1 - logprob_chunk_size: null dtensor_cfg: - _v2: true activation_checkpointing: true context_parallel_size: 4 cpu_offload: true - enabled: true sequence_parallel: false tensor_parallel_size: 8 - custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan - - megatron_cfg: - enabled: false - - # See docs/design-docs/sequence-packing-and-dynamic-batching.md - # for more details on dynamic batching and sequence packing. - dynamic_batching: - enabled: True - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - sequence_length_round: 64 - - sequence_packing: - enabled: False - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} - algorithm: "modified_first_fit_decreasing" - sequence_length_round: 64 - - make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} - max_grad_norm: 1.0 + custom_parallel_plan: examples.custom_parallel.llama_nemotron_super_49b_custom_plan.custom_parallel_plan optimizer: - name: "torch.optim.AdamW" kwargs: lr: 3.0e-7 - weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1e-8 scheduler: - name: "torch.optim.lr_scheduler.LinearLR" kwargs: start_factor: 0.1 end_factor: 1.0 - # The scheduler iteration is per GPRO step and is decoupled with the optimizer step (may be >=1 per GPRO step) total_iters: 13 - name: "torch.optim.lr_scheduler.ConstantLR" kwargs: @@ -120,72 +41,17 @@ policy: - milestones: [13] generation: - backend: "vllm" - max_new_tokens: ${policy.max_total_sequence_length} - temperature: 1.0 - top_p: 1.0 - top_k: null - stop_token_ids: null - stop_strings: null vllm_cfg: - async_engine: false - precision: ${policy.precision} tensor_parallel_size: 4 - pipeline_parallel_size: 1 - expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP - gpu_memory_utilization: 0.6 - max_model_len: ${policy.max_total_sequence_length} - # when enforce_eager is False, it is optional to set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy, - # with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile - # for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998 - enforce_eager: False - use_deep_gemm: False - num_last_layers_in_bf16: 0 - num_first_layers_in_bf16: 0 - vllm_kwargs: {} - colocated: - # true: generation shares training GPUs - # false: uses dedicated generation resources - enabled: true - # only relevant when enabled is false - resources: - gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 - num_nodes: null # Decides number of nodes to be dedicated to generation - -data: - max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len - prompt_file: null # HelpSteer3 contains its own prompts - system_prompt_file: null - shuffle: true - num_workers: 1 - dataset_name: "HelpSteer3" - # HelpSteer3 preference dataset will be converted to response format for GRPO - # The preferred responses will be used as target responses for the environment - -env: - helpsteer3: - num_workers: 8 - # Environment configuration for HelpSteer3 preference-based rewards - reward_model: "preference_based" # Use preference scores as rewards logger: - log_dir: "logs" # Base directory for all logs - wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running - tensorboard_enabled: false - mlflow_enabled: false - monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard wandb: project: "grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5" name: "grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-tp${policy.dtensor_cfg.tensor_parallel_size}" tensorboard: log_dir: "tb_logs-grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5" mlflow: - experiment_name: "grpo-helpsteer3" run_name: "grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5" - gpu_monitoring: - collection_interval: 10 # How often to collect GPU usage metrics (in seconds) - flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) cluster: - gpus_per_node: 8 num_nodes: 16 diff --git a/examples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml.disabled b/examples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml.disabled deleted file mode 100644 index 574db88263..0000000000 --- a/examples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml.disabled +++ /dev/null @@ -1,47 +0,0 @@ -defaults: ../../grpo_math_1B.yaml -grpo: - num_prompts_per_step: 128 -policy: - model_name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 - tokenizer: - name: nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 - max_total_sequence_length: 1024 - train_global_batch_size: 128 - dtensor_cfg: - activation_checkpointing: true - tensor_parallel_size: 8 - custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan - dynamic_batching: - enabled: true - sequence_packing: - enabled: false - optimizer: - kwargs: - lr: 3.0e-07 - scheduler: - - name: torch.optim.lr_scheduler.LinearLR - kwargs: - start_factor: 0.1 - end_factor: 1.0 - total_iters: 13 - - name: torch.optim.lr_scheduler.ConstantLR - kwargs: - factor: 1.0 - total_iters: 10000000000 - - milestones: - - 13 - generation: - vllm_cfg: - tensor_parallel_size: 4 -logger: - wandb_enabled: true - monitor_gpus: false - wandb: - project: grpo-nemotron-super-49b - name: grpo-${data.dataset_name}-nemotron-super-49b-tp${policy.dtensor_cfg.tensor_parallel_size} - mlflow: - experiment_name: sft-dev - run_name: grpo-nemotron-super-49b -cluster: - gpus_per_node: 8 - num_nodes: 4 diff --git a/examples/configs/recipes/llm/sft-nemotron-super-49b-tulu-v3.yaml b/examples/configs/recipes/llm/sft-nemotron-super-49b-tulu-v3.yaml new file mode 100644 index 0000000000..8ce6aeb26d --- /dev/null +++ b/examples/configs/recipes/llm/sft-nemotron-super-49b-tulu-v3.yaml @@ -0,0 +1,49 @@ +defaults: ../../sft_nemotron_super_49b_base.yaml + +# SFT on Tulu3-SFT-Mixture dataset +sft: + max_num_steps: 50 + val_period: 5 + +policy: + max_total_sequence_length: 32768 + train_micro_batch_size: 1 + + dtensor_cfg: + context_parallel_size: 8 + + optimizer: + kwargs: + lr: 1e-5 + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 10 # warmup_steps + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + # max_num_steps - warmup_steps = cosine steps + T_max: 40 + eta_min: 2e-6 + - milestones: [10] + +data: + dataset_name: "tulu3_sft_mixture" + test_size: 0.05 + # max_samples: 10000 # remove this line to use all data + +logger: + monitor_gpus: true + wandb: + project: "nemotron-tulu-3-sft" + name: "nemotron-tulu-3" + tensorboard: + log_dir: "tb_logs-nemotron-tulu-3-sft" + mlflow: + experiment_name: "nemotron-tulu-3-sft" + run_name: "nemotron-tulu-3-sft" + +cluster: + num_nodes: 8 diff --git a/examples/configs/recipes/llm/sft-nemotron-super-49b.yaml b/examples/configs/recipes/llm/sft-nemotron-super-49b.yaml new file mode 100644 index 0000000000..42b115a782 --- /dev/null +++ b/examples/configs/recipes/llm/sft-nemotron-super-49b.yaml @@ -0,0 +1,18 @@ +defaults: ../../sft_nemotron_super_49b_base.yaml + +# SFT Algorithm Configuration for Nemotron Super 49B +sft: + max_num_epochs: 3 + +policy: + train_micro_batch_size: 8 + +logger: + wandb: + project: "sft-nemotron-joyang" + name: "sft-${data.dataset_name}-nemotron-super-49b-joyang" + tensorboard: + log_dir: "tb_logs-openmathinstruct-nemorl-1M_train" + mlflow: + experiment_name: "sft-dev" + run_name: "openmathinstruct-nemorl-1M_train" diff --git a/examples/configs/sft_nemotron_super_49b.yaml b/examples/configs/sft_nemotron_super_49b_base.yaml similarity index 68% rename from examples/configs/sft_nemotron_super_49b.yaml rename to examples/configs/sft_nemotron_super_49b_base.yaml index d79837dbb8..4b153d53a5 100644 --- a/examples/configs/sft_nemotron_super_49b.yaml +++ b/examples/configs/sft_nemotron_super_49b_base.yaml @@ -1,6 +1,6 @@ -# SFT Algorithm Configuration +# Base SFT Configuration for Nemotron Super 49B sft: - max_num_epochs: 3 + max_num_epochs: 1 max_num_steps: 100 val_period: 10 val_batches: 8 @@ -19,16 +19,13 @@ checkpointing: checkpoint_must_save_by: null policy: - # model_name: Qwen/Qwen2.5-7B-Instruct - # tokenizer: - # name: Qwen/Qwen2.5-7B-Instruct model_name: "/lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf" tokenizer: name: ${policy.model_name} max_total_sequence_length: 4096 precision: "bfloat16" train_global_batch_size: 128 - train_micro_batch_size: 8 + train_micro_batch_size: 1 dtensor_cfg: _v2: true @@ -38,7 +35,7 @@ policy: enabled: true sequence_parallel: false tensor_parallel_size: 4 - custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan + custom_parallel_plan: examples.custom_parallel.llama_nemotron_super_49b_custom_plan.custom_parallel_plan megatron_cfg: enabled: false @@ -54,7 +51,6 @@ policy: train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} algorithm: "modified_first_fit_decreasing" sequence_length_round: 64 - # makes the training sequence length divisible by the tensor parallel size # this is useful for sequence parallel training @@ -68,22 +64,9 @@ policy: weight_decay: 0.01 betas: [0.9, 0.98] eps: 1e-8 - # when using Dtensor, we need to set foreach - # and fused to False foreach: False fused: False -# data: -# add_bos: true -# add_eos: true -# add_generation_prompt: false -# dataset_name: "tulu3_sft_mixture" -# cache_dir: "/lustre/fsw/portfolios/coreai/users/gvenkatakris/data-cache" -# max_input_seq_length: 1024 -# max_samples: 10000 -# shuffle: true -# test_size: 0.05 - data: max_input_seq_length: ${policy.max_total_sequence_length} add_bos: true @@ -91,7 +74,6 @@ data: add_generation_prompt: false shuffle: true num_workers: 20 - dataset_name: "squad" # You can use custom response datasets for training and validation. For example: # data: @@ -111,24 +93,26 @@ data: seed: null logger: - log_dir: "logs" # Base directory for all logs - wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running + log_dir: "logs" + wandb_enabled: true tensorboard_enabled: false mlflow_enabled: false - monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + monitor_gpus: false swanlab_enabled: false + num_val_samples_to_print: 0 wandb: - project: "sft-nemotron-joyang" - name: "sft-${data.dataset_name}-nemotron-super-49b-joyang" + project: "sft-nemotron-super-49b" + name: "sft-nemotron-super-49b" tensorboard: - log_dir: "tb_logs-openmathinstruct-nemorl-1M_train" + log_dir: "tb_logs-sft-nemotron-super-49b" mlflow: - experiment_name: "sft-dev" - run_name: "openmathinstruct-nemorl-1M_train" + experiment_name: "sft-nemotron-super-49b" + run_name: "sft-nemotron-super-49b" gpu_monitoring: - collection_interval: 10 # How often to collect GPU usage metrics (in seconds) - flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + collection_interval: 10 + flush_interval: 10 cluster: gpus_per_node: 8 num_nodes: 1 + diff --git a/examples/configs/sft_nemotron_super_49b_tulu_v3.yaml b/examples/configs/sft_nemotron_super_49b_tulu_v3.yaml deleted file mode 100644 index 7c819f0e9f..0000000000 --- a/examples/configs/sft_nemotron_super_49b_tulu_v3.yaml +++ /dev/null @@ -1,115 +0,0 @@ -# SFT on Tulu3-SFT-Mixture dataset -sft: - max_num_epochs: 1 - max_num_steps: 50 - val_period: 5 - val_batches: 8 - val_global_batch_size: 128 - val_micro_batch_size: 1 - val_at_start: true - seed: 42 - -checkpointing: - enabled: true - checkpoint_dir: "results/sft_nemotron_super_49b" - metric_name: "val_loss" - higher_is_better: false - keep_top_k: 100 - save_period: 500 - checkpoint_must_save_by: null - -policy: - model_name: "/lustre/fsw/portfolios/coreai/users/joyang/models/llama-3_3-nemotron-49b-instruct-128k-v1_2-hf" - tokenizer: - name: ${policy.model_name} - max_total_sequence_length: 32768 - precision: "bfloat16" - train_global_batch_size: 128 - train_micro_batch_size: 1 - - dtensor_cfg: - _v2: true - activation_checkpointing: true - context_parallel_size: 8 - cpu_offload: false - enabled: true - sequence_parallel: false - tensor_parallel_size: 4 - custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan - - megatron_cfg: - enabled: false - - dynamic_batching: - enabled: false - train_mb_tokens: 4096 - logprob_mb_tokens: 8192 - sequence_length_round: 64 - - sequence_packing: - enabled: false - train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} - algorithm: "modified_first_fit_decreasing" - sequence_length_round: 64 - - # makes the training sequence length divisible by the tensor parallel size - # this is useful for sequence parallel training - make_sequence_length_divisible_by: ${max:${mul:${policy.dtensor_cfg.context_parallel_size}, 2}, ${policy.max_total_sequence_length}} - max_grad_norm: null - - optimizer: - name: "torch.optim.AdamW" - kwargs: - lr: 1e-5 - weight_decay: 0.01 - betas: [0.9, 0.98] - eps: 1e-8 - foreach: False - fused: False - scheduler: - - name: "torch.optim.lr_scheduler.LinearLR" - kwargs: - start_factor: 0.1 - end_factor: 1.0 - total_iters: 10 # warmup_steps - - name: "torch.optim.lr_scheduler.CosineAnnealingLR" - kwargs: - # max_num_steps - warmup_steps = cosine steps - T_max: 40 - eta_min: 2e-6 - - milestones: [10] - -data: - max_input_seq_length: ${policy.max_total_sequence_length} - dataset_name: "tulu3_sft_mixture" - add_bos: true - add_eos: true - add_generation_prompt: false - shuffle: true - num_workers: 20 - test_size: 0.05 - # max_samples: 10000 # remove this line to use all data - -logger: - log_dir: "logs" - wandb_enabled: true - tensorboard_enabled: false - mlflow_enabled: false - monitor_gpus: true - swanlab_enabled: false - num_val_samples_to_print: 0 - wandb: - project: "nemotron-tulu-3-sft" - name: "nemotron-tulu-3" - tensorboard: - log_dir: "tb_logs-nemotron-tulu-3-sft" - mlflow: - experiment_name: "nemotron-tulu-3-sft" - run_name: "nemotron-tulu-3-sft" - gpu_monitoring: - collection_interval: 10 - flush_interval: 10 - -cluster: - gpus_per_node: 8 - num_nodes: 8 \ No newline at end of file diff --git a/examples/custom_parallel.py b/examples/custom_parallel/custom_parallel.py similarity index 100% rename from examples/custom_parallel.py rename to examples/custom_parallel/custom_parallel.py diff --git a/examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.py b/examples/custom_parallel/llama_nemotron_super_49b_custom_plan.py similarity index 100% rename from examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.py rename to examples/custom_parallel/llama_nemotron_super_49b_custom_plan.py diff --git a/nemo_rl/data/datasets/response_datasets/tulu3.py b/nemo_rl/data/datasets/response_datasets/tulu3.py index 0efc3841a9..dd1956a018 100644 --- a/nemo_rl/data/datasets/response_datasets/tulu3.py +++ b/nemo_rl/data/datasets/response_datasets/tulu3.py @@ -19,65 +19,6 @@ from nemo_rl.data.interfaces import TaskDataSpec - -def to_preference_data_format( - data: dict[str, Any], -) -> dict[ - str, list[dict[str, int | list[dict[str, str | Any]]]] | list[dict[str, str]] -]: - chosen_conversation = data["chosen"] - rejected_conversation = data["rejected"] - - context = chosen_conversation[:-1] - - # We assume that except last assistant response, all messages in - # chosen and rejected conversations are similar. Validating this... - assert json.dumps(context, ensure_ascii=False) == json.dumps( - rejected_conversation[:-1], ensure_ascii=False - ), ( - f"Context mismatch.\n\nchosen: {chosen_conversation}\n\n rejected: {rejected_conversation}" - ) - - # We assume that last response is always from the assistant. Validating this... - assert chosen_conversation[-1]["role"] == "assistant", ( - f"The last chosen response ({chosen_conversation[-1]}) is not from assistant!" - ) - assert rejected_conversation[-1]["role"] == "assistant", ( - f"The last rejected response ({rejected_conversation[-1]}) is not from assistant!" - ) - - chosen_response = chosen_conversation[-1]["content"] - rejected_response = rejected_conversation[-1]["content"] - - return { - "context": context, - "completions": [ - { - "rank": 0, - "completion": [{"role": "assistant", "content": chosen_response}], - }, - { - "rank": 1, - "completion": [{"role": "assistant", "content": rejected_response}], - }, - ], - } - - -class Tulu3PreferenceDataset: - """Tulu3 preference dataset for DPO training.""" - - def __init__(self) -> None: - ds = load_dataset( - path="allenai/llama-3.1-tulu-3-8b-preference-mixture", - trust_remote_code=True, - ) - self.formatted_ds = ds.map(to_preference_data_format) - - self.task_spec = TaskDataSpec( - task_name="Tulu3Preference", - ) - def format_tulu3_sft_mixture(data: dict[str, Any]) -> dict[str, str | dict[str, str]]: """format for Tulu3 SFT data.""" messages = data["messages"] diff --git a/tests/test_suites/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.sh b/tests/test_suites/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.sh new file mode 100755 index 0000000000..9f1e6b9c94 --- /dev/null +++ b/tests/test_suites/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.sh @@ -0,0 +1,40 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=120 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_helpsteer3.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.1' \ + "data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.1" +fi + diff --git a/tests/test_suites/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.sh b/tests/test_suites/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.sh new file mode 100755 index 0000000000..f907d332b7 --- /dev/null +++ b/tests/test_suites/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.sh @@ -0,0 +1,40 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=16 +STEPS_PER_RUN=10 +MAX_STEPS=10 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_helpsteer3.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.1' \ + "data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.1" +fi + diff --git a/tests/test_suites/llm/sft-nemotron-super-49b-tulu-v3.sh b/tests/test_suites/llm/sft-nemotron-super-49b-tulu-v3.sh new file mode 100755 index 0000000000..c9dd4028b1 --- /dev/null +++ b/tests/test_suites/llm/sft-nemotron-super-49b-tulu-v3.sh @@ -0,0 +1,40 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=8 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=120 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_sft.py \ + --config $CONFIG_PATH \ + sft.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] < 2.0' \ + 'data["train/loss"]["50"] < 1.5' +fi + diff --git a/tests/test_suites/llm/sft-nemotron-super-49b.sh b/tests/test_suites/llm/sft-nemotron-super-49b.sh new file mode 100755 index 0000000000..c40d368558 --- /dev/null +++ b/tests/test_suites/llm/sft-nemotron-super-49b.sh @@ -0,0 +1,40 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=50 +MAX_STEPS=50 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=90 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_sft.py \ + --config $CONFIG_PATH \ + sft.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["1"] < 2.0' \ + 'data["train/loss"]["50"] < 1.5' +fi + diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 92cc2ea932..71c5a54100 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -34,6 +34,10 @@ tests/test_suites/llm/grpo-deepscaler-1.5b-8K.sh # Deepscaler (GSPO) tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh +# HelpSteer3 tests +tests/test_suites/llm/grpo-helpsteer3-llama-3.2-1b-1n8g-fsdp2tp1.sh +tests/test_suites/llm/grpo-helpsteer3-llama-3.3-nemotron-super-49b-v1.5-4n8g-fsdp2tp8.sh + # GRPO math test run (32K context mcore) tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh @@ -68,6 +72,10 @@ tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron.sh # sequence packing tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.sh +# Nemotron Super 49B SFT tests +tests/test_suites/llm/sft-nemotron-super-49b.sh +tests/test_suites/llm/sft-nemotron-super-49b-tulu-v3.sh + ####### # DPO # #######