Skip to content

Commit

Permalink
feat: nemotron-5 features
Browse files Browse the repository at this point in the history
wip

Signed-off-by: arendu <[email protected]>

docs: 0.5.0 documentation updates (#346)

Signed-off-by: ashors1 <[email protected]>

ci: Sign-off cherry pick (#366)

Signed-off-by: Oliver Koenig <[email protected]>

docs: main readme and sft docs (#367)

Signed-off-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: Gerald Shen <[email protected]>

docs: fix code block rendering (#369)

Signed-off-by: ashors1 <[email protected]>

dpo and sft

Signed-off-by: arendu <[email protected]>

dpo support

Signed-off-by: root <[email protected]>

mamba padding

Signed-off-by: arendu <[email protected]>

convenience script to remove old format of DPO data

Signed-off-by: adithyare <[email protected]>

pad to mult 256

Signed-off-by: arendu <[email protected]>

copy dpo style cfg overrides

Signed-off-by: arendu <[email protected]>

remove _modify_config

Signed-off-by: arendu <[email protected]>

fix config issue

Signed-off-by: Jiaqi Zeng <[email protected]>

fix mamba config issue

Signed-off-by: Jiaqi Zeng <[email protected]>

is mamba default false

Signed-off-by: arendu <[email protected]>

revert cherry-pick-release-commit

Signed-off-by: Terry Kong <[email protected]>

Revert "revert cherry-pick-release-commit"

This reverts commit 911337c.

undo .github/workflows

Signed-off-by: Terry Kong <[email protected]>

revert docs changes that weren't supposed to be there

Signed-off-by: Terry Kong <[email protected]>
  • Loading branch information
arendu authored and terrykong committed Nov 21, 2024
1 parent bd590d6 commit ddd829b
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 77 deletions.
1 change: 1 addition & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ model:
micro_batch_size: 1
global_batch_size: 64
megatron_amp_O2: True
mamba_hybrid: False

dpo:
# This default value ensures there are no numeric differences beween trained and reference policies when computing log probs.
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ model:
output_original_text: True # needed for the proper metrics support

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
lr: 3e-5
weight_decay: 0.01
betas:
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel, MegatronMambaDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
Expand Down Expand Up @@ -53,7 +53,7 @@ def main(cfg) -> None:
logger = CustomLoggerWrapper(trainer.loggers)

ptl_model = load_from_nemo(
MegatronGPTDPOModel,
MegatronMambaDPOModel if cfg.model.mamba_hybrid else MegatronGPTDPOModel,
cfg.model,
trainer,
strict=True,
Expand Down
79 changes: 8 additions & 71 deletions examples/nlp/gpt/train_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.supervised import SupervisedTrainer
from nemo_aligner.data.nlp.builders import build_dataloader, build_sft_dataset
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel
from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel, MambaSFTModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
Expand All @@ -39,7 +39,7 @@
resolve_and_create_trainer,
retrieve_custom_trainer_state_dict,
)
from nemo_aligner.utils.utils import load_from_nemo
from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo

"""Script to start SFT training"""

Expand All @@ -49,72 +49,9 @@
mp.set_start_method("spawn", force=True)


def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(gpt_cfg, True)
OmegaConf.resolve(cfg)
with open_dict(gpt_cfg):
gpt_cfg.megatron_amp_O2 = cfg.model.get("megatron_amp_O2", False)
gpt_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size
gpt_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size
gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False)
gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None)
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None)
gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None)
gpt_cfg.activations_checkpoint_layers_per_pipeline = cfg.model.get(
"activations_checkpoint_layers_per_pipeline", None
)
gpt_cfg.peft = cfg.model.peft
gpt_cfg.data = cfg.model.data
gpt_cfg.optim = cfg.model.optim
gpt_cfg.precision = cfg.trainer.precision
gpt_cfg.answer_only_loss = cfg.model.answer_only_loss
gpt_cfg.restore_from_path = cfg.model.restore_from_path
gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint
gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end
gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view
gpt_cfg.hidden_dropout = cfg.model.get("hidden_dropout", 0.0)
gpt_cfg.attention_dropout = cfg.model.get("attention_dropout", 0.0)
gpt_cfg.ffn_dropout = cfg.model.ffn_dropout
gpt_cfg.use_flash_attention = cfg.model.get("use_flash_attention", False)
# if TP/PP size is -1, use default TP/PP size as original model
if cfg.model.get("tensor_model_parallel_size", 1) > 0:
gpt_cfg.tensor_model_parallel_size = cfg.model.get("tensor_model_parallel_size", 1)
if cfg.model.get("pipeline_model_parallel_size", 1) > 0:
gpt_cfg.pipeline_model_parallel_size = cfg.model.get("pipeline_model_parallel_size", 1)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get("pipeline_model_parallel_split_rank", 0)

