Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RPO on multiple responses #311

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions examples/nlp/gpt/conf/gpt_rpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
defaults:
- optional [email protected]_tp_comm_overlap_cfg:

trainer:
num_nodes: 1
devices: 8
accelerator: gpu
precision: bf16

# rpo specific args
rpo:
max_epochs: 1
max_steps: -1
val_check_interval: 100
save_interval: 100
limit_train_batches: 1.0

# how many GBS we loop over
limit_val_batches: 10
gradient_clip_val: 1.0
num_responses: 4

# do not change these
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_time: null
max_epochs: ${.rpo.max_epochs}
max_steps: ${.rpo.max_steps}

exp_manager:
explicit_log_dir: /results
exp_dir: null
name: megatron_gpt
max_time_per_run: ${trainer.max_time}
create_wandb_logger: False
wandb_logger_kwargs:
project: nemo_aligner_rpo
name: rlhf_gpt3_rpo
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 5
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits
filename: "megatron_gpt--{${.monitor}:.3f}-{step}-{consumed_samples}-{epoch}"
model_parallel_size: 4

pretrained_checkpoint:
restore_from_path: /models/llama3_8b_sft_alpha_nodes8_tp4_3e-6_bs384_rerun_1200.nemo

model:
mcore_gpt: True
micro_batch_size: 4
global_batch_size: 8
megatron_amp_O2: True

rpo:
# This default value ensures there are no numeric differences beween trained and reference policies when computing log probs.
# A higher value can be used to speed-up log probs computations, but may cause numeric differences.
log_prob_forward_micro_batch_size: ${multiply:${model.micro_batch_size}, trainer.rpo.num_responses}
preference_average_log_probs: False # whether normalizing log probs according to the sequence length in preference_loss
sft_average_log_probs: ${.preference_average_log_probs} # whether normalizing log probs according to the sequence length in sft_loss
gt_reward_scale: 1. # the scale of the rewards in RPO
preference_loss: rpo # the preference loss, we support dpo, ipo, rpo_sq, rpo_bwd_kl, rpo_fwd_kl
preference_loss_weight: 1 # the coefficient of the preference loss
sft_loss_weight: 0.05 # the coefficient of the SFT loss
beta: 0.2
eta: 0.2
num_responses: ${trainer.rpo.num_responses}

#encoder_seq_length: 4096
#max_position_embeddings: ${model.encoder_seq_length}

# miscellaneous
seed: 1234

#peft
peft:
peft_scheme: "none" # ["lora", "none"]
restore_from_path: null
restore_from_ckpt:
checkpoint_dir: null
checkpoint_name: null

lora_tuning:
target_modules: ["attention_qkv"] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all'
adapter_dim: 32
adapter_dropout: 0.0
column_init_method: "xavier" # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: "zero" # IGNORED if linear_adapter is used, options: xavier, zero or normal
layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers
weight_tying: False
position_embedding_strategy: null # used only when weight_tying is True

optim:
name: distributed_fused_adam
bucket_cap_mb: 200
overlap_grad_sync: False
contiguous_grad_buffer: True
lr: 9e-6
weight_decay: 0.1
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 10
constant_steps: 1000
min_lr: 9e-7

data:
data_impl: jsonl
splits_string: null
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 0
dataloader_type: single # cyclic
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
data_prefix:
train:
- /data/responses_general.rm.formatted.4resp.jsonl
test:
- /data/rpo_test_set.jsonl
validation:
- /data/rpo_test_set.jsonl

# define fields from the base model's config that should be ignored when merging with this config.
overwrite_base_config:
data:
data_prefix: True

precision: ${trainer.precision}
162 changes: 162 additions & 0 deletions examples/nlp/gpt/train_gpt_rpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf

from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.rpo import RPOTrainer, rpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_rpo_datasets
from nemo_aligner.models.nlp.gpt.megatron_gpt_rpo_model import MegatronGPTRPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
add_custom_checkpoint_callback,
extract_optimizer_scheduler_from_ptl_model,
init_distributed,
init_peft,
init_using_ptl,
resolve_and_create_trainer,
retrieve_custom_trainer_state_dict,
)
from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo, retrieve_model_state_dict_in_cpu

OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True)

mp.set_start_method("spawn", force=True)


@hydra_runner(config_path="conf", config_name="gpt_rpo")
def main(cfg) -> None:
cfg.model = load_and_override_model_config(cfg.pretrained_checkpoint.restore_from_path, cfg.model)

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")

trainer = resolve_and_create_trainer(cfg, "rpo")
exp_manager(trainer, cfg.exp_manager)
logger = CustomLoggerWrapper(trainer.loggers)

ptl_model = load_from_nemo(
MegatronGPTRPOModel,
cfg.model,
trainer,
strict=True,
load_base_model_only=False,
restore_path=cfg.pretrained_checkpoint.restore_from_path,
)

init_peft(ptl_model, cfg.model)

if cfg.model.peft.peft_scheme == "none":
ref_policy_state_dict = retrieve_model_state_dict_in_cpu(
ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False)
)
ptl_model.ref_policy_state_dict = ref_policy_state_dict

# pull values from checkpoint
trainer_restore_path = trainer.ckpt_path

# TODO: log this restore path
if trainer_restore_path is not None:
custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer)
consumed_samples = custom_trainer_state_dict["consumed_samples"]
else:
custom_trainer_state_dict = None
consumed_samples = 0

init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False))

# use the entire dataset
train_valid_test_num_samples = [-1 * cfg.model.global_batch_size] * 3

train_ds, validation_ds, test_ds = build_train_valid_test_rpo_datasets(
cfg=cfg.model,
data_prefix=cfg.model.data.data_prefix,
data_impl=cfg.model.data.data_impl,
splits_string=cfg.model.data.splits_string,
train_valid_test_num_samples=train_valid_test_num_samples,
seq_length=cfg.model.data.seq_length,
seed=cfg.model.seed,
tokenizer=ptl_model.tokenizer,
)

train_dataloader = build_dataloader(
cfg=cfg,
dataset=train_ds,
consumed_samples=consumed_samples,
mbs=cfg.model.micro_batch_size,
gbs=cfg.model.global_batch_size,
load_gbs=True,
pad_samples_to_global_batch_size=False,
collate_fn=partial(
rpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
),
)

val_dataloader = build_dataloader(
cfg=cfg,
dataset=validation_ds,
consumed_samples=0,
mbs=cfg.model.micro_batch_size,
gbs=cfg.model.global_batch_size,
load_gbs=True,
pad_samples_to_global_batch_size=False,
collate_fn=partial(
rpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
),
use_random_sampler=False,
)

init_using_ptl(trainer, ptl_model, train_dataloader, train_ds)
optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model)

ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model)

logger.log_hyperparams(OmegaConf.to_container(cfg))

timer = Timer(cfg.exp_manager.get("max_time_per_run"))
dpo_trainer = RPOTrainer(
cfg=cfg.trainer.rpo,
model=ptl_model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=None,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)

if custom_trainer_state_dict is not None:
dpo_trainer.load_state_dict(custom_trainer_state_dict)

dpo_trainer.fit()


if __name__ == "__main__":
main()
Loading
Loading