From 3ed1cb1cfa73cad94e1fdef9c0e0ffdd85e2bc57 Mon Sep 17 00:00:00 2001 From: arendu Date: Wed, 30 Oct 2024 21:00:54 +0000 Subject: [PATCH 1/7] wip Signed-off-by: arendu --- examples/nlp/gpt/conf/gpt_dpo.yaml | 1 + examples/nlp/gpt/train_gpt_dpo.py | 4 +- .../models/nlp/gpt/megatron_gpt_dpo_model.py | 421 ++++++++++++++++++ nemo_aligner/utils/utils.py | 6 +- 4 files changed, 428 insertions(+), 4 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 2a165bf9d..4aa67ec38 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -57,6 +57,7 @@ model: micro_batch_size: 1 global_batch_size: 64 megatron_amp_O2: True + mamba_hybrid: False dpo: # This default value ensures there are no numeric differences beween trained and reference policies when computing log probs. diff --git a/examples/nlp/gpt/train_gpt_dpo.py b/examples/nlp/gpt/train_gpt_dpo.py index aefa0c5ac..402528f43 100644 --- a/examples/nlp/gpt/train_gpt_dpo.py +++ b/examples/nlp/gpt/train_gpt_dpo.py @@ -21,7 +21,7 @@ from nemo.utils.exp_manager import exp_manager from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets -from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel +from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel, MegatronMambaDPOModel from nemo_aligner.utils.distributed import Timer from nemo_aligner.utils.train_script_utils import ( CustomLoggerWrapper, @@ -53,7 +53,7 @@ def main(cfg) -> None: logger = CustomLoggerWrapper(trainer.loggers) ptl_model = load_from_nemo( - MegatronGPTDPOModel, + MegatronMambaDPOModel if cfg.model.mamba_hybrid else MegatronGPTDPOModel, cfg.model, trainer, strict=True, diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 952b4e897..692e6a732 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -16,6 +16,8 @@ from functools import partial import torch +from megatron.core import parallel_state +from megatron.core.models.mamba import MambaModel from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from megatron.core.utils import divide @@ -23,6 +25,7 @@ from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel from nemo.collections.nlp.modules.common.megatron.utils import ( average_losses_across_data_parallel_group, get_iterator_k_split, @@ -460,3 +463,421 @@ def get_ref_policy_logprobs(self, batch): # return in GPU, trainer needs to move to cpu return ref_log_probs + +class MegatronMambaDPOModel(NLPAdapterModelMixin, MegatronMambaModel, SupervisedInterface): + """ + Megatron GPT DPO Model Training. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) + + if self.cfg.pipeline_model_parallel_size > 1 and not self.cfg.megatron_amp_O2: + warnings.warn( + "when using pipeline parallelism, it is recommended to set megatron_amp_O2 to be True to " + "avoid explicit casting for pipeline communication" + ) + self.automatic_optimization = False + self.ref_policy_state_dict = None + + self.ref_policy_kl_penalty = self.cfg.dpo.get("ref_policy_kl_penalty", 0.0) + self.preference_avg_log_probs = self.cfg.dpo.get("preference_average_log_probs", False) + self.sft_avg_log_probs = self.cfg.dpo.get("sft_average_log_probs", self.preference_avg_log_probs) + + self.preference_loss_weight = self.cfg.dpo.get("preference_loss_weight", 1) + self.sft_loss_weight = self.cfg.dpo.get("sft_loss_weight", 0) + assert ( + self.preference_loss_weight != 0 or self.sft_loss_weight != 0 + ), "sft loss weight and preference loss weight cannot both be 0" + + # variants of preference losses, by default DPO. + self.preference_loss = self.cfg.dpo.get("preference_loss", "dpo") + self.gt_reward_scale = self.cfg.dpo.get("gt_reward_scale", 1.0) + + @torch.no_grad() + def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False): + pi_logprobs = pi_logprobs.detach() + + dp_group = parallel_state.get_data_parallel_group() + + batch_logs = self.get_reduced_masked_logps( + pi_logprobs - ref_logprobs, labels[:, 1:], average_log_probs=average_log_probs + ) + + output_list = [torch.zeros_like(batch_logs) for _ in range(dp_group.size())] + + torch.distributed.all_gather(output_list, batch_logs, group=dp_group) + + split_iter = map(self.split_output_tensor, output_list) + + out_chosen, out_rejected = map(torch.cat, zip(*split_iter)) + + return out_chosen.flatten(), out_rejected.flatten() + + def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False): + def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): + batch = next(dataloader_iter) + + required_keys = set() + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + # there is a problem with apex ignoring the mask on the older models + # so we will always give the attention mask + required_keys.add("attention_mask") + + if parallel_state.is_pipeline_first_stage(): + required_keys.update(("chosen", "rejected", "position_ids")) + + if parallel_state.is_pipeline_last_stage(): + required_keys.update( + ( + "ref_policy_log_probs_chosen", + "ref_policy_log_probs_rejected", + "chosen_labels", + "rejected_labels", + "chosen_rewards", + "rejected_rewards", + ) + ) + + batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + + tokens, labels, ref_logprobs, gt_rewards = None, None, None, None + if batch["chosen"] is not None and batch["rejected"] is not None: + tokens = torch.cat((batch["chosen"], batch["rejected"]), dim=0) + + if batch["chosen_labels"] is not None and batch["rejected_labels"] is not None: + labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) + + if ( + batch.get("ref_policy_log_probs_chosen") is not None + and batch.get("ref_policy_log_probs_rejected") is not None + ): + ref_logprobs = torch.cat( + (batch["ref_policy_log_probs_chosen"], batch["ref_policy_log_probs_rejected"]), dim=0 + ) + + if batch["chosen_rewards"] is not None and batch["rejected_rewards"] is not None: + gt_rewards = torch.cat((batch["chosen_rewards"], batch["rejected_rewards"]), dim=0) + + # this is necessary if MBS > 1 with the new GBS padding logic, as you may get batch dim > 1 in some configs + # these two lines ensure your position_ids and attn_mask are always B=1 + # position_ids = batch["position_ids"][0:1] + attention_mask = batch["attention_mask"][0:1] + + # Model forward pass + forward_args = { + "input_ids": tokens, + "position_ids": batch["position_ids"], + "attention_mask": attention_mask, + "labels": None, + "loss_mask": None, + } + + # TODO: we can remove this someday when we no longer support legacy models + if not self.mcore_gpt: + forward_args["checkpoint_activations_all_layers"] = checkpoint_activations_all_layers + if not self.use_loss_mask: + forward_args.pop("loss_mask") + else: + forward_args.pop("loss_mask") + + output_tensor = model(**forward_args) + + # in this nemo version the model and autocast dtypes are not synced + # so we need to explicitly cast it + if not parallel_state.is_pipeline_last_stage(): + output_tensor = output_tensor.to(dtype=self.autocast_dtype) + + def logprobs_func(output_tensor, non_loss_data=True): + # This function is expected to be used only when `collect_non_loss_data=True` in the fwd_bwd_function of Megatron-LM. + # See https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/megatron/core/pipeline_parallel/schedules.py#L228 + assert non_loss_data + logprobs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, target=labels, inference_only=True, higher_stability=True, + ) + return {"logprobs": logprobs} + + def loss_func(output_tensor): + if validation_step and not self.cfg.data.get("validation_drop_last", True): + raise NotImplementedError("DPO does not support validation when cfg.data.drop_last=False") + + per_token_logps = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, + target=labels, + inference_only=validation_step, + higher_stability=True, + ) + + preference_loss, acc_chosen = self.loss_func( + per_token_logps, + ref_logprobs, + labels[:, 1:], + gt_rewards, + average_log_probs=self.preference_avg_log_probs, + ) + + sft_loss = torch.zeros_like(preference_loss) + if self.sft_loss_weight != 0: + sft_loss = self.sft_loss_func( + per_token_logps, labels[:, 1:], average_log_probs=self.sft_avg_log_probs + ) + loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss + + ( + reduced_loss, + reduced_preference_loss, + reduced_sft_loss, + reduced_acc, + ) = average_losses_across_data_parallel_group([loss, preference_loss, sft_loss, acc_chosen]) + + out_chosen, out_rejected = self.gather_and_split_rewards( + per_token_logps, ref_logprobs, labels, average_log_probs=self.preference_avg_log_probs + ) + + return ( + loss, + { + "avg": reduced_loss, + "avg_sft_loss": reduced_sft_loss, + "avg_preference_loss": reduced_preference_loss, + "acc": reduced_acc, + "out_chosen": out_chosen, + "out_rejected": out_rejected, + }, + ) + + if logprobs_only: + return output_tensor, logprobs_func + else: + return output_tensor, loss_func + + return fwd_output_and_loss_func + + def split_output_tensor(self, output_tensor): + chosen_logps, reject_logps = torch.split(output_tensor.float(), len(output_tensor) // 2, dim=0) + return chosen_logps, reject_logps + + def get_reduced_masked_logps(self, logps, labels, average_log_probs=False): + assert logps.shape == labels.shape, "logps and labels shape mismatch" + + loss_mask = (labels > -1).float() + + if average_log_probs: + # need to guard against divide by zero in case labels are all -100 + return (logps * loss_mask).sum(-1) / loss_mask.sum(-1).clamp(min=1) + else: + return (logps * loss_mask).sum(-1) + + def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): + rewards = self.get_reduced_masked_logps( + pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs + ) + chosen_rewards, reject_rewards = self.split_output_tensor(rewards) + rewards_delta = chosen_rewards - reject_rewards + + if self.preference_loss == "dpo": + loss = -torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta).mean(0) + elif self.preference_loss == "rpo_bwd_kl": + logbeta_hat_chosen = torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta) + logbeta_hat_rejected = torch.nn.functional.logsigmoid(-self.ref_policy_kl_penalty * rewards_delta) + + chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) + gt_rewards_delta = self.gt_reward_scale * (chosen_gt_rewards - reject_gt_rewards) + logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) + logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) + + loss = ( + torch.exp(logalpha_hat_chosen) * (logalpha_hat_chosen - logbeta_hat_chosen) + + torch.exp(logalpha_hat_rejected) * (logalpha_hat_rejected - logbeta_hat_rejected) + ).mean(0) + elif self.preference_loss == "rpo_fwd_kl": + logbeta_hat_chosen = torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta) + logbeta_hat_rejected = torch.nn.functional.logsigmoid(-self.ref_policy_kl_penalty * rewards_delta) + + chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) + gt_rewards_delta = self.gt_reward_scale * (chosen_gt_rewards - reject_gt_rewards) + logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) + logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) + + loss = ( + torch.exp(logbeta_hat_chosen) * (logbeta_hat_chosen - logalpha_hat_chosen) + + torch.exp(logbeta_hat_rejected) * (logbeta_hat_rejected - logalpha_hat_rejected) + ).mean(0) + elif self.preference_loss == "ipo": + loss = torch.mean((chosen_rewards - reject_rewards - 1.0 / (2.0 * self.ref_policy_kl_penalty)) ** 2, 0) + elif self.preference_loss == "rpo_sq": + chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) + gt_rewards_delta = self.gt_reward_scale * (chosen_gt_rewards - reject_gt_rewards) + + loss = torch.mean((self.ref_policy_kl_penalty * rewards_delta - gt_rewards_delta) ** 2, 0) + else: + raise NotImplementedError(f"preference_loss {self.preference_loss} is not implemented") + + with torch.no_grad(): + comp = chosen_rewards > reject_rewards + acc_chosen = comp.float().mean() + + return loss, acc_chosen + + def sft_loss_func(self, pi_logprobs, labels, average_log_probs=False): + logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) + chosen_logprobs, _ = self.split_output_tensor(logprobs) + return -chosen_logprobs.mean(0) + + def get_loss_and_metrics(self, batch, forward_only): + seq_length = batch["chosen"].shape[1] + + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + set_sync_funcs(self, forward_only) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(forward_only, logprobs_only=False), + data_iterator=data_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=self.cfg.micro_batch_size + * 2, # each minibatch has 2 comparisons so tensor shape will be mbs * 2 + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + # NOTE: assume that the returned values are already gathered across the DP workers + rewards_chosen = torch.cat([item["out_chosen"] for item in losses_reduced_per_micro_batch]) + rewards_rejected = torch.cat([item["out_rejected"] for item in losses_reduced_per_micro_batch]) + + rewards_all = torch.cat((rewards_chosen, rewards_rejected)) + rewards_chosen_mean = rewards_chosen.mean() + rewards_rejected_mean = rewards_rejected.mean() + rewards_all_mean = rewards_all.mean() + rewards_all_std = rewards_all.std() + + loss_mean = torch.as_tensor( + [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + sft_loss_mean = torch.as_tensor( + [loss_reduced["avg_sft_loss"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + preference_loss_mean = torch.as_tensor( + [loss_reduced["avg_preference_loss"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + acc_mean = torch.as_tensor( + [loss_reduced["acc"] for loss_reduced in losses_reduced_per_micro_batch], + device=torch.cuda.current_device(), + ).mean() + else: + loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + sft_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + preference_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + acc_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + rewards_chosen_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + rewards_rejected_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + rewards_all_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + rewards_all_std = torch.tensor(0.0, device=torch.cuda.current_device()) + + # we can only log on one rank if it is rank zero so we broadcast from last rank + torch.distributed.broadcast(loss_mean, get_last_rank()) + torch.distributed.broadcast(sft_loss_mean, get_last_rank()) + torch.distributed.broadcast(preference_loss_mean, get_last_rank()) + torch.distributed.broadcast(acc_mean, get_last_rank()) + + torch.distributed.broadcast(rewards_chosen_mean, get_last_rank()) + torch.distributed.broadcast(rewards_rejected_mean, get_last_rank()) + torch.distributed.broadcast(rewards_all_mean, get_last_rank()) + torch.distributed.broadcast(rewards_all_std, get_last_rank()) + + metrics = { + "loss": loss_mean, + "sft_loss": sft_loss_mean, + "preference_loss": preference_loss_mean, + "acc": acc_mean, + "rewards_chosen_mean": rewards_chosen_mean, + "rewards_rejected_mean": rewards_rejected_mean, + "rewards_all_mean": rewards_all_mean, + "rewards_all_std": rewards_all_std, + } + + # move to CPU + metrics = {k: v.item() for k, v in metrics.items()} + + return loss_mean.item(), metrics + + def prepare_for_training_step(self): + # custom trainers will always zero grad for us + prepare_for_training_step(self, zero_grad=False) + + def finish_training_step(self): + grad_reductions(self) + + def prepare_for_validation_step(self): + prepare_for_validation_step(self) + + def finish_validation_step(self): + finish_validation_step(self) + + @torch.no_grad() + def get_logprob_batch(self, batch): + seq_length = batch["chosen"].shape[1] + batch_size = batch["chosen"].shape[0] + + num_microbatches = divide(batch_size, self.cfg.dpo.log_prob_forward_micro_batch_size) + data_iter = get_iterator_k_split(batch, num_microbatches) + set_sync_funcs(self, forward_only=True) + + fwd_bwd_function = get_forward_backward_func() + + logprobs_list = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(logprobs_only=True), + data_iterator=data_iter, + model=self.model, + num_microbatches=num_microbatches, + forward_only=True, + seq_length=seq_length, + micro_batch_size=self.cfg.dpo.log_prob_forward_micro_batch_size * 2, + collect_non_loss_data=True, + ) + + if len(logprobs_list) > 0: + chosen_logprobs_list = [] + rejected_logprobs_list = [] + for item in logprobs_list: + chosen_logprobs, rejected_logprobs = self.split_output_tensor(item["logprobs"]) + chosen_logprobs_list.append(chosen_logprobs) + rejected_logprobs_list.append(rejected_logprobs) + + logprobs = torch.cat([torch.cat(chosen_logprobs_list), torch.cat(rejected_logprobs_list)], dim=0) + else: + logprobs = None + + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + # broadcast it from last PP stage to everything else + logprobs = broadcast_2d_tensor( + logprobs, + parallel_state.get_pipeline_model_parallel_last_rank(), + parallel_state.get_pipeline_model_parallel_group(), + ) + + return logprobs + + def get_ref_policy_logprobs(self, batch): + + if self.use_peft and self.ref_policy_state_dict is None: + # when using adapters instead of full-tuning, the actor is reference model + adapters + with adapter_control(self): + # With adapters disabled (meaning using the reference model), calculate ref_log_probs + ref_log_probs = self.get_logprob_batch(batch) + else: + with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2): + ref_log_probs = self.get_logprob_batch(batch) + + # return in GPU, trainer needs to move to cpu + return ref_log_probs \ No newline at end of file diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index e29bc28e6..cf9a85012 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -28,7 +28,7 @@ import torch from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory -from megatron.core.num_microbatches_calculator import reconfigure_microbatch_calculator +from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator as reconfigure_microbatch_calculator from omegaconf import DictConfig, OmegaConf from torch.masked import as_masked_tensor @@ -122,7 +122,9 @@ def load_checkpoint_model_config(restore_path): return OmegaConf.load(cfg_path) with tempfile.TemporaryDirectory() as tmpdir: - NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, extract_config_only=True) + # Extracts only model config + members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name) + NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members) cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt)) return cfg From bc96c958192c677ab3991bd1c5bd22901fbaa691 Mon Sep 17 00:00:00 2001 From: arendu Date: Fri, 1 Nov 2024 18:00:42 +0000 Subject: [PATCH 2/7] dpo and sft Signed-off-by: arendu --- examples/nlp/gpt/conf/gpt_sft.yaml | 2 +- examples/nlp/gpt/train_gpt_dpo.py | 2 +- examples/nlp/gpt/train_gpt_sft.py | 4 +- nemo_aligner/models/nlp/gpt/gpt_sft_model.py | 6 + .../models/nlp/gpt/megatron_gpt_dpo_model.py | 418 +----------------- nemo_aligner/utils/utils.py | 4 +- 6 files changed, 13 insertions(+), 423 deletions(-) diff --git a/examples/nlp/gpt/conf/gpt_sft.yaml b/examples/nlp/gpt/conf/gpt_sft.yaml index bdd757f31..745f6ae01 100644 --- a/examples/nlp/gpt/conf/gpt_sft.yaml +++ b/examples/nlp/gpt/conf/gpt_sft.yaml @@ -191,7 +191,7 @@ model: output_original_text: True # needed for the proper metrics support optim: - name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work. + name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work. lr: 3e-5 weight_decay: 0.01 betas: diff --git a/examples/nlp/gpt/train_gpt_dpo.py b/examples/nlp/gpt/train_gpt_dpo.py index 402528f43..473fb89eb 100644 --- a/examples/nlp/gpt/train_gpt_dpo.py +++ b/examples/nlp/gpt/train_gpt_dpo.py @@ -53,7 +53,7 @@ def main(cfg) -> None: logger = CustomLoggerWrapper(trainer.loggers) ptl_model = load_from_nemo( - MegatronMambaDPOModel if cfg.model.mamba_hybrid else MegatronGPTDPOModel, + MegatronMambaDPOModel if cfg.mamba_hybrid else MegatronGPTDPOModel, cfg.model, trainer, strict=True, diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index f52445637..271f7cec8 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -27,7 +27,7 @@ from nemo.utils.exp_manager import exp_manager from nemo_aligner.algorithms.supervised import SupervisedTrainer from nemo_aligner.data.nlp.builders import build_dataloader, build_sft_dataset -from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel +from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel, MambaSFTModel from nemo_aligner.utils.distributed import Timer from nemo_aligner.utils.train_script_utils import ( CustomLoggerWrapper, @@ -127,7 +127,7 @@ def main(cfg) -> None: cfg.model.precision = cfg.trainer.precision ptl_model, updated_cfg = load_from_nemo( - GPTSFTModel, + MambaSFTModel if cfg.model.mamba_hybrid else GPTSFTModel, cfg, trainer, strict=True, diff --git a/nemo_aligner/models/nlp/gpt/gpt_sft_model.py b/nemo_aligner/models/nlp/gpt/gpt_sft_model.py index d3a615500..15bc69c00 100644 --- a/nemo_aligner/models/nlp/gpt/gpt_sft_model.py +++ b/nemo_aligner/models/nlp/gpt/gpt_sft_model.py @@ -22,6 +22,7 @@ from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy from nemo.collections.nlp.modules.common.text_generation_utils import ( @@ -225,3 +226,8 @@ def finish_inference(self): self._restore_activation_checkpointing_args() self._restore_sequence_parallelism_args() set_train(self) + + +class MambaSFTModel(MegatronMambaModel, GPTSFTModel): + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) \ No newline at end of file diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 692e6a732..d7e69d7ef 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -464,420 +464,6 @@ def get_ref_policy_logprobs(self, batch): # return in GPU, trainer needs to move to cpu return ref_log_probs -class MegatronMambaDPOModel(NLPAdapterModelMixin, MegatronMambaModel, SupervisedInterface): - """ - Megatron GPT DPO Model Training. - """ - +class MegatronMambaDPOModel(MegatronMambaModel, MegatronGPTDPOModel): # @adithyare inherence order matters def __init__(self, cfg: DictConfig, trainer: Trainer): - super().__init__(cfg, trainer=trainer) - - if self.cfg.pipeline_model_parallel_size > 1 and not self.cfg.megatron_amp_O2: - warnings.warn( - "when using pipeline parallelism, it is recommended to set megatron_amp_O2 to be True to " - "avoid explicit casting for pipeline communication" - ) - self.automatic_optimization = False - self.ref_policy_state_dict = None - - self.ref_policy_kl_penalty = self.cfg.dpo.get("ref_policy_kl_penalty", 0.0) - self.preference_avg_log_probs = self.cfg.dpo.get("preference_average_log_probs", False) - self.sft_avg_log_probs = self.cfg.dpo.get("sft_average_log_probs", self.preference_avg_log_probs) - - self.preference_loss_weight = self.cfg.dpo.get("preference_loss_weight", 1) - self.sft_loss_weight = self.cfg.dpo.get("sft_loss_weight", 0) - assert ( - self.preference_loss_weight != 0 or self.sft_loss_weight != 0 - ), "sft loss weight and preference loss weight cannot both be 0" - - # variants of preference losses, by default DPO. - self.preference_loss = self.cfg.dpo.get("preference_loss", "dpo") - self.gt_reward_scale = self.cfg.dpo.get("gt_reward_scale", 1.0) - - @torch.no_grad() - def gather_and_split_rewards(self, pi_logprobs, ref_logprobs, labels, average_log_probs=False): - pi_logprobs = pi_logprobs.detach() - - dp_group = parallel_state.get_data_parallel_group() - - batch_logs = self.get_reduced_masked_logps( - pi_logprobs - ref_logprobs, labels[:, 1:], average_log_probs=average_log_probs - ) - - output_list = [torch.zeros_like(batch_logs) for _ in range(dp_group.size())] - - torch.distributed.all_gather(output_list, batch_logs, group=dp_group) - - split_iter = map(self.split_output_tensor, output_list) - - out_chosen, out_rejected = map(torch.cat, zip(*split_iter)) - - return out_chosen.flatten(), out_rejected.flatten() - - def get_forward_output_and_loss_func(self, validation_step=False, logprobs_only=False): - def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): - batch = next(dataloader_iter) - - required_keys = set() - if parallel_state.get_pipeline_model_parallel_world_size() == 1: - required_keys.update(batch.keys()) - else: - # there is a problem with apex ignoring the mask on the older models - # so we will always give the attention mask - required_keys.add("attention_mask") - - if parallel_state.is_pipeline_first_stage(): - required_keys.update(("chosen", "rejected", "position_ids")) - - if parallel_state.is_pipeline_last_stage(): - required_keys.update( - ( - "ref_policy_log_probs_chosen", - "ref_policy_log_probs_rejected", - "chosen_labels", - "rejected_labels", - "chosen_rewards", - "rejected_rewards", - ) - ) - - batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} - - tokens, labels, ref_logprobs, gt_rewards = None, None, None, None - if batch["chosen"] is not None and batch["rejected"] is not None: - tokens = torch.cat((batch["chosen"], batch["rejected"]), dim=0) - - if batch["chosen_labels"] is not None and batch["rejected_labels"] is not None: - labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0) - - if ( - batch.get("ref_policy_log_probs_chosen") is not None - and batch.get("ref_policy_log_probs_rejected") is not None - ): - ref_logprobs = torch.cat( - (batch["ref_policy_log_probs_chosen"], batch["ref_policy_log_probs_rejected"]), dim=0 - ) - - if batch["chosen_rewards"] is not None and batch["rejected_rewards"] is not None: - gt_rewards = torch.cat((batch["chosen_rewards"], batch["rejected_rewards"]), dim=0) - - # this is necessary if MBS > 1 with the new GBS padding logic, as you may get batch dim > 1 in some configs - # these two lines ensure your position_ids and attn_mask are always B=1 - # position_ids = batch["position_ids"][0:1] - attention_mask = batch["attention_mask"][0:1] - - # Model forward pass - forward_args = { - "input_ids": tokens, - "position_ids": batch["position_ids"], - "attention_mask": attention_mask, - "labels": None, - "loss_mask": None, - } - - # TODO: we can remove this someday when we no longer support legacy models - if not self.mcore_gpt: - forward_args["checkpoint_activations_all_layers"] = checkpoint_activations_all_layers - if not self.use_loss_mask: - forward_args.pop("loss_mask") - else: - forward_args.pop("loss_mask") - - output_tensor = model(**forward_args) - - # in this nemo version the model and autocast dtypes are not synced - # so we need to explicitly cast it - if not parallel_state.is_pipeline_last_stage(): - output_tensor = output_tensor.to(dtype=self.autocast_dtype) - - def logprobs_func(output_tensor, non_loss_data=True): - # This function is expected to be used only when `collect_non_loss_data=True` in the fwd_bwd_function of Megatron-LM. - # See https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/megatron/core/pipeline_parallel/schedules.py#L228 - assert non_loss_data - logprobs = from_parallel_logits_to_logprobs( - vocab_parallel_logits=output_tensor, target=labels, inference_only=True, higher_stability=True, - ) - return {"logprobs": logprobs} - - def loss_func(output_tensor): - if validation_step and not self.cfg.data.get("validation_drop_last", True): - raise NotImplementedError("DPO does not support validation when cfg.data.drop_last=False") - - per_token_logps = from_parallel_logits_to_logprobs( - vocab_parallel_logits=output_tensor, - target=labels, - inference_only=validation_step, - higher_stability=True, - ) - - preference_loss, acc_chosen = self.loss_func( - per_token_logps, - ref_logprobs, - labels[:, 1:], - gt_rewards, - average_log_probs=self.preference_avg_log_probs, - ) - - sft_loss = torch.zeros_like(preference_loss) - if self.sft_loss_weight != 0: - sft_loss = self.sft_loss_func( - per_token_logps, labels[:, 1:], average_log_probs=self.sft_avg_log_probs - ) - loss = self.preference_loss_weight * preference_loss + self.sft_loss_weight * sft_loss - - ( - reduced_loss, - reduced_preference_loss, - reduced_sft_loss, - reduced_acc, - ) = average_losses_across_data_parallel_group([loss, preference_loss, sft_loss, acc_chosen]) - - out_chosen, out_rejected = self.gather_and_split_rewards( - per_token_logps, ref_logprobs, labels, average_log_probs=self.preference_avg_log_probs - ) - - return ( - loss, - { - "avg": reduced_loss, - "avg_sft_loss": reduced_sft_loss, - "avg_preference_loss": reduced_preference_loss, - "acc": reduced_acc, - "out_chosen": out_chosen, - "out_rejected": out_rejected, - }, - ) - - if logprobs_only: - return output_tensor, logprobs_func - else: - return output_tensor, loss_func - - return fwd_output_and_loss_func - - def split_output_tensor(self, output_tensor): - chosen_logps, reject_logps = torch.split(output_tensor.float(), len(output_tensor) // 2, dim=0) - return chosen_logps, reject_logps - - def get_reduced_masked_logps(self, logps, labels, average_log_probs=False): - assert logps.shape == labels.shape, "logps and labels shape mismatch" - - loss_mask = (labels > -1).float() - - if average_log_probs: - # need to guard against divide by zero in case labels are all -100 - return (logps * loss_mask).sum(-1) / loss_mask.sum(-1).clamp(min=1) - else: - return (logps * loss_mask).sum(-1) - - def loss_func(self, pi_logprobs, ref_logprobs, labels, gt_rewards, average_log_probs=False): - rewards = self.get_reduced_masked_logps( - pi_logprobs - ref_logprobs, labels, average_log_probs=average_log_probs - ) - chosen_rewards, reject_rewards = self.split_output_tensor(rewards) - rewards_delta = chosen_rewards - reject_rewards - - if self.preference_loss == "dpo": - loss = -torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta).mean(0) - elif self.preference_loss == "rpo_bwd_kl": - logbeta_hat_chosen = torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta) - logbeta_hat_rejected = torch.nn.functional.logsigmoid(-self.ref_policy_kl_penalty * rewards_delta) - - chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) - gt_rewards_delta = self.gt_reward_scale * (chosen_gt_rewards - reject_gt_rewards) - logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) - logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) - - loss = ( - torch.exp(logalpha_hat_chosen) * (logalpha_hat_chosen - logbeta_hat_chosen) - + torch.exp(logalpha_hat_rejected) * (logalpha_hat_rejected - logbeta_hat_rejected) - ).mean(0) - elif self.preference_loss == "rpo_fwd_kl": - logbeta_hat_chosen = torch.nn.functional.logsigmoid(self.ref_policy_kl_penalty * rewards_delta) - logbeta_hat_rejected = torch.nn.functional.logsigmoid(-self.ref_policy_kl_penalty * rewards_delta) - - chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) - gt_rewards_delta = self.gt_reward_scale * (chosen_gt_rewards - reject_gt_rewards) - logalpha_hat_chosen = torch.nn.functional.logsigmoid(gt_rewards_delta) - logalpha_hat_rejected = torch.nn.functional.logsigmoid(-gt_rewards_delta) - - loss = ( - torch.exp(logbeta_hat_chosen) * (logbeta_hat_chosen - logalpha_hat_chosen) - + torch.exp(logbeta_hat_rejected) * (logbeta_hat_rejected - logalpha_hat_rejected) - ).mean(0) - elif self.preference_loss == "ipo": - loss = torch.mean((chosen_rewards - reject_rewards - 1.0 / (2.0 * self.ref_policy_kl_penalty)) ** 2, 0) - elif self.preference_loss == "rpo_sq": - chosen_gt_rewards, reject_gt_rewards = self.split_output_tensor(gt_rewards) - gt_rewards_delta = self.gt_reward_scale * (chosen_gt_rewards - reject_gt_rewards) - - loss = torch.mean((self.ref_policy_kl_penalty * rewards_delta - gt_rewards_delta) ** 2, 0) - else: - raise NotImplementedError(f"preference_loss {self.preference_loss} is not implemented") - - with torch.no_grad(): - comp = chosen_rewards > reject_rewards - acc_chosen = comp.float().mean() - - return loss, acc_chosen - - def sft_loss_func(self, pi_logprobs, labels, average_log_probs=False): - logprobs = self.get_reduced_masked_logps(pi_logprobs, labels, average_log_probs=average_log_probs) - chosen_logprobs, _ = self.split_output_tensor(logprobs) - return -chosen_logprobs.mean(0) - - def get_loss_and_metrics(self, batch, forward_only): - seq_length = batch["chosen"].shape[1] - - data_iter = get_iterator_k_split(batch, get_num_microbatches()) - set_sync_funcs(self, forward_only) - - fwd_bwd_function = get_forward_backward_func() - - losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(forward_only, logprobs_only=False), - data_iterator=data_iter, - model=self.model, - num_microbatches=get_num_microbatches(), - forward_only=forward_only, - seq_length=seq_length, - micro_batch_size=self.cfg.micro_batch_size - * 2, # each minibatch has 2 comparisons so tensor shape will be mbs * 2 - ) - - # only the last stages of the pipeline return losses - if losses_reduced_per_micro_batch: - # NOTE: assume that the returned values are already gathered across the DP workers - rewards_chosen = torch.cat([item["out_chosen"] for item in losses_reduced_per_micro_batch]) - rewards_rejected = torch.cat([item["out_rejected"] for item in losses_reduced_per_micro_batch]) - - rewards_all = torch.cat((rewards_chosen, rewards_rejected)) - rewards_chosen_mean = rewards_chosen.mean() - rewards_rejected_mean = rewards_rejected.mean() - rewards_all_mean = rewards_all.mean() - rewards_all_std = rewards_all.std() - - loss_mean = torch.as_tensor( - [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch], - device=torch.cuda.current_device(), - ).mean() - sft_loss_mean = torch.as_tensor( - [loss_reduced["avg_sft_loss"] for loss_reduced in losses_reduced_per_micro_batch], - device=torch.cuda.current_device(), - ).mean() - preference_loss_mean = torch.as_tensor( - [loss_reduced["avg_preference_loss"] for loss_reduced in losses_reduced_per_micro_batch], - device=torch.cuda.current_device(), - ).mean() - acc_mean = torch.as_tensor( - [loss_reduced["acc"] for loss_reduced in losses_reduced_per_micro_batch], - device=torch.cuda.current_device(), - ).mean() - else: - loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) - sft_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) - preference_loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) - acc_mean = torch.tensor(0.0, device=torch.cuda.current_device()) - - rewards_chosen_mean = torch.tensor(0.0, device=torch.cuda.current_device()) - rewards_rejected_mean = torch.tensor(0.0, device=torch.cuda.current_device()) - rewards_all_mean = torch.tensor(0.0, device=torch.cuda.current_device()) - rewards_all_std = torch.tensor(0.0, device=torch.cuda.current_device()) - - # we can only log on one rank if it is rank zero so we broadcast from last rank - torch.distributed.broadcast(loss_mean, get_last_rank()) - torch.distributed.broadcast(sft_loss_mean, get_last_rank()) - torch.distributed.broadcast(preference_loss_mean, get_last_rank()) - torch.distributed.broadcast(acc_mean, get_last_rank()) - - torch.distributed.broadcast(rewards_chosen_mean, get_last_rank()) - torch.distributed.broadcast(rewards_rejected_mean, get_last_rank()) - torch.distributed.broadcast(rewards_all_mean, get_last_rank()) - torch.distributed.broadcast(rewards_all_std, get_last_rank()) - - metrics = { - "loss": loss_mean, - "sft_loss": sft_loss_mean, - "preference_loss": preference_loss_mean, - "acc": acc_mean, - "rewards_chosen_mean": rewards_chosen_mean, - "rewards_rejected_mean": rewards_rejected_mean, - "rewards_all_mean": rewards_all_mean, - "rewards_all_std": rewards_all_std, - } - - # move to CPU - metrics = {k: v.item() for k, v in metrics.items()} - - return loss_mean.item(), metrics - - def prepare_for_training_step(self): - # custom trainers will always zero grad for us - prepare_for_training_step(self, zero_grad=False) - - def finish_training_step(self): - grad_reductions(self) - - def prepare_for_validation_step(self): - prepare_for_validation_step(self) - - def finish_validation_step(self): - finish_validation_step(self) - - @torch.no_grad() - def get_logprob_batch(self, batch): - seq_length = batch["chosen"].shape[1] - batch_size = batch["chosen"].shape[0] - - num_microbatches = divide(batch_size, self.cfg.dpo.log_prob_forward_micro_batch_size) - data_iter = get_iterator_k_split(batch, num_microbatches) - set_sync_funcs(self, forward_only=True) - - fwd_bwd_function = get_forward_backward_func() - - logprobs_list = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(logprobs_only=True), - data_iterator=data_iter, - model=self.model, - num_microbatches=num_microbatches, - forward_only=True, - seq_length=seq_length, - micro_batch_size=self.cfg.dpo.log_prob_forward_micro_batch_size * 2, - collect_non_loss_data=True, - ) - - if len(logprobs_list) > 0: - chosen_logprobs_list = [] - rejected_logprobs_list = [] - for item in logprobs_list: - chosen_logprobs, rejected_logprobs = self.split_output_tensor(item["logprobs"]) - chosen_logprobs_list.append(chosen_logprobs) - rejected_logprobs_list.append(rejected_logprobs) - - logprobs = torch.cat([torch.cat(chosen_logprobs_list), torch.cat(rejected_logprobs_list)], dim=0) - else: - logprobs = None - - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - # broadcast it from last PP stage to everything else - logprobs = broadcast_2d_tensor( - logprobs, - parallel_state.get_pipeline_model_parallel_last_rank(), - parallel_state.get_pipeline_model_parallel_group(), - ) - - return logprobs - - def get_ref_policy_logprobs(self, batch): - - if self.use_peft and self.ref_policy_state_dict is None: - # when using adapters instead of full-tuning, the actor is reference model + adapters - with adapter_control(self): - # With adapters disabled (meaning using the reference model), calculate ref_log_probs - ref_log_probs = self.get_logprob_batch(batch) - else: - with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2): - ref_log_probs = self.get_logprob_batch(batch) - - # return in GPU, trainer needs to move to cpu - return ref_log_probs \ No newline at end of file + super().__init__(cfg, trainer=trainer) \ No newline at end of file diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index cf9a85012..97b0f7184 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -122,9 +122,7 @@ def load_checkpoint_model_config(restore_path): return OmegaConf.load(cfg_path) with tempfile.TemporaryDirectory() as tmpdir: - # Extracts only model config - members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name) - NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members) + NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, extract_config_only=True) cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt)) return cfg From b8049cd7ec1118d2a79332be5d820a8e33187c94 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 1 Nov 2024 17:20:41 -0700 Subject: [PATCH 3/7] dpo support Signed-off-by: root --- examples/nlp/gpt/train_gpt_dpo.py | 2 +- nemo_aligner/utils/utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/nlp/gpt/train_gpt_dpo.py b/examples/nlp/gpt/train_gpt_dpo.py index 473fb89eb..402528f43 100644 --- a/examples/nlp/gpt/train_gpt_dpo.py +++ b/examples/nlp/gpt/train_gpt_dpo.py @@ -53,7 +53,7 @@ def main(cfg) -> None: logger = CustomLoggerWrapper(trainer.loggers) ptl_model = load_from_nemo( - MegatronMambaDPOModel if cfg.mamba_hybrid else MegatronGPTDPOModel, + MegatronMambaDPOModel if cfg.model.mamba_hybrid else MegatronGPTDPOModel, cfg.model, trainer, strict=True, diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index 97b0f7184..fe01cff2a 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -122,7 +122,8 @@ def load_checkpoint_model_config(restore_path): return OmegaConf.load(cfg_path) with tempfile.TemporaryDirectory() as tmpdir: - NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, extract_config_only=True) + members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name) + NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members) cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt)) return cfg From 050e767e8e9f248aaf78951d784e576d4107e008 Mon Sep 17 00:00:00 2001 From: arendu Date: Tue, 5 Nov 2024 01:03:25 +0000 Subject: [PATCH 4/7] mamba padding Signed-off-by: arendu --- nemo_aligner/algorithms/dpo.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/nemo_aligner/algorithms/dpo.py b/nemo_aligner/algorithms/dpo.py index 6b2103328..b12a6c87d 100644 --- a/nemo_aligner/algorithms/dpo.py +++ b/nemo_aligner/algorithms/dpo.py @@ -29,6 +29,15 @@ from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch from nemo_aligner.utils.utils import clear_memory +def pad_sequence_to_max(sequences, max_len, padding_value=0): + # Then, pad further to match `max_len` + if sequences.size(1) > max_len: + raise RuntimeError("max len has to be > seq len") + elif sequences.size(1) <= max_len: + pad_size = max_len - sequences.size(1) + padding = torch.full((sequences.size(0), pad_size), padding_value) + padded_sequences = torch.cat([sequences, padding], dim=1) + return padded_sequences def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False): chosen_tokens = [item["chosen"] for item in batch] @@ -317,6 +326,15 @@ def augment_dataloader(self, dataloader): while True: try: batch = next(iter_dataloader) + if self.model.cfg.mamba_hybrid: + max_seq_len = max([batch['chosen'].size(-1), batch['rejected'].size(-1), batch['chosen_labels'].size(-1), batch['rejected_labels'].size(-1)]) + max_seq_len = torch.tensor(max_seq_len, device=torch.cuda.current_device()) + torch.distributed.all_reduce(max_seq_len, op=torch.distributed.ReduceOp.MAX) + max_seq_len = ((max_seq_len.item() + 255) // 256) * 256 + batch["chosen"] = pad_sequence_to_max(batch["chosen"], max_seq_len, padding_value=self.model.tokenizer.eos_id) + batch["chosen_labels"] = pad_sequence_to_max(batch["chosen_labels"], max_seq_len, padding_value=-100) + batch["rejected"] = pad_sequence_to_max(batch["rejected"], max_seq_len, padding_value=self.model.tokenizer.eos_id) + batch["rejected_labels"] = pad_sequence_to_max(batch["rejected_labels"], max_seq_len, padding_value=-100) logprobs = self.model.get_ref_policy_logprobs(batch).cpu() chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0) batch["ref_policy_log_probs_chosen"] = chosen_logps From 1a4acc95fa6dc707ab0afe8a314a21f68e366883 Mon Sep 17 00:00:00 2001 From: adithyare Date: Wed, 13 Nov 2024 17:40:30 -0800 Subject: [PATCH 5/7] convenience script to remove old format of DPO data Signed-off-by: adithyare --- .../data/nlp/scripts/undo_special_tokens.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 nemo_aligner/data/nlp/scripts/undo_special_tokens.py diff --git a/nemo_aligner/data/nlp/scripts/undo_special_tokens.py b/nemo_aligner/data/nlp/scripts/undo_special_tokens.py new file mode 100644 index 000000000..d47f0d272 --- /dev/null +++ b/nemo_aligner/data/nlp/scripts/undo_special_tokens.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to remove special tokens from dpo datasets +and convert them into list of messages format""" + +import json +import re +import sys +input_jsonl = sys.argv[1] +output_jsonl = input_jsonl.replace(".jsonl", "no_special_toks.jsonl") + +def format_conversation(input_string): + # Define roles and patterns + role_patterns = { + "System": "system", + "User": "user", + "Assistant": "assistant" + } + + # Initialize an empty output list + conversation = [] + + # Use regex to find each segment's role and content + segments = re.findall(r"([^\n]+)\n(.*?)((?=")[0] + rejected = rejected.split("\n")[0] + chosen_message = {"role": empty_assistant["role"], "content": chosen} + rejected_message = {"role": empty_assistant["role"], "content": rejected} + j_out = {"prompt": undo_spl_prompt, "chosen_response": chosen_message, "rejected_response": rejected_message, "chosen_reward": j["chosen_reward"], "rejected_reward": j["rejected_reward"]} + w.write(json.dumps(j_out) + "\n") + +old_format_example="""System\n\nUser\nDoes GPT-4 use RNN or Transformer models, and which one is better for this type of project?\nAssistant\nGPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.\nUser\nCould you explain in detail both the advantages and disadvantages from different perspectives?\nAssistant\nYes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.\nUser\ncould you add more in a table\nAssistant\nHere are some more pros and cons of RNN and transformer models:\n \n\n Pros of RNNs:\n \n\n - Better sequential memory and handling of sequences\n - Simpler and easier to understand\n - Faster to run\n - Good for teaching and research\n \n\n Cons of RNNs:\n \n\n - Limited parallelism\n - Poor handling of large input datasets\n - Not as accurate as transformers\n - Not as flexible as transformers\n \n\n Pros of transformers:\n \n\n - Better parallelism and handling of large input datasets\n - More flexible and versatile\n - Higher accuracy\n - Good for actual applications in industry and the real world\n \n\n Cons of transformers:\n \n\n - More complex and harder to understand and debug\n - Slower to run\n - Not as good for teaching and research\nUser\nFor a chat-based music recommendation platform, which model is recommended? Please mention any model that works well, including those mentioned above or any others.\nAssistant\nFor a chat-based music recommendation platform, I would recommend using a transformer architecture, such as the GPT-4 model. Transformers are better for handling large input datasets, and they are more flexible and versatile than RNNs. They are also more accurate, so they are better for actually recommending music to users in a real-world application. RNNs are better for teaching and research, as they are simpler and easier to understand and debug.\nUser\ncould you add more?\nAssistant\n""" + +print("****************************************") +print("old format prompt") +print("****************************************") +print(old_format_example) +print("****************************************") +print("new format prompt") +print("****************************************") +new_format_example, _ = format_conversation(old_format_example) +print(json.dumps(new_format_example, indent=2)) From 93eea807e17f1dc0c40798d96b08618ba2a40fb6 Mon Sep 17 00:00:00 2001 From: arendu Date: Thu, 14 Nov 2024 04:38:30 +0000 Subject: [PATCH 6/7] pad to mult 256 Signed-off-by: arendu --- examples/nlp/gpt/train_gpt_sft.py | 2 ++ nemo_aligner/data/nlp/builders.py | 3 ++- nemo_aligner/data/nlp/scripts/undo_special_tokens.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index 271f7cec8..8272f238b 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -170,6 +170,7 @@ def main(cfg) -> None: train_data_cfg, ptl_model.tokenizer, num_samples, + is_mamba=cfg.model.mamba_hybrid, answer_only_loss=True, is_chat=cfg.model.data.chat, special_tokens=cfg.model.data.chat_prompt_tokens, @@ -182,6 +183,7 @@ def main(cfg) -> None: val_data_cfg, ptl_model.tokenizer, num_samples, + is_mamba=cfg.model.mamba_hybrid, answer_only_loss=True, is_chat=cfg.model.data.chat, special_tokens=cfg.model.data.chat_prompt_tokens, diff --git a/nemo_aligner/data/nlp/builders.py b/nemo_aligner/data/nlp/builders.py index a61fb46f9..43d9231ef 100644 --- a/nemo_aligner/data/nlp/builders.py +++ b/nemo_aligner/data/nlp/builders.py @@ -266,7 +266,7 @@ def build_dataset(index, name): build_train_valid_test_regression_rm_datasets = partial(build_train_valid_test_datasets, RegressionRewardModelDataset) -def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None): +def build_sft_dataset(data_cfg, tokenizer, num_samples, is_mamba, answer_only_loss=True, is_chat=True, special_tokens=None): packed_sequence = data_cfg.get("packed_sequence", False) dataset_kwargs = {} @@ -298,6 +298,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i answer_only_loss=answer_only_loss, truncation_field=data_cfg.get("truncation_field", "text"), pad_to_max_length=data_cfg.get("pad_to_max_length", False), + pad_seq_length_to_mult=256 if is_mamba else 16, index_mapping_dir=data_cfg.get("index_mapping_dir", None), prompt_template=data_cfg.get("prompt_template", None), virtual_tokens=0, diff --git a/nemo_aligner/data/nlp/scripts/undo_special_tokens.py b/nemo_aligner/data/nlp/scripts/undo_special_tokens.py index d47f0d272..3b06f9c8a 100644 --- a/nemo_aligner/data/nlp/scripts/undo_special_tokens.py +++ b/nemo_aligner/data/nlp/scripts/undo_special_tokens.py @@ -19,7 +19,7 @@ import re import sys input_jsonl = sys.argv[1] -output_jsonl = input_jsonl.replace(".jsonl", "no_special_toks.jsonl") +output_jsonl = input_jsonl.replace(".jsonl", ".no_special_toks.jsonl") def format_conversation(input_string): # Define roles and patterns From 5721741c73d5aebd120c285a4664562f6a6051cd Mon Sep 17 00:00:00 2001 From: arendu Date: Thu, 14 Nov 2024 18:35:27 +0000 Subject: [PATCH 7/7] copy dpo style cfg overrides Signed-off-by: arendu --- examples/nlp/gpt/train_gpt_sft.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index 8272f238b..064b6c4ad 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -39,8 +39,7 @@ resolve_and_create_trainer, retrieve_custom_trainer_state_dict, ) -from nemo_aligner.utils.utils import load_from_nemo - +from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo """Script to start SFT training""" OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) @@ -115,6 +114,7 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): @hydra_runner(config_path="conf", config_name="gpt_sft") def main(cfg) -> None: + cfg.model = load_and_override_model_config(cfg.model.restore_from_path, cfg.model) logging.info("\n\n************** Experiment configuration ***********") logging.info(f"\n{OmegaConf.to_yaml(cfg)}") @@ -126,17 +126,15 @@ def main(cfg) -> None: with open_dict(cfg): cfg.model.precision = cfg.trainer.precision - ptl_model, updated_cfg = load_from_nemo( + ptl_model = load_from_nemo( MambaSFTModel if cfg.model.mamba_hybrid else GPTSFTModel, cfg, trainer, strict=True, - modify_config_fn=_modify_config, restore_path=cfg.model.restore_from_path, - return_updated_cfg=True, ) - init_peft(ptl_model, updated_cfg) + init_peft(ptl_model, cfg.model) with open_dict(cfg): # overwrite the model config with the config from the checkpoint