if cfg.model.data.get("chat", False):
# chat model, overwrite the prompt template
prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens)
gpt_cfg.data.train_ds.prompt_template = prompt_template
gpt_cfg.data.validation_ds.prompt_template = prompt_template

sft_cls = GPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

if cfg.model.get("use_flash_attention", None) is not None:
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention

if cfg.model.get("seq_len_interpolation_factor", None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

gpt_cfg.inference = cfg.model.get("inference", {})

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(gpt_cfg)
gpt_cfg.cfg = gpt_cfg

return gpt_cfg


@hydra_runner(config_path="conf", config_name="gpt_sft")
def main(cfg) -> None:
cfg.model = load_and_override_model_config(cfg.model.restore_from_path, cfg.model)
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")

Expand All @@ -126,17 +63,15 @@ def main(cfg) -> None:
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision

ptl_model, updated_cfg = load_from_nemo(
GPTSFTModel,
ptl_model = load_from_nemo(
MambaSFTModel if cfg.model.get("mamba_hybrid", False) else GPTSFTModel,
cfg,
trainer,
strict=True,
modify_config_fn=_modify_config,
restore_path=cfg.model.restore_from_path,
return_updated_cfg=True,
)

init_peft(ptl_model, updated_cfg)
init_peft(ptl_model, cfg.model)

with open_dict(cfg):
# overwrite the model config with the config from the checkpoint
Expand Down Expand Up @@ -170,6 +105,7 @@ def main(cfg) -> None:
train_data_cfg,
ptl_model.tokenizer,
num_samples,
is_mamba=cfg.model.get("mamba_hybrid", False),
answer_only_loss=True,
is_chat=cfg.model.data.chat,
special_tokens=cfg.model.data.chat_prompt_tokens,
Expand All @@ -182,6 +118,7 @@ def main(cfg) -> None:
val_data_cfg,
ptl_model.tokenizer,
num_samples,
is_mamba=cfg.model.get("mamba_hybrid", False),
answer_only_loss=True,
is_chat=cfg.model.data.chat,
special_tokens=cfg.model.data.chat_prompt_tokens,
Expand Down
18 changes: 18 additions & 0 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch
from nemo_aligner.utils.utils import clear_memory

def pad_sequence_to_max(sequences, max_len, padding_value=0):
# Then, pad further to match `max_len`
if sequences.size(1) > max_len:
raise RuntimeError("max len has to be > seq len")
elif sequences.size(1) <= max_len:
pad_size = max_len - sequences.size(1)
padding = torch.full((sequences.size(0), pad_size), padding_value)
padded_sequences = torch.cat([sequences, padding], dim=1)
return padded_sequences

def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
chosen_tokens = [item["chosen"] for item in batch]
Expand Down Expand Up @@ -317,6 +326,15 @@ def augment_dataloader(self, dataloader):
while True:
try:
batch = next(iter_dataloader)
if self.model.cfg.mamba_hybrid:
max_seq_len = max([batch['chosen'].size(-1), batch['rejected'].size(-1), batch['chosen_labels'].size(-1), batch['rejected_labels'].size(-1)])
max_seq_len = torch.tensor(max_seq_len, device=torch.cuda.current_device())
torch.distributed.all_reduce(max_seq_len, op=torch.distributed.ReduceOp.MAX)
max_seq_len = ((max_seq_len.item() + 255) // 256) * 256
batch["chosen"] = pad_sequence_to_max(batch["chosen"], max_seq_len, padding_value=self.model.tokenizer.eos_id)
batch["chosen_labels"] = pad_sequence_to_max(batch["chosen_labels"], max_seq_len, padding_value=-100)
batch["rejected"] = pad_sequence_to_max(batch["rejected"], max_seq_len, padding_value=self.model.tokenizer.eos_id)
batch["rejected_labels"] = pad_sequence_to_max(batch["rejected_labels"], max_seq_len, padding_value=-100)
logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0)
batch["ref_policy_log_probs_chosen"] = chosen_logps
Expand Down
3 changes: 2 additions & 1 deletion nemo_aligner/data/nlp/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def build_dataset(index, name):
build_train_valid_test_regression_rm_datasets = partial(build_train_valid_test_datasets, RegressionRewardModelDataset)


def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None):
def build_sft_dataset(data_cfg, tokenizer, num_samples, is_mamba=False, answer_only_loss=True, is_chat=True, special_tokens=None):
packed_sequence = data_cfg.get("packed_sequence", False)
dataset_kwargs = {}

Expand Down Expand Up @@ -298,6 +298,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i
answer_only_loss=answer_only_loss,
truncation_field=data_cfg.get("truncation_field", "text"),
pad_to_max_length=data_cfg.get("pad_to_max_length", False),
pad_seq_length_to_mult=256 if is_mamba else 16,
index_mapping_dir=data_cfg.get("index_mapping_dir", None),
prompt_template=data_cfg.get("prompt_template", None),
virtual_tokens=0,
Expand Down
70 changes: 70 additions & 0 deletions nemo_aligner/data/nlp/scripts/undo_special_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Script to remove special tokens from dpo datasets
and convert them into list of messages format"""

import json
import re
import sys
input_jsonl = sys.argv[1]
output_jsonl = input_jsonl.replace(".jsonl", ".no_special_toks.jsonl")

def format_conversation(input_string):
# Define roles and patterns
role_patterns = {
"<extra_id_0>System": "system",
"<extra_id_1>User": "user",
"<extra_id_1>Assistant": "assistant"
}

# Initialize an empty output list
conversation = []

# Use regex to find each segment's role and content
segments = re.findall(r"(<extra_id_[0-1]>[^\n]+)\n(.*?)((?=<extra_id_)|$)", input_string, re.DOTALL)

for segment in segments:
role_tag, content, _ = segment
role = role_patterns.get(role_tag.strip(), "unknown")
conversation.append({"role": role, "content": content.strip()})

empty_asst = conversation.pop()

return conversation, empty_asst

with open(input_jsonl, "r") as f, open(output_jsonl, "w") as w:
for line in f:
j = json.loads(line)
prompt = j["prompt"]
undo_spl_prompt, empty_assistant = format_conversation(prompt)
chosen, rejected = j["chosen_response"], j["rejected_response"]
chosen = chosen.split("\n<extra_id_1>")[0]
rejected = rejected.split("\n<extra_id_1>")[0]
chosen_message = {"role": empty_assistant["role"], "content": chosen}
rejected_message = {"role": empty_assistant["role"], "content": rejected}
j_out = {"prompt": undo_spl_prompt, "chosen_response": chosen_message, "rejected_response": rejected_message, "chosen_reward": j["chosen_reward"], "rejected_reward": j["rejected_reward"]}
w.write(json.dumps(j_out) + "\n")

old_format_example="""<extra_id_0>System\n\n<extra_id_1>User\nDoes GPT-4 use RNN or Transformer models, and which one is better for this type of project?\n<extra_id_1>Assistant\nGPT-4 uses a transformer architecture, not a recurrent neural network. Both models are commonly used for natural language processing tasks, and both have advantages and disadvantages, so it depends on the specific application of the model.\n<extra_id_1>User\nCould you explain in detail both the advantages and disadvantages from different perspectives?\n<extra_id_1>Assistant\nYes, here are a few different perspectives on the pros and cons of RNN and transformer models:\n \n\n The RNN model has better sequential memory and can better model sequences of data, whereas the transformer model has better parallel computation and can better model large input datasets. This makes RNNs better for applications like speech recognition and machine translation, where the input is a sequence of sounds or words, and transformers better for applications like natural language processing and text classification, where the input is a large collection of documents or social media posts.\n \n\n RNNs are more straightforward and easier to understand and debug, whereas transformers are more flexible and can be used for more diverse applications. This makes RNNs better for teaching and research, and transformers better for actual applications in industry and the real world.\n \n\n RNNs are more efficient and faster to run, whereas transformers are more accurate and produce better results. This makes RNNs better for fast prototyping and testing, and transformers better for actual deployment in production applications.\n<extra_id_1>User\ncould you add more in a table\n<extra_id_1>Assistant\nHere are some more pros and cons of RNN and transformer models:\n \n\n Pros of RNNs:\n \n\n - Better sequential memory and handling of sequences\n - Simpler and easier to understand\n - Faster to run\n - Good for teaching and research\n \n\n Cons of RNNs:\n \n\n - Limited parallelism\n - Poor handling of large input datasets\n - Not as accurate as transformers\n - Not as flexible as transformers\n \n\n Pros of transformers:\n \n\n - Better parallelism and handling of large input datasets\n - More flexible and versatile\n - Higher accuracy\n - Good for actual applications in industry and the real world\n \n\n Cons of transformers:\n \n\n - More complex and harder to understand and debug\n - Slower to run\n - Not as good for teaching and research\n<extra_id_1>User\nFor a chat-based music recommendation platform, which model is recommended? Please mention any model that works well, including those mentioned above or any others.\n<extra_id_1>Assistant\nFor a chat-based music recommendation platform, I would recommend using a transformer architecture, such as the GPT-4 model. Transformers are better for handling large input datasets, and they are more flexible and versatile than RNNs. They are also more accurate, so they are better for actually recommending music to users in a real-world application. RNNs are better for teaching and research, as they are simpler and easier to understand and debug.\n<extra_id_1>User\ncould you add more?\n<extra_id_1>Assistant\n"""

print("****************************************")
print("old format prompt")
print("****************************************")
print(old_format_example)
print("****************************************")
print("new format prompt")
print("****************************************")
new_format_example, _ = format_conversation(old_format_example)
print(json.dumps(new_format_example, indent=2))
6 changes: 6 additions & 0 deletions nemo_aligner/models/nlp/gpt/gpt_sft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel
from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split
from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy
from nemo.collections.nlp.modules.common.text_generation_utils import (
Expand Down Expand Up @@ -225,3 +226,8 @@ def finish_inference(self):
self._restore_activation_checkpointing_args()
self._restore_sequence_parallelism_args()
set_train(self)


class MambaSFTModel(MegatronMambaModel, GPTSFTModel):
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
7 changes: 7 additions & 0 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from functools import partial

import torch
from megatron.core import parallel_state
from megatron.core.models.mamba import MambaModel
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.utils import divide
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel
from nemo.collections.nlp.modules.common.megatron.utils import (
average_losses_across_data_parallel_group,
get_iterator_k_split,
Expand Down Expand Up @@ -460,3 +463,7 @@ def get_ref_policy_logprobs(self, batch):

# return in GPU, trainer needs to move to cpu
return ref_log_probs

class MegatronMambaDPOModel(MegatronMambaModel, MegatronGPTDPOModel): # @adithyare inherence order matters
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
5 changes: 3 additions & 2 deletions nemo_aligner/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import torch
from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedTensorFactory
from megatron.core.num_microbatches_calculator import reconfigure_microbatch_calculator
from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator as reconfigure_microbatch_calculator
from omegaconf import DictConfig, OmegaConf
from torch.masked import as_masked_tensor

Expand Down Expand Up @@ -122,7 +122,8 @@ def load_checkpoint_model_config(restore_path):
return OmegaConf.load(cfg_path)

with tempfile.TemporaryDirectory() as tmpdir:
NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, extract_config_only=True)
members = NLPSaveRestoreConnector._filtered_tar_info(restore_path, filter_fn=lambda name: '.yaml' in name)
NLPSaveRestoreConnector._unpack_nemo_file(restore_path, tmpdir, members=members)
cfg = OmegaConf.load(os.path.join(tmpdir, config_name_in_ckpt))

return cfg
Expand Down

0 comments on commit ddd829b

Please sign in to comment.