From eb2db8b0b6f15e708b6bd2a16c1d348ad2a2b9b9 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 20 Nov 2024 12:27:14 -0800 Subject: [PATCH 1/5] feat: TRTLLM API handle tokenizers without pad_id (e.g., tiktoken) (#399) Signed-off-by: Terry Kong Signed-off-by: NeMo-Aligner CI Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nemo_aligner/utils/trt_llm.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/nemo_aligner/utils/trt_llm.py b/nemo_aligner/utils/trt_llm.py index 1f879064d..9f1a9345f 100644 --- a/nemo_aligner/utils/trt_llm.py +++ b/nemo_aligner/utils/trt_llm.py @@ -44,8 +44,9 @@ def append_and_repad_list(list_of_items, item_to_append, pad_id): class GPTGenerateTRTLLM: - # If a tokenizer does not have a pad_id, we use a large negative number and replace - # with self.eos_id after generation. + # Use a reserved negative number since there is variation between tokenizers if + # they (1) have a pad_id (2) don't have a pad_id or (3) have None as the pad_id. + # This pad_id is replaced with eos_id after generation. DEFAULT_PAD_ID = -42 def __init__( @@ -72,12 +73,6 @@ def __init__( "You are trying to use NeMo-Aligner's TensorRT-LLM acceleration for LLM generation. Please build the dockerfile to enable this feature: https://github.com/NVIDIA/NeMo-Aligner/blob/main/Dockerfile" ) - # If this assert turns out to be a blocker with some tokenizers, potential workarounds could be to: - # - add a config option to allow specifying which token we pass as `end_id` to TRT-LLM (should - # be a token that the model is guaranteed to never generate) - assert ( - tokenizer.pad_id != tokenizer.eos_id - ), f"We require tokenizers to have a different {tokenizer.pad_id=} than {tokenizer.eos_id=} when using TRT-LLM. This is to make sure all code goes into the same path and include the eos_id when the response lengths are computed" assert max_input_len > 0 assert max_generation_length > 0 assert ( @@ -104,7 +99,7 @@ def __init__( rng_generator.manual_seed(seed) self.rng_generator = rng_generator - self.pad_id = tokenizer.pad_id if tokenizer.pad_id is not None else GPTGenerateTRTLLM.DEFAULT_PAD_ID + self.pad_id = GPTGenerateTRTLLM.DEFAULT_PAD_ID self.eos_id = tokenizer.eos_id end_strings = list(end_strings) From 716e5033a462f7f0b48f1ffd0b2852e276645aa5 Mon Sep 17 00:00:00 2001 From: Alexander Bukharin <59148829+abukharin3@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:07:55 -0500 Subject: [PATCH 2/5] feat: adds REINFORCE algorithm (#357) Signed-off-by: Terry Kong Signed-off-by: NeMo-Aligner CI Signed-off-by: abukharin Co-authored-by: abukharin Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Terry Kong --- .github/workflows/cicd-main.yml | 1 + CHANGELOG.md | 1 + README.md | 2 + docs/user-guide/reinforce.rst | 256 ++++++++ .../nlp/gpt/conf/gpt_reinforce_actor.yaml | 216 +++++++ examples/nlp/gpt/train_gpt_reinforce_actor.py | 198 ++++++ nemo_aligner/algorithms/reinforce.py | 599 ++++++++++++++++++ .../nlp/gpt/megatron_gpt_reinforce_actor.py | 393 ++++++++++++ nemo_aligner/utils/ppo_utils.py | 28 + tests/functional/reinforce.sh | 178 ++++++ .../test_cases/reinforce-llama3-pp2-reshard | 28 + tests/test_ppo_utils.py | 22 +- 12 files changed, 1921 insertions(+), 1 deletion(-) create mode 100644 docs/user-guide/reinforce.rst create mode 100644 examples/nlp/gpt/conf/gpt_reinforce_actor.yaml create mode 100644 examples/nlp/gpt/train_gpt_reinforce_actor.py create mode 100644 nemo_aligner/algorithms/reinforce.py create mode 100644 nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py create mode 100755 tests/functional/reinforce.sh create mode 100755 tests/functional/test_cases/reinforce-llama3-pp2-reshard diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index a2784d592..d2d27e95a 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -90,6 +90,7 @@ jobs: matrix: test_case: - ppo-llama3-pp2-reshard + - reinforce-llama3-pp2-reshard - dpo-llama3 - kd-llama3 - sft-llama3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c2f34819..63cd9ba5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) durations = timer.consume_durations() ``` - Add code and instructions for replicating Reward Modeling training in HelpSteer2 and HelpSteer2-Preference +- Implement REINFORCE algorithm. ### Breaking Changes - Upgrade TRTLLM dependency from v0.10.0 to v0.12.0 and migrate from `GPTSession` cpp runtime to `ModelRunner` python runtime. Please use the latest Dockerfile. diff --git a/README.md b/README.md index 66029e4da..a2a500e26 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ The toolkit is currently in it's early stages. We are committed to improving the * **Reward Model Training** * **Reinforcement Learning from Human Feedback using the [PPO](https://arxiv.org/pdf/1707.06347.pdf) Algorithm** * [Llama3-70B-PPO-Chat](https://huggingface.co/nvidia/Llama3-70B-PPO-Chat) aligned with NeMo-Aligner using TRT-LLM. +* **Reinforcement Learning from Human Feedback using the REINFORCE Algorithm** + * [Llama-3.1-Nemotron-70B-Instruct](https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct) aligned with NeMo-Aligner using TRT-LLM. * **Direct Preference Optimization** as described in [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290) * [Llama3-70B-DPO-Chat](https://huggingface.co/nvidia/Llama3-70B-DPO-Chat) aligned with NeMo Aligner. * **Self-Play Fine-Tuning (SPIN)** as described in [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/pdf/2401.01335) diff --git a/docs/user-guide/reinforce.rst b/docs/user-guide/reinforce.rst new file mode 100644 index 000000000..cc3005db1 --- /dev/null +++ b/docs/user-guide/reinforce.rst @@ -0,0 +1,256 @@ +.. include:: /content/nemo.rsts + +.. _model-aligner-reinforce: + +Model Alignment by REINFORCE +@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + +In this tutorial, we will guide you through the process of aligning a NeMo Framework model using REINFORCE. This method can be applied to various models, including LLaMa2 and Mistral, with our scripts functioning consistently across different models. + +REINFORCE is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide ` and the :ref:`SFT guide `. After obtaining the SFT model, we will also need to train a reward model as in :ref:`PPO guide `. We will use the REINFORCE algorithm on the `Anthropic-HH-RLHF `__ dataset. + +REINFORCE Training +############ + +After you have fine-tuned a GPT model using Supervised Fine-Tuning (SFT), and trained a reward model as explained in the preceding section, you can start aligning the policy using REINFORCE. + +During REINFORCE training, we have three models interacting with each other, which Aligner runs in two separate jobs: + +#. The Policy Network: This is the model we are training and it should start from an SFT model. +#. The Reward Model (RM): This model accepts a prompt combined with a response as input and produces a single scalar value, known as the reward. The REINFORCE algorithm aims to maximize this reward. +#. The Initial Policy Network (also known as the Reference Model): We use this model to compute a KL Divergence penalty term that ensures that the PPO Actor does not diverge too much from the Initial Policy. This way, we prevent the REINFORCE Actor from overfitting to the rewards given by the RM, and ensure it does not forget the knowledge it acquired during pretraining and SFT. This model should be the one used to initialize the REINFORCE Actor Network. + +The next section discusses how to launch each of these two jobs. + +Launching the Reward Model and Critic Server +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +To launch the server: + +.. code-block:: bash + + #!/bin/bash + RM_NEMO_FILE="/path/to/trained_rm.nemo" + GPFS="/path/to/nemo-aligner-repo" + + RESULTS_DIR="critic_results_dir" + + cd ${GPFS} + export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ + && export HYDRA_FULL_ERROR=1 \ + && python -u examples/nlp/gpt/serve_reward_model.py \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + ++model.tensor_model_parallel_size=4 \ + rm_model_file=${RM_NEMO_FILE} + + +The above example launches the reward model server on eight GPUs and one node. Make sure to change trainer.devices, trainer.num_nodes depending on your model size and scale. Aligner will work on any scale. Also, make sure to tune the trainer.reinforce.inference_micro_batch_size argument. This argument sets the size of the batch the REINFORCE actor is allowed to send to the reward per DP rank. + +Launch the Initial Policy and REINFORCE Actor Training +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +The REINFORCE Actor training job contains the master controller that makes the HTTP calls to all servers when needed. To launch the REINFORCE Actor and Initial Policy server: + +.. code-block:: bash + + GPFS="/path/to/nemo-aligner-repo" + TRAIN_DATA_PATH="/path/to/train_prompts.jsonl" + VALID_DATA_PATH="/path/to/test_prompts.jsonl" + + PRETRAINED_ACTOR_NEMO_FILE="/path/to/sft_checkpoint.nemo" + RESULTS_DIR="/path/to/actor_results_dir" + + USE_FLASK=False + ACTOR_LR=1e-6 + KL=0.01 + NUM_ROLLOUTS=32 + ACTOR_GBS=32 + REWARD_PORT=5555 + # Change this to the hostname of server hosting the reward model + host_reward="localhost" + + cd ${GPFS} + export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ + && export HYDRA_FULL_ERROR=1 \ + && python -u examples/nlp/gpt/train_gpt_reinforce_actor.py \ + "model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ + pretrained_checkpoint.restore_from_path=\"${ACTOR_NEMO_FILE}\" \ + exp_manager.checkpoint_callback_params.save_top_k=1 \ + exp_manager.explicit_log_dir=\"${RESULTS_DIR}\" \ + trainer.reinforce.max_epochs=1 \ + trainer.reinforce.max_steps=313 \ + trainer.reinforce.val_check_interval=4 \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + trainer.reinforce.trt_llm.enable=True \ + trainer.reinforce.trt_llm.reshard=True \ + trainer.reinforce.trt_llm.unload_engine_train=False \ + ++model.tensor_model_parallel_size=4 \ + ++model.reinforce.num_rollout_samples=${NUM_ROLLOUTS} \ + model.global_batch_size=${ACTOR_GBS} \ + model.micro_batch_size=1 \ + model.optim.lr=\"\\\$\{multiply:${ACTOR_LR},1.001\}\" \ + model.optim.sched.warmup_steps=0 \ + model.optim.sched.constant_steps=312 \ + model.optim.sched.min_lr=${ACTOR_LR} \ + model.optim.weight_decay=0.01 \ + model.reinforce.rollout_micro_batch_size=16 \ + model.reinforce.forward_micro_batch_size=16 \ + model.reinforce.val_rollout_micro_batch_size=8 \ + model.data.data_impl=jsonl \ + remote_rm.reward_model.ip=${host_reward} \ + remote_rm.reward_model.port=${REWARD_PORT} \ + ++model.reinforce.length_params.max_length=2048 \ + trainer.reinforce.initial_policy_kl_penalty="${KL}" \ + ++model.optim.bucket_cap_mb=200 \ + ++model.dist_ckpt_format=zarr \ + ++model.optim.overlap_grad_sync=False \ + ++model.optim.contiguous_grad_buffer=True \ + ++model.enable_nge=True \ + trainer.reinforce.batch_iterator.use_flask=${USE_FLASK} \ + trainer.reinforce.rollout_batch_seq_length=4096 + +The above command launches the initial and actor server on one node with eight GPUs. + +Launching Both Servers for REINFORCE training +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +You can use slurm to launch the two jobs and get them to coordinate together in a full REINFORCE job through the following: + +.. code-block:: bash + + #!/bin/bash + #SBATCH -N 1 --ntasks-per-node 8 -A <> -p <> --job-name <> -t 4:00:00 --exclusive + #SBATCH hetjob + #SBATCH -N 1 --ntasks-per-node 8 -A <> -p <> --job-name <> -t 4:00:00 --exclusive + + NAME="reinforce" + + # PARAMETERS + RM_NEMO_FILE="/path/to/trained_rm.nemo" + + ACTOR_NEMO_FILE="/path/to/sft_model.nemo" + + TRAIN_DATA_PATH="/path/to/train_prompts.jsonl" + VALID_DATA_PATH="/path/to/test_prompts.jsonl" + + RESULTS_DIR="/path/to/results_dir" + mkdir -p $RESULTS_DIR + + GPFS="/path/to/nemo-aligner-repo" + MOUNTS="--container-mounts=MOUNTS" # mounts + + CONTAINER=<<>> # use the latest NeMo Training container, Aligner will work there + + PROJECT=reinforce_run + + CRITIC_LOG_DIR="${RESULTS_DIR}/critic_results" + CRITIC_OUTFILE="${CRITIC_LOG_DIR}/critic_output_%j_%t.log" + CRITIC_ERRFILE="${CRITIC_LOG_DIR}/critic_error_%j_%t.err" + REWARD_PORT=5567 + CRITIC_CONFIG_PATH="${GPFS}/examples/nlp/gpt/conf" + CRITIC_CONFIG_NAME="inference_rm" + + CONF_DIR="${GPFS}/examples/nlp/gpt/conf" + CONFIG_NAME="gpt_reinforce_actor" + + mkdir -p $CRITIC_LOG_DIR + + CRITIC_NAME="${NAME}_critic" + + read -r -d '' cmd_critic_inference <`__ script from the NeMo codebase to run more rigorous evaluation of your trained model. \ No newline at end of file diff --git a/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml b/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml new file mode 100644 index 000000000..8efe26bb5 --- /dev/null +++ b/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml @@ -0,0 +1,216 @@ +defaults: + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: + +trainer: + # these args are respected + num_nodes: 8 + devices: 8 + accelerator: gpu + precision: bf16 + + reinforce: + + max_epochs: 1 + max_steps: -1 # max REINFORCE steps (-1 to go through the whole train set) + val_check_interval: 10 + save_interval: ${.val_check_interval} + gradient_clip_val: 1.0 + + # REINFORCE args to generate the data for training + initial_policy_kl_penalty: 0.01 + use_absolute_kl: True + num_rollouts_per_prompt: 4 + + + # the sequence length to pad the rollout batch for training to + # reduce fragmentation at the cost of using more + # memory, set to null if we don't want to pad it + # to a constant size + # if actual seq length is higher than this a warning will be raised + # but will not crash and training will still proceed on the larger + # sequence length + rollout_batch_seq_length: null + + # Speed-up training by accelerating inference stage using TRTLLM + trt_llm: + enable: True + reshard: False # if True then reshard the model into TP only for inference + + # TRTLLM preallocates activation memory according to the number of input tokens + # By default, assume the max input length is the difference between the model sequence length and the max number of tokens to generate + max_input_len: ${subtract:${model.encoder_seq_length}, ${model.reinforce.length_params.max_length}} + + # the seed to use for trt-llm generation + seed: ${model.seed} + + # for supported values see: https://github.com/NVIDIA/NeMo/blob/db6244857af3b012f645c7f4672522978bb608b1/nemo/export/trt_llm/converter/utils.py#L26 + model_type: llama # can be gptj, gptnext, llama, gemma, falcon + + # Save GPU memory by unloading and reloading the TRTLLM engine before and after the training stage + # Reloading the engine incurs a constant time overhead + unload_engine_train: False + + batch_iterator: + # When use_flask is True, we will spawn a flask server on rank 0 to balance the work of policy rollouts. + # This option is useful in cases where the generation length varies greatly across DP ranks since + # the flask server will allow DP ranks with shorter responses to process more samples and DP ranks + # with longer responses to process less samples. Thereby lowering the DP wait time. + use_flask: False + port: 5557 + + # pick up from the model + # *do not change this* + model_gbs: ${model.global_batch_size} + model_mbs: ${model.micro_batch_size} + + # no need to change these + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_time: null + max_epochs: ${.reinforce.max_epochs} + max_steps: ${.reinforce.max_steps} + +remote_rm: + # what to batch the inputs to + # set to None if no batching when sending inference to the reward model + pad_to_length: ${model.encoder_seq_length} + + # reward model server + reward_model: + name: reward_model + ip: localhost + port: 5555 + + +exp_manager: + explicit_log_dir: /results + exp_dir: null + name: megatron_gpt_reinforce_actor + create_wandb_logger: False + wandb_logger_kwargs: + project: nemo_aligner_reinforce + name: gpt3_reinforce_2b + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_rewards + save_top_k: 1 + mode: max + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_gpt-{step}-{consumed_samples}-{reinforce_optimization_step}-{epoch}-{val_rewards:.3f}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + +pretrained_checkpoint: + restore_from_path: null + +model: + + reinforce: + # training generation mbs + rollout_micro_batch_size: 8 + num_rollout_samples: 512 + + # mbs to do log prob inference, can be set to + # lower than rollout_micro_batch_size to reduce + # memory usage + forward_micro_batch_size: ${.rollout_micro_batch_size} + + # val generation mbs + val_rollout_micro_batch_size: ${.rollout_micro_batch_size} + num_val_samples: ${.num_rollout_samples} + + # to offload during generation or not + offload_adam_states: True + + # params for generation + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 0 + top_p: 1.0 + repetition_penalty: 1.0 + add_BOS: False + all_probs: False + compute_logprob: False + # will be used in NeMo version > 1.20.0 + # keeping it for now + end_strings: ["<|endoftext|>", ""] + + # length argument for autoregressive sampling + # max length means max amount of tokens to generate + length_params: + max_length: ${int_div:${model.encoder_seq_length}, 2} + min_length: 1 + + trt_llm: ${trainer.reinforce.trt_llm} + + peft: + peft_scheme: "none" # ["lora", "none"] + restore_from_path: null + restore_from_ckpt: + checkpoint_dir: null + checkpoint_name: null + + lora_tuning: + target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + mcore_gpt: True + # these control the mbs/gbs during REINFORCE training + micro_batch_size: 1 + global_batch_size: 64 + megatron_amp_O2: True + + encoder_seq_length: 4096 + max_position_embeddings: ${model.encoder_seq_length} + + ## Sequence Parallelism + sequence_parallel: False + + # miscellaneous + seed: 1234 + + optim: + name: distributed_fused_adam + bucket_cap_mb: 200 + overlap_grad_sync: False + contiguous_grad_buffer: True + lr: 9e-7 + weight_decay: 0.1 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 10 + constant_steps: 1000 + min_lr: 9e-8 + + precision: ${trainer.precision} + + data: + data_impl: jsonl + splits_string: null + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_prefix: null + + # define fields from the base model's config that should be ignored when merging with this config. + overwrite_base_config: + data: + data_prefix: True \ No newline at end of file diff --git a/examples/nlp/gpt/train_gpt_reinforce_actor.py b/examples/nlp/gpt/train_gpt_reinforce_actor.py new file mode 100644 index 000000000..0aa238fc4 --- /dev/null +++ b/examples/nlp/gpt/train_gpt_reinforce_actor.py @@ -0,0 +1,198 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import torch +import torch.multiprocessing as mp +from megatron.core.utils import divide +from omegaconf.omegaconf import OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.reinforce import ReinforceTrainer +from nemo_aligner.data.nlp.builders import ( + build_dataloader, + build_train_valid_test_rlhf_datasets, + collate_with_pad_to_max_batch, +) +from nemo_aligner.models.nlp.gpt.megatron_gpt_reinforce_actor import MegatronGPTReinforceActorModel +from nemo_aligner.models.nlp.gpt.reward_critic_clients import RemoteGPTRMClient +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.batch_iterators import get_batch_iterator_cls +from nemo_aligner.utils.distributed import Timer +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + resolve_and_create_trainer, + retrieve_custom_trainer_state_dict, +) +from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo, retrieve_model_state_dict_in_cpu + +"""Script to start REINFORCE training""" + +OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) +OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) +OmegaConf.register_new_resolver("subtract", lambda x, y: x - y, replace=True) + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="gpt_reinforce_actor") +def main(cfg) -> None: + cfg.model = load_and_override_model_config(cfg.pretrained_checkpoint.restore_from_path, cfg.model) + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + trainer = resolve_and_create_trainer(cfg, "reinforce") + + exp_manager(trainer, cfg.exp_manager) + + logger = CustomLoggerWrapper(trainer.loggers) + + ptl_model = load_from_nemo( + MegatronGPTReinforceActorModel, + cfg.model, + trainer, + strict=True, + restore_path=cfg.pretrained_checkpoint.restore_from_path, + ) + + init_peft(ptl_model, cfg.model) + + init_policy_state_dict = None + + # only need this if we are running with inital kl penalty & full-parameter tuning + if cfg.trainer.reinforce.initial_policy_kl_penalty > 0 and cfg.model.peft.peft_scheme == "none": + init_policy_state_dict = retrieve_model_state_dict_in_cpu( + ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False) + ) + + ptl_model.init_policy_state_dict = init_policy_state_dict + + # pull values from checkpoint + trainer_restore_path = trainer.ckpt_path + + # TODO: log this restore path + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + else: + custom_trainer_state_dict = None + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + # use the entire dataset + train_valid_test_num_samples = [-1, -1, -1] + train_ds, validation_ds, _ = build_train_valid_test_rlhf_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl=cfg.model.data.data_impl, + splits_string=cfg.model.data.splits_string, + train_valid_test_num_samples=train_valid_test_num_samples, + seq_length=cfg.model.data.seq_length, + seed=cfg.model.seed, + tokenizer=ptl_model.tokenizer, + ) + + max_seqlen = cfg.model.reinforce.length_params.max_length + eos_id = ptl_model.tokenizer.eos_id + + # collate fn to pad to the max seq length in the batch + collate_fn = collate_with_pad_to_max_batch(max_seqlen, eos_id, cfg, generate_masks_and_position_ids=False) + + train_dataloader_builder = partial( + build_dataloader, + cfg=cfg, + dataset=train_ds, + mbs=cfg.model.reinforce.rollout_micro_batch_size, + gbs=cfg.model.reinforce.num_rollout_samples, + collate_fn=collate_fn, + load_gbs=False, + ) + + val_dataloader_builder = partial( + build_dataloader, + cfg=cfg, + dataset=validation_ds, + mbs=cfg.model.reinforce.val_rollout_micro_batch_size, + gbs=cfg.model.reinforce.num_val_samples, + collate_fn=collate_fn, + load_gbs=False, + use_random_sampler=False, + ) + + # nemo uses the train dataloader to figure out + # max steps to take when max_steps = -1 + # but our train dataloader is for the prompts + # so we instaniate a dummy dataloader + # to get the proper max *optimization* steps + # nemo treats batch size of normal dataloader as GBS/DP + # so we need to offset it by DP + dummy_train_dataloader = torch.utils.data.DataLoader( + dataset=train_ds, batch_size=divide(cfg.model.global_batch_size, parallel_state.get_data_parallel_world_size()) + ) + + init_using_ptl(trainer, ptl_model, dummy_train_dataloader, train_ds) + # make sure the dummy train dataloader is never used + del ptl_model._train_dl + del dummy_train_dataloader + + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + rm = RemoteGPTRMClient(cfg.remote_rm) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) + + batch_iterator_cfg = cfg.trainer.reinforce.get("batch_iterator", {}) + batch_iterator_cls = get_batch_iterator_cls(batch_iterator_cfg) + + reinforce_trainer = ReinforceTrainer( + cfg=cfg.trainer.reinforce, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader_builder=train_dataloader_builder, + val_dataloader_builder=val_dataloader_builder, + collate_fn=collate_fn, + rm=rm, + batch_iterator_cls=batch_iterator_cls, + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + ) + + if custom_trainer_state_dict is not None: + reinforce_trainer.load_state_dict(custom_trainer_state_dict) + + reinforce_trainer.fit() + + # Note: The main loop creates multiple HTTPCommunicators which own a + # pytriton.client.FuturesModelClient. At the end of the loop, we manually + # close all FuturesModelClients since we do not use the context manager + # syntax. This guarantees all dangling threads are no longer blocking. + # `atexit` does not suffice since the registered cleanup function can be + # queued behind another blocking atexit registered function. + # TODO: utilize context managers to avoid manual cleanup + rm.communicator.close() + + +if __name__ == "__main__": + main() diff --git a/nemo_aligner/algorithms/reinforce.py b/nemo_aligner/algorithms/reinforce.py new file mode 100644 index 000000000..3bb127cee --- /dev/null +++ b/nemo_aligner/algorithms/reinforce.py @@ -0,0 +1,599 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from collections import UserDict +from contextlib import nullcontext +from typing import Dict, List, Optional, Union + +import pandas as pd +import torch +from megatron.core import parallel_state as mcore_parallel_state +from megatron.core.utils import divide +from omegaconf.dictconfig import DictConfig +from tqdm import tqdm +from typing_extensions import Self + +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingRandomSampler +from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split +from nemo.utils import logging +from nemo_aligner.models.nlp.gpt.megatron_gpt_reinforce_actor import MegatronGPTReinforceActorModel +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.distributed import ( + ScopedTimer, + all_reduce_dict, + masked_global_mean_var, + normalize_tensor, + rebalance_nd_tensor, +) +from nemo_aligner.utils.parallel_state import is_trt_llm_reshard, trt_llm_reshard_region +from nemo_aligner.utils.ppo_utils import calculate_kl_penalty, calculate_rloo_baseline, create_mask +from nemo_aligner.utils.server_utils import FutureResult +from nemo_aligner.utils.train_utils import clip_gradients +from nemo_aligner.utils.trainer_utils import check_progress, compute_num_steps_per_epoch +from nemo_aligner.utils.utils import clear_memory, cpu_dict, masked_mean + + +class ReinforceRolloutBatch(UserDict): + @classmethod + def from_rollout_batches( + cls: Self, rollout_batches: List[Dict], eos_id: int, rollout_batch_seq_length: Optional[int] + ) -> Self: + """Given a list of rollout batches, stack the tensors within and put them in a single dictionary + """ + stacked_dict = cls() + + for k in sorted(rollout_batches[0]): + + list_of_tensors = [item[k] for item in rollout_batches] + + if all(x.ndim == 1 for x in list_of_tensors): + tensor = torch.cat(list_of_tensors) + else: + pad_value = eos_id if k == "response_tokens" else 0 + + list_of_tensors = [row.flatten() for tensor in list_of_tensors for row in tensor] + # TODO: can we avoid padding locally then padding globally? + tensor = torch.nn.utils.rnn.pad_sequence(list_of_tensors, batch_first=True, padding_value=pad_value) + + # find the max sequence length globally + max_seqlen = torch.tensor([tensor.size(-1)], dtype=torch.long, device=torch.cuda.current_device()) + torch.distributed.all_reduce(max_seqlen, op=torch.distributed.ReduceOp.MAX) + + if rollout_batch_seq_length is None or max_seqlen >= rollout_batch_seq_length: + pad_seq_len = max_seqlen.item() + else: + # response tokens must be B x S because computing log probs requires us to offset by 1 + pad_seq_len = rollout_batch_seq_length if k == "response_tokens" else rollout_batch_seq_length - 1 + + tensor = torch.nn.functional.pad(tensor, (0, pad_seq_len - tensor.size(-1)), value=pad_value) + + stacked_dict[k] = tensor + + return stacked_dict + + def gather_and_balance_globally(self): + global_rollout_batch = type(self)() + + for k, tensor in self.data.items(): + # with reshard enabled, PP groups turn into DP groups. So need to balance them first and then + # balance by all the original DP groups + # NOTE: this logic needs to use the pure parallel state, that is one without sharding but needs + # to ping the is_trt_llm_reshard variable + if is_trt_llm_reshard(): + tensor = rebalance_nd_tensor(tensor, group=mcore_parallel_state.get_pipeline_model_parallel_group()) + + tensor = rebalance_nd_tensor(tensor, group=mcore_parallel_state.get_data_parallel_group()) + global_rollout_batch[k] = tensor + + return global_rollout_batch + + def chunk(self, rank, split_size, seed): + chunked_rollout_batch = type(self)() + + batch_set = set(tensor.size(0) for tensor in self.data.values()) + assert len(batch_set) == 1, "batch sizes are not the same across the rollout batch" + B = batch_set.pop() + + g_cpu = torch.Generator() + g_cpu.manual_seed(seed) + indices = torch.arange(B) + + for k in self.data: + chunked_rollout_batch[k] = self.data[k][indices].clone() + + return chunked_rollout_batch + + +def compute_num_rollout_microbatches(dataloader): + return divide( + divide(dataloader.batch_sampler.global_batch_size, dataloader.batch_sampler.micro_batch_size), + parallel_state.get_data_parallel_world_size(), + ) + + +class ReinforceTrainer: + """Trainer to coordinate REINFORCE training + """ + + def __init__( + self, + cfg: DictConfig, + model: MegatronGPTReinforceActorModel, + optimizer, + scheduler, + train_dataloader_builder, + val_dataloader_builder, + collate_fn, + rm, + batch_iterator_cls, + logger, + ckpt_callback, + run_timer, + ): + self.cfg = cfg + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.train_dataloader_builder = train_dataloader_builder + self.val_dataloader_builder = val_dataloader_builder + self.collate_fn = collate_fn + self.rm = rm + self.batch_iterator_cls = batch_iterator_cls + self.logger = logger + self.ckpt_callback = ckpt_callback + + # this timer checks if we should stop training + self.run_timer = run_timer + + self.trtllm_reshard = "trt_llm" in cfg and cfg.trt_llm.enable and cfg.trt_llm.reshard + + self.consumed_samples = 0 + # the step here is REINFORCE step + self.step = 0 + # keep track of how many times we optimized the actor + self.reinforce_optimization_step = 0 + + # compute `max_steps` + train_dataloader = self.train_dataloader_builder(consumed_samples=0) + if (not isinstance(train_dataloader.batch_sampler, MegatronPretrainingRandomSampler)) and ( + self.cfg.max_epochs is not None and self.cfg.max_epochs > 1 + ): + # if you use MegatronPretrainingBatchSampler as the batch_sampler passed to your train dataloader (in builders.py) + # then each epoch will repeat all your samples in the same order as the previous epoch, there is no shuffling + # to fix this, you should use MegatronPretrainingRandomSampler instead, which alleviates this issue and allows + # random shuffling for each epoch. + raise ValueError( + "max_epochs > 1 is not supported unless using `MegatronPretrainingRandomSampler` as the batch_sampler for your train dataloader" + ) + + self.num_steps_per_epoch = compute_num_steps_per_epoch(train_dataloader.batch_sampler) + self.set_max_steps() + + self.compute_init_policy_kl = self.cfg.initial_policy_kl_penalty > 0 + # size to pad our rollout batch to + self.rollout_batch_seq_length = self.cfg.rollout_batch_seq_length + + # for wandb table + self.train_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) + self.val_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) + + self.timer = ScopedTimer(reduction="mean", sync_cuda=True, buffer_size=1) + + def generate_reinforce_data(self, rollout_batch): + """generate reinforce specific data for training + """ + reinforce_rollout_data = {} + reinforce_rollout_metrics = {} + + prompt_lengths = rollout_batch["prompt_lengths"] + response_lengths = rollout_batch["response_lengths"] + prompt_tokens = rollout_batch["prompt_tokens"] + response_tokens = rollout_batch["response_tokens"] + rewards = rollout_batch["rewards"] + logprobs = rollout_batch["logprobs"] + is_end = rollout_batch["is_end"] + + if self.compute_init_policy_kl: + init_policy_kl = calculate_kl_penalty( + log_probs_a=rollout_batch["logprobs"], + log_probs_b=rollout_batch["init_logprobs"], + use_absolute_kl=self.cfg.use_absolute_kl, + ) + else: + init_policy_kl = torch.tensor(0, dtype=logprobs.dtype, device=logprobs.device) + + mask = create_mask(values=logprobs, prompt_lengths=prompt_lengths, response_lengths=response_lengths) + + init_policy_kl = masked_mean(init_policy_kl, mask, dim=-1) + rewards_with_kl = rewards - self.cfg.initial_policy_kl_penalty * init_policy_kl + + baseline = calculate_rloo_baseline(prompts=prompt_tokens, reward=rewards_with_kl, mask=is_end.float()) + + # collect everything we need to train REINFORCE + reinforce_rollout_data["mask"] = mask + reinforce_rollout_data["rewards_with_kl"] = rewards_with_kl + reinforce_rollout_data["baseline"] = baseline + reinforce_rollout_data["response_tokens"] = response_tokens + reinforce_rollout_data["is_end"] = is_end + + # compute metrics + # these are not global yet + reinforce_rollout_metrics["init_policy_kl"] = init_policy_kl.sum().item() if self.compute_init_policy_kl else 0 + reinforce_rollout_metrics["rewards_with_kl"] = rewards_with_kl.sum().item() + reinforce_rollout_metrics["num_samples"] = prompt_lengths.size(0) + + # now the metrics are global + reinforce_rollout_metrics = all_reduce_dict( + reinforce_rollout_metrics, + group=parallel_state.get_data_parallel_group(), + op=torch.distributed.ReduceOp.SUM, + ) + num_samples = reinforce_rollout_metrics.pop("num_samples") + reinforce_rollout_metrics = {k: v / num_samples for k, v in reinforce_rollout_metrics.items()} + + return reinforce_rollout_data, cpu_dict(reinforce_rollout_metrics) + + def _run_inference(self, dataloader_builder, consumed_samples, is_validation): + """this function is run per DP so the metrics need to be computed globally + assumes that the dataloader is built with the proper consumed samples value + """ + reshard_context = trt_llm_reshard_region if self.trtllm_reshard else nullcontext + + rollout_batches, futures = [], [] + + with reshard_context(): + # dataloader must be built within the reshard context because it uses DP rank and size + dataloader = dataloader_builder(consumed_samples=consumed_samples) + sampler_iter = iter(dataloader.batch_sampler) + + # must compute the number of microbatches in the reshard context + # so the DP groups are correct + num_microbatches = compute_num_rollout_microbatches(dataloader) + + with self.timer("batch_iterator_init"): + batch_iterator = self.batch_iterator_cls( + sampler_iter, num_microbatches, dataloader.dataset, self.collate_fn + ) + + with self.timer("generate"): + for batch in batch_iterator: + if not is_validation: + for _ in range(self.cfg.num_rollouts_per_prompt): + rollout_batch = self.model.infer(batch) + rollout_batch["prompt_tokens"] = batch["text"] + rollout_batches.append(rollout_batch) + futures.append(self.rm.infer_rm(rollout_batch)) + else: + rollout_batch = self.model.infer(batch) + rollout_batches.append(rollout_batch) + futures.append(self.rm.infer_rm(rollout_batch)) + + unbalanced_local_batch = ReinforceRolloutBatch.from_rollout_batches( + rollout_batches, + eos_id=self.model.tokenizer.eos_id, + rollout_batch_seq_length=self.cfg.rollout_batch_seq_length, + ) + global_rollout_batch = unbalanced_local_batch.gather_and_balance_globally() + + padded_rollout_sequence_length = global_rollout_batch["response_tokens"].size(-1) + + # the chunking must be outside of the TRT-LLM context because we do logprob calculation in nemo + balanced_local_batch = global_rollout_batch.chunk( + rank=parallel_state.get_data_parallel_rank(), + split_size=parallel_state.get_data_parallel_world_size(), + seed=self.step, + ) + # since we compute the logprobs in nemo we need to disable the resharding + batched_response_tokens = balanced_local_batch["response_tokens"] + + with self.timer("logprobs"): + rollout_logprobs = self.model.get_inference_log_probs(batched_response_tokens) + balanced_local_batch["logprobs"] = rollout_logprobs + + compute_init_policy_kl = not is_validation and self.compute_init_policy_kl + if compute_init_policy_kl: + with self.timer("init_logprobs"): + rollout_init_logprobs = self.model.get_init_policy_logprobs(batched_response_tokens) + balanced_local_batch["init_logprobs"] = rollout_init_logprobs + + # we send the request in sharded context, so we need to keep this sharding and then undo it + with reshard_context(): + with self.timer("critic_wait"): + rm_rollout_batches = [] + for future in futures: + rewards = future.result().squeeze(1) + rm_rollout_batches.append({"rewards": rewards}) + + unbalanced_rm_batch = ReinforceRolloutBatch.from_rollout_batches( + rm_rollout_batches, + eos_id=self.model.tokenizer.eos_id, + rollout_batch_seq_length=padded_rollout_sequence_length, + ) + global_rm_batch = unbalanced_rm_batch.gather_and_balance_globally() + + # chunking needs to be outside of reshard region + # NOTE: the seed here must be the same as the chunk before since we need to shuffle + # these values the same way as the other values + balanced_rm_batch = global_rm_batch.chunk( + rank=parallel_state.get_data_parallel_rank(), + split_size=parallel_state.get_data_parallel_world_size(), + seed=self.step, + ) + balanced_local_batch.update(balanced_rm_batch) + + global_rollout_batch.update(global_rm_batch) + + return balanced_local_batch, cpu_dict(self.compute_rollout_metrics(global_rollout_batch)) + + def compute_rollout_metrics(self, rollout_batch): + table = {} + + prompt_lengths = rollout_batch["prompt_lengths"] + response_lengths = rollout_batch["response_lengths"] + response_tokens = rollout_batch["response_tokens"] + rewards = rollout_batch["rewards"] + is_end = rollout_batch["is_end"] + + # take the first sample for logging + reward = rewards[0] + prompt_length = prompt_lengths[0] + response_length = response_lengths[0] + response_token = response_tokens[0] + + table["reward"] = reward.item() + table["prompt"] = self.model.tokenizer.ids_to_text(response_token[:prompt_length].tolist()) + table["response"] = self.model.tokenizer.ids_to_text(response_token[prompt_length:response_length].tolist()) + + metrics = { + "table": table, + "rollout_size": prompt_lengths.size(0), + "response_lengths": response_lengths.float().mean().item(), + "prompt_lengths": prompt_lengths.float().mean().item(), + "generation_length": (response_lengths - prompt_lengths).float().mean().item(), + "rewards": rewards.mean().item(), + "fraction_of_samples_properly_ended": is_end.float().mean().item(), + } + + return metrics + + @torch.no_grad() + def run_validation(self): + self.model.prepare_for_inference() + + _, rollout_metrics = self._run_inference(self.val_dataloader_builder, consumed_samples=0, is_validation=True) + + self.model.finish_inference() + return rollout_metrics + + @torch.no_grad() + def generate_rollouts(self): + with self.timer("prepare_for_inference"): + # Timing includes build if first step and refit if step > 1 + self.model.prepare_for_inference() + + rollout_batch, rollout_metrics = self._run_inference( + self.train_dataloader_builder, consumed_samples=self.consumed_samples, is_validation=False + ) + + self.consumed_samples += rollout_metrics["rollout_size"] + + reinforce_rollout_data, reinforce_rollout_metrics = self.generate_reinforce_data(rollout_batch) + + with self.timer("finish_inference"): + # Timing includes engine unloading if enabled + self.model.finish_inference() + + return ( + reinforce_rollout_data, + rollout_metrics | reinforce_rollout_metrics | {"consumed_samples": self.consumed_samples}, + self.timer.consume_durations(), + ) + + def run_training(self, dataloader_iter): + self.model.prepare_for_training() + + for batch in dataloader_iter: + self.optimizer.zero_grad() + + self.model.prepare_for_training_step() + loss_mean, metrics = self.model.get_loss_and_metrics(batch=batch, forward_only=False) + self.model.finish_training_step() + + grad_norm = clip_gradients(self.model, self.cfg.gradient_clip_val) + grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm + lr = self.optimizer.param_groups[0]["lr"] + + self.optimizer.step() + self.scheduler.step() + + if grad_norm is not None: + metrics["grad_norm"] = grad_norm + if lr is not None: + # Some optimizers like adafactor do not require a LR in their initializer + metrics["lr"] = lr + + metrics.update({"loss": loss_mean, "optim_step": self.reinforce_optimization_step}) + + self.logger.log_metrics( + metrics, step=self.step, prefix="train_optim/", + ) + + self.reinforce_optimization_step += 1 + + self.model.finish_training() + + # zero grad again incase it frees up grad mem + self.optimizer.zero_grad() + return loss_mean, metrics + + def fit(self): + epoch_iter = range(self.epoch, self.cfg.max_epochs) + if len(epoch_iter) <= 0: + # epoch done + return + + for _ in epoch_iter: + num_steps_in_epoch = min( + self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch + ) + loop_iter = range(num_steps_in_epoch) + + if not loop_iter: + return # training ended + + global_pbar = tqdm( + loop_iter, initial=self.step, total=self.max_steps, leave=True, desc="REINFORCE Global Step" + ) + + dp_size = parallel_state.get_data_parallel_world_size() + + num_to_load_on_each_dp = divide(self.cfg.model_gbs, dp_size) + + self.run_timer.start_time() + for _ in global_pbar: + step_metrics = {} + timing_metrics = {} + + clear_memory() + with self.timer("rollout_time"): + reinforce_rollout_data, metrics, rollout_timer_metrics = self.generate_rollouts() + # Consume rollout_time + timing_metrics.update(self.timer.consume_durations()) + + rollout_timer_metrics = all_reduce_dict(rollout_timer_metrics, op=torch.distributed.ReduceOp.MAX) + timing_metrics.update(rollout_timer_metrics) + + # logging + table_metrics = metrics.pop("table") + self.train_df.loc[len(self.train_df)] = [ + self.step, + table_metrics["prompt"], + table_metrics["response"], + table_metrics["reward"], + ] + metrics["epoch"] = self.epoch + 1 + self.logger.log_metrics( + metrics, step=self.step, prefix="train_rollouts/", + ) + self.logger.log_table( + key="table/train_rollouts", dataframe=self.train_df, step=self.step, + ) + + rollout_size = reinforce_rollout_data["response_tokens"].size(0) + rollout_dataloader_iter = get_iterator_k_split( + reinforce_rollout_data, divide(rollout_size, num_to_load_on_each_dp) + ) + # start training + clear_memory() + with self.timer("train_time"): + self.run_training(rollout_dataloader_iter) + + self.logger.log_metrics( + timing_metrics | self.timer.consume_durations(), step=self.step, prefix="timers/" + ) + + self.step += 1 + + run_time_exceeded = self.run_timer.is_finished() + run_val, save_model, is_train_end = check_progress( + self.step, + self.max_steps, + self.cfg.val_check_interval, + self.cfg.save_interval, + 1.0, # TODO:(geshen): allow for limit val batches + run_time_exceeded=run_time_exceeded, + ) + + if run_val: + with self.timer("validation_time"): + val_metrics = self.run_validation() + # Note: validation_time is logged one step behind (val step 5 means we've completed step 4) + timing_metrics.update(self.timer.consume_durations()) + + val_table_metrics = val_metrics.pop("table") + + self.val_df.loc[len(self.val_df)] = [ + self.step, + val_table_metrics["prompt"], + val_table_metrics["response"], + val_table_metrics["reward"], + ] + self.logger.log_metrics(val_metrics, step=self.step, prefix="val_rollouts/") + self.logger.log_table("table/val_rollouts", dataframe=self.val_df, step=self.step) + + step_metrics.update({f"val_{k}": v for k, v in val_metrics.items()}) + + step_metrics.update(timing_metrics) + step_metrics.update({f"train_{k}": v for k, v in metrics.items()}) + global_pbar.set_postfix(step_metrics) + + if save_model: + step_metrics = {k: torch.as_tensor(v) for k, v in step_metrics.items()} + self.save(step_metrics, is_train_end=is_train_end) + + if run_time_exceeded: + logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run") + return + + self.logger.finalize() + + def state_dict(self): + return { + "step": self.step, + "consumed_samples": self.consumed_samples, + "epoch": self.epoch, + "reinforce_optimization_step": self.reinforce_optimization_step, + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.consumed_samples = state_dict["consumed_samples"] + self.reinforce_optimization_step = state_dict["reinforce_optimization_step"] + + loaded_values = [self.step, self.consumed_samples, self.reinforce_optimization_step] + + # make sure everyone loaded the same checkpoint as rank 0 + to_broadcast = torch.tensor(loaded_values, dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.broadcast(to_broadcast, 0) + + assert loaded_values == to_broadcast.tolist() + # restore max steps we need to run for + self.set_max_steps() + + def save(self, extra_candidates=None, is_train_end=False): + self.model.prepare_for_training() + # load back in the adam states if needed + torch.cuda.synchronize() + torch.distributed.barrier() + + if extra_candidates is None: + extra_candidates = {} + + monitor_candidates = {k: torch.tensor(v, dtype=torch.int32) for k, v in self.state_dict().items()} + monitor_candidates.update(extra_candidates) + + self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end) + + self.model.finish_training() + + def set_max_steps(self): + self.max_steps = self.num_steps_per_epoch * self.cfg.max_epochs + + if (max_steps := self.cfg.get("max_steps", -1)) >= 0: + self.max_steps = min(self.max_steps, max_steps) + + @property + def epoch(self): + return self.step // self.num_steps_per_epoch diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py new file mode 100644 index 000000000..a98180fe8 --- /dev/null +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py @@ -0,0 +1,393 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch +import torch.distributed +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.utils import divide +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, + get_ltor_masks_and_position_ids, +) +from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.utils import logging +from nemo_aligner.models.alignable_interface import AlignableGenerativeInterface +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.distributed import ( + broadcast_2d_tensor_within_pp, + calculate_distributed_entropy, + from_parallel_logits_to_logprobs, +) +from nemo_aligner.utils.text_generation_utils import ( + TrackLengthGPTModelTextGenerationStrategy, + verify_is_valid_and_clamp_range_, +) +from nemo_aligner.utils.train_utils import ( + grad_reductions, + prepare_for_training_step, + set_eval, + set_sync_funcs, + set_train, +) +from nemo_aligner.utils.trt_llm import GPTGenerateTRTLLM +from nemo_aligner.utils.utils import ( + adapter_control, + clear_memory, + configure_batch_sizes, + cpu_weight_swap, + masked_mean, + offload_distributed_adam, +) + + +class MegatronGPTReinforceActorModel(NLPAdapterModelMixin, MegatronGPTModel, AlignableGenerativeInterface): + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) + self.automatic_optimization = False + + self.init_policy_state_dict = None + self.distributed_adam_offload_manager = None + + # length parameters for generation + self._length_params = OmegaConf.to_container(self.cfg.reinforce.length_params, resolve=True) + # sampling parameters for generation + self._sampling_params = OmegaConf.to_container(self.cfg.reinforce.sampling_params, resolve=True) + + self.to_offload_adam_states = self.cfg.reinforce.offload_adam_states and self.with_distributed_adam + self.forward_micro_batch_size = self.cfg.reinforce.forward_micro_batch_size + + self.use_trtllm_generation = "trt_llm" in self.cfg.reinforce and self.cfg.reinforce.trt_llm.enable + if self.use_trtllm_generation: + self.trtllm_generate = GPTGenerateTRTLLM( + model_cfg=self.cfg, + max_generation_length=self.cfg.reinforce.length_params.get("max_length", 1024), + max_input_len=self.cfg.reinforce.trt_llm.get("max_input_len", 1024), + generation_batch_size=self.cfg.reinforce.get("rollout_micro_batch_size", 4), + unload_engine_train=self.cfg.reinforce.trt_llm.get("unload_engine_train", False), + trt_model_type=self.cfg.reinforce.trt_llm.get("model_type", "llama"), + end_strings=self.cfg.reinforce.sampling_params["end_strings"], + reshard_model=self.cfg.reinforce.trt_llm.get("reshard", False), + sample_temperature=self.cfg.reinforce.sampling_params["temperature"], + sample_top_k=self.cfg.reinforce.sampling_params["top_k"], + sample_top_p=self.cfg.reinforce.sampling_params["top_p"], + repetition_penalty=self.cfg.reinforce.sampling_params["repetition_penalty"], + use_greedy=self.cfg.reinforce.sampling_params.get("use_greedy", False), + tokenizer=self.tokenizer, + seed=self.cfg.reinforce.trt_llm.get("seed", self.cfg.seed), + ) + + # training calls + def get_actor_forward_output_and_loss_func(self): + def fwd_output_and_loss_func(data_iterator, model): + batch = next(data_iterator) + required_keys = set() + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + required_keys.add("attention_mask") + + if parallel_state.is_pipeline_first_stage(): + required_keys.update(("response_tokens", "position_ids")) + + if parallel_state.is_pipeline_last_stage(): + required_keys.update(("response_tokens", "baseline", "mask", "rewards_with_kl", "is_end")) + + batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + + parallel_logits = model( + batch["response_tokens"], batch["position_ids"], batch["attention_mask"], labels=None, + ) + + def loss_func(parallel_logits): + mask = batch["mask"] + rewards_with_kl = batch["rewards_with_kl"] + baseline = batch["baseline"] + tokens = batch["response_tokens"] + is_end = batch["is_end"] + + is_end_mask = mask * is_end.view(-1, 1) + + curr_log_probs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=parallel_logits, target=tokens, higher_stability=True + ) + + reinforce_loss = -1 * curr_log_probs * (rewards_with_kl - baseline) + + if is_end_mask.sum() > 0: + loss = masked_mean(reinforce_loss, mask) + else: + # hack to disable this update since there are no valid tokens + loss = reinforce_loss.view(-1)[0] * 0 + + reduced_actor_loss = average_losses_across_data_parallel_group([loss]) + return ( + loss, + {"loss": reduced_actor_loss,}, + ) + + return parallel_logits, loss_func + + return fwd_output_and_loss_func + + def prepare_for_training(self): + configure_batch_sizes( + mbs=self.cfg.micro_batch_size, + gbs=self.cfg.global_batch_size, + dp=parallel_state.get_data_parallel_world_size(), + ) + self.onload_adam_states() + + def prepare_for_training_step(self): + # custom trainers will always zero grad for us + prepare_for_training_step(self, zero_grad=False) + + def get_loss_and_metrics(self, batch, forward_only): + sequence_length = batch["response_tokens"].size(1) + + attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(tokens=batch["response_tokens"]) + batch["attention_mask"] = attention_mask + batch["position_ids"] = position_ids + + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + set_sync_funcs(self, forward_only) + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_actor_forward_output_and_loss_func(), + data_iterator=self._make_data_iterator_list(data_iter), + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=sequence_length, + micro_batch_size=self.cfg.micro_batch_size, + ) + + metrics = {} + + for key in ["loss"]: + if losses_reduced_per_micro_batch: + metric_mean = torch.stack( + [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch] + ).mean() + else: + metric_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + torch.distributed.broadcast(metric_mean, get_last_rank()) + + metrics[key] = metric_mean.cpu().item() + + return metrics["loss"], metrics + + def finish_training_step(self): + grad_reductions(self) + + def finish_training(self): + """no need to offload adam states here + """ + + # inference calls + def get_logprob_output_only_func(self, inference_only=True): + fwd_output_only_func = self.get_forward_output_only_func() + + def log_prob_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) + + output_tensor, _ = fwd_output_only_func(iter([batch,]), model) + + def id_func(output_tensor, non_loss_data=True): + logprobs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, + target=batch[0], + inference_only=inference_only, + higher_stability=True, + ) + return logprobs + + return output_tensor, id_func + + return log_prob_output_only_func + + @torch.no_grad() + def get_inference_log_probs(self, response_tokens, forward_micro_batch_size=None): + if forward_micro_batch_size is None: + forward_micro_batch_size = self.forward_micro_batch_size + + set_sync_funcs(self, forward_only=True) + + mbs, seq_length = response_tokens.size() + num_microbatches = divide(mbs, forward_micro_batch_size) + attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(response_tokens) + + batch_iter = get_iterator_k_split([response_tokens, attention_mask, position_ids], num_microbatches) + + fwd_bwd_function = get_forward_backward_func() + logprobs_list = fwd_bwd_function( + forward_step_func=self.get_logprob_output_only_func(inference_only=True), + data_iterator=self._make_data_iterator_list(batch_iter), + model=self.model, + num_microbatches=num_microbatches, + forward_only=True, + seq_length=seq_length, + micro_batch_size=forward_micro_batch_size, + collect_non_loss_data=True, + ) + + logprobs = torch.cat(logprobs_list) if len(logprobs_list) > 0 else None + + # Broadcast it from last PP stage to everything else. + logprobs = broadcast_2d_tensor_within_pp(logprobs) + + return logprobs + + def prepare_for_inference(self): + """normally we would configure the micro batch calculator here + but the nemo generation already does the configuration""" + self._reset_activation_checkpointing_args() + self._reset_sequence_parallelism_args() + set_eval(self) + self.offload_adam_states() + + if self.use_trtllm_generation: + # TODO this might be optimized to avoid calling `refit()` twice in a row after a validation step + self.trtllm_generate.refit(self.model) + clear_memory() + + @torch.no_grad() + def infer(self, inference_batch): + prompt_tokens = inference_batch["text"].cuda(non_blocking=True) + prompt_lengths = inference_batch["length"].cuda(non_blocking=True) + inputs = (prompt_tokens, prompt_lengths) + + strategy = TrackLengthGPTModelTextGenerationStrategy( + model=self, context_lengths=prompt_lengths, max_length=self._length_params["max_length"] + ) + + if self.use_trtllm_generation: + actor_output = self.trtllm_generate.generate(inputs) + response_tokens = actor_output["response_tokens"] + response_lengths = actor_output["response_lengths"] + else: + actor_output = self.generate( + inputs=inputs, + length_params=self._length_params, + sampling_params=self._sampling_params, + strategy=strategy, + ) + response_tokens = torch.cuda.LongTensor(actor_output["token_ids"]) if actor_output else None + response_tokens = broadcast_2d_tensor_within_pp(response_tokens, dtype=torch.long) + response_lengths = strategy.get_lengths() + + max_response_length = response_lengths.max().item() + + # Sanity check to validate response length. + if max_response_length != response_tokens.size(1): + # This may actually happen because NeMo does not always stop generation after `max_length` in batch mode + # => `response_tokens` may contain up to `max_length + max_context_length` tokens. + # TODO once NeMo fixes this issue we should be able to always raise an exception when the check above fails, + # and remove the `if` below. + if ( + max_response_length >= response_tokens.size(1) + or response_tokens.size(1) != prompt_lengths.max().item() + self._length_params["max_length"] + ): + raise AssertionError( + f"max response length ({max_response_length}) does not match the size of " + f"`response_tokens` ({response_tokens.size(1)})" + ) + + # sometimes backends like TRT-LLM will generate invalid tokens + # so we need to also inplace mutate the response_tokens to be within the tokenizer range + is_valid = verify_is_valid_and_clamp_range_( + response_tokens, + response_lengths, + strategy, + self.tokenizer, + self.cfg.reinforce.sampling_params["end_strings"], + ) + + rollout_batch = { + "response_tokens": response_tokens, + "response_lengths": response_lengths, + "prompt_lengths": prompt_lengths, + "is_end": is_valid, + } + + # return in GPU, trainer needs to move to cpu + + return rollout_batch + + def get_init_policy_logprobs(self, response_tokens): + use_peft_init_policy = self.use_peft and self.init_policy_state_dict is None + + context_mgr = ( + adapter_control(self) + if use_peft_init_policy + else cpu_weight_swap(self, self.init_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2) + ) + + with context_mgr: + return self.get_inference_log_probs(response_tokens) + + def finish_inference(self): + # training will onload the adam states, no need to onload it here + self._restore_activation_checkpointing_args() + self._restore_sequence_parallelism_args() + + if self.use_trtllm_generation: + self.trtllm_generate.free() + + set_train(self) + + def offload_adam_states(self): + if self.distributed_adam_offload_manager is None: + + self.distributed_adam_offload_manager = ( + offload_distributed_adam( + self._optimizer.state_dict(state_dict_format=1, gather_on_root=False), force_clear_memory=True + ) + if self.to_offload_adam_states + else nullcontext() + ) + + # offload onto cpu + self.distributed_adam_offload_manager.__enter__() + + def onload_adam_states(self): + if self.distributed_adam_offload_manager is not None: + # load back onto GPU + self.distributed_adam_offload_manager.__exit__(None, None, None) + + self.distributed_adam_offload_manager = None + + def get_ltor_masks_and_position_ids(self, tokens): + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + data=tokens, + eod_token=self.tokenizer.eos_id, + reset_position_ids=self.cfg.data.get("reset_position_ids", False), + reset_attention_mask=self.cfg.data.get("reset_attention_mask", False), + eod_mask_loss=False, # since we ignore the loss mask here + ) + attention_mask = attention_mask.expand(tokens.size(0), -1, -1, -1) + position_ids = position_ids.expand(tokens.size(0), -1) + + return attention_mask, loss_mask, position_ids diff --git a/nemo_aligner/utils/ppo_utils.py b/nemo_aligner/utils/ppo_utils.py index 1d1f5cf67..0a69e3b9a 100644 --- a/nemo_aligner/utils/ppo_utils.py +++ b/nemo_aligner/utils/ppo_utils.py @@ -112,3 +112,31 @@ def select_topk(batch, num_select=1): selected_batch = {k: batch[k][selected_idx] for k in batch.keys()} return selected_batch + + +def calculate_rloo_baseline(prompts, reward, mask): + """ + Function to select the RLOO baseline for each (prompt, response) pair in the batch. + The same baseline is calculated for each prompt. Masked samples are not included + in the baseline calculation. + """ + unique_prompts = torch.unique(prompts, dim=0) + + baseline = torch.zeros_like(reward) + reward_device = reward.get_device() + if reward_device == -1: + reward_device = "cpu" + + for i in range(len(unique_prompts)): + is_matching_prompt = (prompts == unique_prompts[i]).all(1) + prompt_idx = torch.arange(len(prompts), device=reward_device)[is_matching_prompt] + rloo_mat = (1 - torch.eye(len(prompt_idx))).to(reward_device) + + if mask[prompt_idx].sum() <= 1: + # Ignore sample: set baseline equal to reward + baseline[prompt_idx] = reward[prompt_idx] + else: + rloo = torch.matmul(rloo_mat, reward[prompt_idx] * mask[prompt_idx]) / (mask[prompt_idx].sum() - 1) + baseline[prompt_idx] = rloo + + return baseline diff --git a/tests/functional/reinforce.sh b/tests/functional/reinforce.sh new file mode 100755 index 000000000..3529acdbb --- /dev/null +++ b/tests/functional/reinforce.sh @@ -0,0 +1,178 @@ +#!/bin/bash + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR +set -eoux pipefail + +export NCCL_ALGO=Tree +export NVTE_APPLY_QK_LAYER_SCALING=1 + +KL=${KL:-0.03} +LR=${LR:-9e-7} +RUN_ONLY=${RUN_ONLY:-} +GBS=${GBS:-2} +TP_SIZE=${TP_SIZE:-1} +PP_SIZE=${PP_SIZE:-2} +RESHARD=${RESHARD:-True} +RM_NEMO_FILE=${RM_NEMO_FILE} +ACTOR_NEMO_FILE=${ACTOR_NEMO_FILE} + + +MIN_LR=$(awk -v var="$LR" 'BEGIN {print var - 1e-11}') + +TRAIN_DATA_PATH=$SCRIPT_DIR/test_data/synthetic-123.jsonl +VALID_DATA_PATH=$SCRIPT_DIR/test_data/synthetic-123.jsonl + +NAME="reinforce_test" + +# PARAMETERS +RESULTS_DIR="/tmp/${NAME}" +mkdir -p $RESULTS_DIR + +GPFS=$(git rev-parse --show-toplevel) + +# W&B Logging +PROJECT=reinforce_test + +REWARD_LOG_DIR="${RESULTS_DIR}/reward_results" +REWARD_PORT=5555 + +mkdir -p $REWARD_LOG_DIR + +REWARD_NAME="${NAME}_reward" + +reward() { +export CUDA_VISIBLE_DEVICES=0 +export PYTHONPATH="${GPFS}:${PYTHONPATH:-}" +export HYDRA_FULL_ERROR=1 +python -u ${GPFS}/examples/nlp/gpt/serve_reward_model.py \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + inference.port=${REWARD_PORT} \ + ++model.tensor_model_parallel_size=1 \ + ++model.pipeline_model_parallel_size=1 \ + ++model.dist_ckpt_load_strictness=log_all \ + rm_model_file=${RM_NEMO_FILE} +} +reward_log_file=$(mktemp /tmp/reward-reinforce-log-XXXXXX) +if [[ $RUN_ONLY =~ actor* ]]; then + echo SKIPPING REWARD +elif [[ $RUN_ONLY == reward ]]; then + reward 2>&1 | stdbuf -o0 sed 's/^/[REWARD_SERVER]: /' | tee $reward_log_file + exit $? +else + reward 2>&1 | stdbuf -o0 sed 's/^/[REWARD_SERVER]: /' | tee $reward_log_file & +fi + +if [[ -z "${FAST:-}" ]]; then + sleep 15 +fi +######################################################################################### + +ACTOR_LOG_DIR="${RESULTS_DIR}/actor_results" +mkdir -p $ACTOR_LOG_DIR + +ACTOR_NAME="${NAME}_actor" +host_reward=localhost + +actor() { +export CUDA_VISIBLE_DEVICES=0,1 +export PYTHONPATH="${GPFS}:${PYTHONPATH:-}" +export HYDRA_FULL_ERROR=1 +mpirun -np 2 --allow-run-as-root python -u ${GPFS}/examples/nlp/gpt/train_gpt_reinforce_actor.py \ + "++model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ + pretrained_checkpoint.restore_from_path=${ACTOR_NEMO_FILE} \ + exp_manager.explicit_log_dir=${ACTOR_LOG_DIR} \ + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name=${ACTOR_NAME} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} \ + exp_manager.create_checkpoint_callback=True \ + trainer.num_nodes=1 \ + trainer.devices=2 \ + trainer.reinforce.trt_llm.enable=True \ + ++model.offload_adam_states=False \ + trainer.reinforce.trt_llm.reshard=${RESHARD} \ + trainer.reinforce.val_check_interval=2 \ + ++trainer.reinforce.save_interval=2 \ + ++model.micro_batch_size=1 \ + ++model.global_batch_size=${GBS} \ + ++model.tensor_model_parallel_size=${TP_SIZE} \ + ++model.pipeline_model_parallel_size=${PP_SIZE} \ + ++model.reinforce.entropy_bonus=0.0 \ + ++model.reinforce.ratio_eps=0.2 \ + ++model.encoder_seq_length=64 \ + ++exp_manager.checkpoint_callback_params.save_top_k=10 \ + ++model.reinforce.num_rollout_samples=${GBS} \ + ++model.reinforce.rollout_micro_batch_size=1 \ + ++model.reinforce.length_params.max_length=32 \ + ++model.reinforce.forward_micro_batch_size=1 \ + trainer.reinforce.initial_policy_kl_penalty="${KL}" \ + trainer.reinforce.rollout_batch_seq_length=32 \ + ++trainer.reinforce.flask_server.enable=True \ + ++model.optim.lr=${LR} \ + ++model.optim.sched.min_lr=${MIN_LR} \ + ++model.activations_checkpoint_granularity=full \ + ++model.activations_checkpoint_method=uniform \ + ++model.activations_checkpoint_num_layers=1 \ + ++model.optim.bucket_cap_mb=200 \ + ++model.optim.overlap_grad_sync=False \ + ++model.optim.contiguous_grad_buffer=True \ + ++model.enable_nge=True \ + remote_rm.reward_model.ip=${host_reward} \ + remote_rm.reward_model.port=${REWARD_PORT} \ + \ + +model.overwrite_base_config.optim=True \ + '~model.optim' \ + '++model.optim={name:sgd}' \ + model.reinforce.sampling_params.use_greedy=True \ + trainer.reinforce.save_interval=0 \ + trainer.reinforce.max_steps=3 \ + trainer.reinforce.trt_llm.model_type=llama \ + ++exp_manager=null \ + \ + ++model.dist_ckpt_load_strictness=log_all \ + $@ +} + +actor_log_file=$(mktemp /tmp/actor-reinforce-log-XXXXXX) +if [[ -z "$RUN_ONLY" || "$RUN_ONLY" == actor_trt || "$RUN_ONLY" == trt ]]; then + actor 2>&1 | stdbuf -o0 sed 's/^/[ACTOR_TRT]: /' +elif [[ "$RUN_ONLY" == actor_nemo || "$RUN_ONLY" == nemo ]]; then + actor trainer.reinforce.trt_llm.enable=False 2>&1 | stdbuf -o0 sed 's/^/[ACTOR_NEMO]: /' +else + echo "Only accepts RUN_ONLY=actor_nemo or actor_trt" + exit 1 +fi | tee $actor_log_file || true + +REWARD_ID=$(grep -oP "kill -SIGINT \K\d+" $reward_log_file) +if [[ $REWARD_ID =~ ^[0-9]+$ ]]; then + echo "Valid integer: $REWARD_ID" + kill -SIGINT $REWARD_ID +else + echo "Invalid REWARD_ID=$REWARD_ID detected!" + exit 1 +fi + +if ! fgrep 'Cleaning up communicator' $actor_log_file &>/dev/null; then + echo "[ERROR] Did not find 'Cleaning up communicator' in the actor logs ($actor_log_file) indicating the actor reached the end" + exit 1 +fi + +echo "Waiting for backgrounded processes to finish..." +wait +set +x +echo "[Finished] $0" diff --git a/tests/functional/test_cases/reinforce-llama3-pp2-reshard b/tests/functional/test_cases/reinforce-llama3-pp2-reshard new file mode 100755 index 000000000..7ee8342d7 --- /dev/null +++ b/tests/functional/test_cases/reinforce-llama3-pp2-reshard @@ -0,0 +1,28 @@ +#!/bin/bash + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +set -eoux pipefail + +GBS=2 \ + TP_SIZE=1 \ + PP_SIZE=2 \ + RESHARD=True \ + RM_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/llama3--nlayers4-hidden64-ffn224-dummy_rm-megatron_gpt.nemo \ + ACTOR_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/tiny-llama3-results-nlayers2-hidden128-ffn448-nhead4-qgroup2-megatron_gpt.nemo \ + bash ../reinforce.sh 2>&1 | tee $(basename $0).log diff --git a/tests/test_ppo_utils.py b/tests/test_ppo_utils.py index a12db274b..0959e9add 100644 --- a/tests/test_ppo_utils.py +++ b/tests/test_ppo_utils.py @@ -17,7 +17,12 @@ import torch import torch.nn.functional as F -from nemo_aligner.utils.ppo_utils import calculate_advantages_and_returns, calculate_entropy, calculate_ppo_rewards +from nemo_aligner.utils.ppo_utils import ( + calculate_advantages_and_returns, + calculate_entropy, + calculate_ppo_rewards, + calculate_rloo_baseline, +) class TestCalculateEntropy: @@ -120,3 +125,18 @@ def test_calculate_advantage_and_returns_small_example(self): assert torch.allclose(advantages, gt_advantages), "computed advantage is not the same as hand example" assert torch.allclose(returns, gt_advantages + values), "computed returns is not the same as hand example" + + +class TestCalculateRLOOBaseline: + def test_calculate_rloo_baseline_small_example(self): + + prompts = torch.Tensor([[1, 0], [1, 0], [0, 1], [1, 0], [1, 0], [0, 1], [0, 1], [0, 1],]) + + rewards = torch.Tensor([1, 0, 2, -3, 5, 7, -1, 0]) + mask = torch.Tensor([1, 1, 1, 1, 1, 1, 1, 0]) + + baseline = calculate_rloo_baseline(prompts, rewards, mask) + + gt_baseline = torch.Tensor([2 / 3, 1.0, 3.0, 2.0, -2 / 3, 1 / 2, 9 / 2, 8 / 2]) + + assert torch.allclose(baseline, gt_baseline), "computed baseline is not the same as hand example" From 79eed888144bf4c730858700fa6025cfa01bf5ef Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 14 Nov 2024 00:07:31 -0800 Subject: [PATCH 3/5] feat: update Aligner to use mcore export Signed-off-by: Terry Kong --- Dockerfile | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/Dockerfile b/Dockerfile index 44a9f8651..273bb2f64 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,8 +13,8 @@ ARG MAX_JOBS=8 # Git refs for dependencies ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG PYTRITON_VERSION=0.5.10 -ARG NEMO_TAG=19668e5320a2e2af0199b6d5e0b841993be3a634 # On: main -ARG MLM_TAG=25059d3bbf68be0751800f3644731df12a88f3f3 # On: main +ARG NEMO_TAG=aligner_export_mcore # On: main (TODO: change to upstream main commit after merge) +ARG MLM_TAG=main # On: main (TODO: change to correct commit after merge) ARG ALIGNER_COMMIT=main ARG TRTLLM_VERSION=v0.13.0 ARG PROTOBUF_VERSION=4.24.4 @@ -121,26 +121,3 @@ RUN cd /opt/NeMo-Aligner && \ RUN cd TensorRT-LLM && patch -p1 < ../NeMo-Aligner/setup/trtllm.patch -# TODO(terryk): This layer should be deleted ASAP after NeMo is bumped to include all of these PRs -RUN <<"EOF" bash -exu -cd NeMo -# Ensures we don't cherry-pick "future" origin/main commits -git fetch -a -# 0c92fe17df4642ffc33d5d8c0c83fda729e3910c: [fix] Ensures disabling exp_manager with exp_manager=null does not error NeMo#10651 -# 60e677423667c029dd05875da72bf0719774f844: [feat] Update get_model_parallel_src_rank to support tp-pp-dp ordering NeMo#10652 -# 0deaf6716cb4f20766c995ce25d129795f1ae200: fix[export]: update API for disabling device reassignment in TRTLLM for Aligner NeMo#10863 -# (superceded by 10863) 148543d6e9c66ff1f8562e84484448202249811d: feat: Migrate GPTSession refit path in Nemo export to ModelRunner for Aligner NeMo#10654 -for pr_and_commit in \ - "10651 0c92fe17df4642ffc33d5d8c0c83fda729e3910c" \ - "10652 60e677423667c029dd05875da72bf0719774f844" \ - "10863 0deaf6716cb4f20766c995ce25d129795f1ae200" \ -; do - pr=$(cut -f1 -d' ' <<<"$pr_and_commit") - head_pr_commit=$(cut -f2 -d' ' <<<"$pr_and_commit") - git fetch origin $head_pr_commit:PR-${pr} - # cherry-picks all commits between main and the top of the PR - git cherry-pick --allow-empty $(git merge-base origin/main PR-${pr})..PR-${pr} - # Tag cherry-picks to help - git tag cherry-pick-PR-${pr} -done -EOF From 7f75e47f8421bf4abb476f2bef89c7d7859889fb Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 14 Nov 2024 00:09:04 -0800 Subject: [PATCH 4/5] require grad_sync to be true now for this ckpt Signed-off-by: Terry Kong --- tests/functional/ppo.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/ppo.sh b/tests/functional/ppo.sh index 46f4974b5..fb8e7957d 100755 --- a/tests/functional/ppo.sh +++ b/tests/functional/ppo.sh @@ -82,7 +82,7 @@ python -u ${GPFS}/examples/nlp/gpt/serve_ppo_critic.py \ ++model.global_batch_size=1 \ ++model.tensor_model_parallel_size=1 \ ++model.optim.bucket_cap_mb=200 \ - ++model.optim.overlap_grad_sync=False \ + ++model.optim.overlap_grad_sync=True \ ++model.optim.contiguous_grad_buffer=True \ ++trainer.ppo.pad_sequence_length_to_multiple=32 \ model.reward_standardization.enable=True \ From 06817b095f82153bb6c062c97de74eba589a959a Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Fri, 22 Nov 2024 10:56:06 -0800 Subject: [PATCH 5/5] bump commits to public ones that have mcore-export support Signed-off-by: Terry Kong --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 273bb2f64..20441c4ca 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,8 +13,8 @@ ARG MAX_JOBS=8 # Git refs for dependencies ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG PYTRITON_VERSION=0.5.10 -ARG NEMO_TAG=aligner_export_mcore # On: main (TODO: change to upstream main commit after merge) -ARG MLM_TAG=main # On: main (TODO: change to correct commit after merge) +ARG NEMO_TAG=a153b8c58d56aa930749587017ee70d56f75445e # On: main +ARG MLM_TAG=2e355b7889258fc0756bd6a8036d09a896cc9caa # On: main ARG ALIGNER_COMMIT=main ARG TRTLLM_VERSION=v0.13.0 ARG PROTOBUF_VERSION=4.24.4