diff --git a/examples/nlp/gpt/conf/gpt_dpo.yaml b/examples/nlp/gpt/conf/gpt_dpo.yaml index 2a165bf9d..4aa67ec38 100644 --- a/examples/nlp/gpt/conf/gpt_dpo.yaml +++ b/examples/nlp/gpt/conf/gpt_dpo.yaml @@ -57,6 +57,7 @@ model: micro_batch_size: 1 global_batch_size: 64 megatron_amp_O2: True + mamba_hybrid: False dpo: # This default value ensures there are no numeric differences beween trained and reference policies when computing log probs. diff --git a/examples/nlp/gpt/conf/gpt_sft.yaml b/examples/nlp/gpt/conf/gpt_sft.yaml index bdd757f31..745f6ae01 100644 --- a/examples/nlp/gpt/conf/gpt_sft.yaml +++ b/examples/nlp/gpt/conf/gpt_sft.yaml @@ -191,7 +191,7 @@ model: output_original_text: True # needed for the proper metrics support optim: - name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work. + name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work. lr: 3e-5 weight_decay: 0.01 betas: diff --git a/examples/nlp/gpt/train_gpt_dpo.py b/examples/nlp/gpt/train_gpt_dpo.py index aefa0c5ac..402528f43 100644 --- a/examples/nlp/gpt/train_gpt_dpo.py +++ b/examples/nlp/gpt/train_gpt_dpo.py @@ -21,7 +21,7 @@ from nemo.utils.exp_manager import exp_manager from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets -from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel +from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel, MegatronMambaDPOModel from nemo_aligner.utils.distributed import Timer from nemo_aligner.utils.train_script_utils import ( CustomLoggerWrapper, @@ -53,7 +53,7 @@ def main(cfg) -> None: logger = CustomLoggerWrapper(trainer.loggers) ptl_model = load_from_nemo( - MegatronGPTDPOModel, + MegatronMambaDPOModel if cfg.model.mamba_hybrid else MegatronGPTDPOModel, cfg.model, trainer, strict=True, diff --git a/examples/nlp/gpt/train_gpt_sft.py b/examples/nlp/gpt/train_gpt_sft.py index f52445637..064b6c4ad 100644 --- a/examples/nlp/gpt/train_gpt_sft.py +++ b/examples/nlp/gpt/train_gpt_sft.py @@ -27,7 +27,7 @@ from nemo.utils.exp_manager import exp_manager from nemo_aligner.algorithms.supervised import SupervisedTrainer from nemo_aligner.data.nlp.builders import build_dataloader, build_sft_dataset -from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel +from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel, MambaSFTModel from nemo_aligner.utils.distributed import Timer from nemo_aligner.utils.train_script_utils import ( CustomLoggerWrapper, @@ -39,8 +39,7 @@ resolve_and_create_trainer, retrieve_custom_trainer_state_dict, ) -from nemo_aligner.utils.utils import load_from_nemo - +from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo """Script to start SFT training""" OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) @@ -115,6 +114,7 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): @hydra_runner(config_path="conf", config_name="gpt_sft") def main(cfg) -> None: + cfg.model = load_and_override_model_config(cfg.model.restore_from_path, cfg.model) logging.info("\n\n************** Experiment configuration ***********") logging.info(f"\n{OmegaConf.to_yaml(cfg)}") @@ -126,17 +126,15 @@ def main(cfg) -> None: with open_dict(cfg): cfg.model.precision = cfg.trainer.precision - ptl_model, updated_cfg = load_from_nemo( - GPTSFTModel, + ptl_model = load_from_nemo( + MambaSFTModel if cfg.model.mamba_hybrid else GPTSFTModel, cfg, trainer, strict=True, - modify_config_fn=_modify_config, restore_path=cfg.model.restore_from_path, - return_updated_cfg=True, ) - init_peft(ptl_model, updated_cfg) + init_peft(ptl_model, cfg.model) with open_dict(cfg): # overwrite the model config with the config from the checkpoint @@ -170,6 +168,7 @@ def main(cfg) -> None: train_data_cfg, ptl_model.tokenizer, num_samples, + is_mamba=cfg.model.mamba_hybrid, answer_only_loss=True, is_chat=cfg.model.data.chat, special_tokens=cfg.model.data.chat_prompt_tokens, @@ -182,6 +181,7 @@ def main(cfg) -> None: val_data_cfg, ptl_model.tokenizer, num_samples, + is_mamba=cfg.model.mamba_hybrid, answer_only_loss=True, is_chat=cfg.model.data.chat, special_tokens=cfg.model.data.chat_prompt_tokens, diff --git a/nemo_aligner/algorithms/dpo.py b/nemo_aligner/algorithms/dpo.py index 6b2103328..b12a6c87d 100644 --- a/nemo_aligner/algorithms/dpo.py +++ b/nemo_aligner/algorithms/dpo.py @@ -29,6 +29,15 @@ from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch from nemo_aligner.utils.utils import clear_memory +def pad_sequence_to_max(sequences, max_len, padding_value=0): + # Then, pad further to match `max_len` + if sequences.size(1) > max_len: + raise RuntimeError("max len has to be > seq len") + elif sequences.size(1) <= max_len: + pad_size = max_len - sequences.size(1) + padding = torch.full((sequences.size(0), pad_size), padding_value) + padded_sequences = torch.cat([sequences, padding], dim=1) + return padded_sequences def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False): chosen_tokens = [item["chosen"] for item in batch] @@ -317,6 +326,15 @@ def augment_dataloader(self, dataloader): while True: try: batch = next(iter_dataloader) + if self.model.cfg.mamba_hybrid: + max_seq_len = max([batch['chosen'].size(-1), batch['rejected'].size(-1), batch['chosen_labels'].size(-1), batch['rejected_labels'].size(-1)]) + max_seq_len = torch.tensor(max_seq_len, device=torch.cuda.current_device()) + torch.distributed.all_reduce(max_seq_len, op=torch.distributed.ReduceOp.MAX) + max_seq_len = ((max_seq_len.item() + 255) // 256) * 256 + batch["chosen"] = pad_sequence_to_max(batch["chosen"], max_seq_len, padding_value=self.model.tokenizer.eos_id) + batch["chosen_labels"] = pad_sequence_to_max(batch["chosen_labels"], max_seq_len, padding_value=-100) + batch["rejected"] = pad_sequence_to_max(batch["rejected"], max_seq_len, padding_value=self.model.tokenizer.eos_id) + batch["rejected_labels"] = pad_sequence_to_max(batch["rejected_labels"], max_seq_len, padding_value=-100) logprobs = self.model.get_ref_policy_logprobs(batch).cpu() chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0) batch["ref_policy_log_probs_chosen"] = chosen_logps diff --git a/nemo_aligner/data/nlp/builders.py b/nemo_aligner/data/nlp/builders.py index a61fb46f9..43d9231ef 100644 --- a/nemo_aligner/data/nlp/builders.py +++ b/nemo_aligner/data/nlp/builders.py @@ -266,7 +266,7 @@ def build_dataset(index, name): build_train_valid_test_regression_rm_datasets = partial(build_train_valid_test_datasets, RegressionRewardModelDataset) -def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None): +def build_sft_dataset(data_cfg, tokenizer, num_samples, is_mamba, answer_only_loss=True, is_chat=True, special_tokens=None): packed_sequence = data_cfg.get("packed_sequence", False) dataset_kwargs = {} @@ -298,6 +298,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i answer_only_loss=answer_only_loss, truncation_field=data_cfg.get("truncation_field", "text"), pad_to_max_length=data_cfg.get("pad_to_max_length", False), + pad_seq_length_to_mult=256 if is_mamba else 16, index_mapping_dir=data_cfg.get("index_mapping_dir", None), prompt_template=data_cfg.get("prompt_template", None), virtual_tokens=0, diff --git a/nemo_aligner/data/nlp/scripts/undo_special_tokens.py b/nemo_aligner/data/nlp/scripts/undo_special_tokens.py new file mode 100644 index 000000000..3b06f9c8a --- /dev/null +++ b/nemo_aligner/data/nlp/scripts/undo_special_tokens.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to remove special tokens from dpo datasets +and convert them into list of messages format""" + +import json +import re +import sys +input_jsonl = sys.argv[1] +output_jsonl = input_jsonl.replace(".jsonl", ".no_special_toks.jsonl") + +def format_conversation(input_string): + # Define roles and patterns + role_patterns = { + "System": "system", + "User": "user", + "Assistant": "assistant" + } + + # Initialize an empty output list + conversation = [] + + # Use regex to find each segment's role and content + segments = re.findall(r"([^\n]+)\n(.*?)((?=")[0] + rejected = rejected.split("\n")[0] + chosen_message = {"role": empty_assistant["role"], "content": chosen} + rejected_message = {"role": empty_assistant["role"], "content": rejected} + j_out = {"prompt": undo_spl_prompt, "chosen_response": chosen_message, "rejected_response": rejected_message, "chosen_reward": j["chosen_reward"], "rejected_reward": j["rejected_reward"]} + w.write(json.dumps(j_out) + "\n") + +old_format_example="""System\n\nUser\nDoes GPT-4 use RNN or Transformer models, and which one is better for this type of project?\nAssistant\nGPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.\nUser\nCould you explain in detail both the advantages and disadvantages from different perspectives?\nAssistant\nYes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.\nUser\ncould you add more in a table\nAssistant\nHere are some more pros and cons of RNN and transformer models:\n \n\n Pros of RNNs:\n \n\n - Better sequential memory and handling of sequences\n - Simpler and easier to understand\n - Faster to run\n - Good for teaching and research\n \n\n Cons of RNNs:\n \n\n - Limited parallelism\n - Poor handling of large input datasets\n - Not as accurate as transformers\n - Not as flexible as transformers\n \n\n Pros of transformers:\n \n\n - Better parallelism and handling of large input datasets\n - More flexible and versatile\n - Higher accuracy\n - Good for actual applications in industry and the real world\n \n\n Cons of transformers:\n \n\n - More complex and harder to understand and debug\n - Slower to run\n - Not as good for teaching and research\nUser\nFor a chat-based music recommendation platform, which model is recommended? Please mention any model that works well, including those mentioned above or any others.\nAssistant\nFor a chat-based music recommendation platform, I would recommend using a transformer architecture, such as the GPT-4 model. Transformers are better for handling large input datasets, and they are more flexible and versatile than RNNs. They are also more accurate, so they are better for actually recommending music to users in a real-world application. RNNs are better for teaching and research, as they are simpler and easier to understand and debug.\nUser\ncould you add more?\nAssistant\n""" + +print("****************************************") +print("old format prompt") +print("****************************************") +print(old_format_example) +print("****************************************") +print("new format prompt") +print("****************************************") +new_format_example, _ = format_conversation(old_format_example) +print(json.dumps(new_format_example, indent=2)) diff --git a/nemo_aligner/models/nlp/gpt/gpt_sft_model.py b/nemo_aligner/models/nlp/gpt/gpt_sft_model.py index d3a615500..15bc69c00 100644 --- a/nemo_aligner/models/nlp/gpt/gpt_sft_model.py +++ b/nemo_aligner/models/nlp/gpt/gpt_sft_model.py @@ -22,6 +22,7 @@ from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy from nemo.collections.nlp.modules.common.text_generation_utils import ( @@ -225,3 +226,8 @@ def finish_inference(self): self._restore_activation_checkpointing_args() self._restore_sequence_parallelism_args() set_train(self) + + +class MambaSFTModel(MegatronMambaModel, GPTSFTModel): + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) \ No newline at end of file diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index 952b4e897..d7e69d7ef 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -16,6 +16,8 @@ from functools import partial import torch +from megatron.core import parallel_state +from megatron.core.models.mamba import MambaModel from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from megatron.core.utils import divide @@ -23,6 +25,7 @@ from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel from nemo.collections.nlp.modules.common.megatron.utils import ( average_losses_across_data_parallel_group, get_iterator_k_split, @@ -460,3 +463,7 @@ def get_ref_policy_logprobs(self, batch): # return in GPU, trainer needs to move to cpu return ref_log_probs + +class MegatronMambaDPOModel(MegatronMambaModel, MegatronGPTDPOModel): # @adithyare inherence order matters + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) \ No newline at end of file diff --git a/nemo_aligner/utils/utils.py b/nemo_aligner/utils/utils.py index e29bc28e6..fe01cff2a 100644 --- a/nemo_aligner/utils/utils.py +++ b/nemo_aligner/utils/utils.py @@ -28,7 +28,7 @@ import torch from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory -from megatron.core.num_microbatches_calculator import reconfigure_microbatch_calculator +from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator as reconfigure_microbatch_calculator from omegaconf import DictConfig, OmegaConf from torch.masked import as_masked_tensor @@ -122,7 +122,8 @@ def load_checkpoint_model_config(restore_path): return OmegaConf.load(cfg_path) with tempfile.TemporaryDirectory() as tmpdir: - NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, extract_config_only=True) + members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name) + NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members) cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt)) return cfg