Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Reward Model training #1246

Merged
merged 17 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def load_checkpoint(
load_lr_scheduler_states=load_optim_and_scheduler,
load_module_only=not load_optim_and_scheduler,
tag=tag,
load_module_strict=neox_args.train_impl != "rm",
)

if checkpoint_name is None:
Expand Down
54 changes: 35 additions & 19 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def __init__(
is_last_layer=False,
):
super().__init__()
parallelism = neox_args.output_layer_parallelism
self.is_rm = neox_args.train_impl == "rm"
parallelism = neox_args.output_layer_parallelism if not self.is_rm else "row"
if parallelism == "column":
self.final_linear = mpu.ColumnParallelLinear(
neox_args=neox_args,
Expand All @@ -212,26 +213,41 @@ def __init__(
mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here
seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0
)

# else:
# print(
# 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.'
# )
# exit()
# self.final_linear = mpu.RowParallelLinear(
# neox_args=neox_args,
# input_size=neox_args.hidden_size,
# output_size=neox_args.padded_vocab_size,
# bias=False,
# input_is_parallel=False,
# init_method=init_method,
# parallel_output=parallel_output,
# skip_bias_add=False,
# mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here
# )
else:
if not self.is_rm:
print(
'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.'
)
exit()
# self.final_linear = mpu.RowParallelLinear(
# neox_args=neox_args,
# input_size=neox_args.hidden_size,
# output_size=neox_args.padded_vocab_size,
# bias=False,
# input_is_parallel=False,
# init_method=init_method,
# parallel_output=parallel_output,
# skip_bias_add=False,
# mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here
# )
else: # Not using cross entropy loss for RMs
self.rm_linear = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=1,
bias=False,
input_is_parallel=False,
init_method=init_method,
parallel_output=False,
skip_bias_add=False,
mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here
)

def forward(self, hidden_states):
return self.final_linear(hidden_states)
if not self.is_rm:
return self.final_linear(hidden_states)
else:
return self.rm_linear(hidden_states)


class _MegablocksAdapter(nn.Module):
Expand Down
11 changes: 9 additions & 2 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,9 +997,9 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Dataset implementation, can be one of "gpt2" or "pairwise"
"""

train_impl: Literal["normal", "dpo"] = "normal"
train_impl: Literal["normal", "dpo", "rm"] = "normal"
"""
Training implementation, can be one of "normal" or "dpo"
Training implementation, can be one of "normal", "dpo", or "rm"
"""

dpo_fp32: bool = True
Expand All @@ -1012,6 +1012,13 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Beta value for DPO
"""

z_loss: float = 0.0
"""
Z-loss parameter, only implemented for RM training currently.
https://arxiv.org/pdf/2204.02311
https://arxiv.org/pdf/2309.10305
"""

allow_chopped: bool = True
"""
WARNING: if your packing impl is packed, this is ignored.
Expand Down
36 changes: 31 additions & 5 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def get_batch(neox_args, data_iterator):
# Items and their type.
if neox_args.train_impl == "normal":
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
elif neox_args.train_impl == "dpo":
elif neox_args.train_impl in ["dpo", "rm"]:
keys = (
[["pos", "pos_label"], ["neg", "neg_label"]]
if neox_args.pos_train_label_data_paths
Expand All @@ -338,7 +338,7 @@ def get_batch(neox_args, data_iterator):
data=data,
datatype=datatype,
)
elif neox_args.train_impl == "dpo":
elif neox_args.train_impl in ["dpo", "rm"]:
pos_tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
Expand All @@ -353,7 +353,7 @@ def get_batch(neox_args, data_iterator):
data=data,
datatype=datatype,
)
if neox_args.precompute_model_name:
if (neox_args.precompute_model_name) and (neox_args.train_impl == "dpo"):
ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float)
else:
ref_data = {"pos_ref": None}
Expand Down Expand Up @@ -491,7 +491,7 @@ def forward_step(
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
neox_args=neox_args, data_iterator=data_iterator
)
if neox_args.train_impl == "dpo":
if neox_args.train_impl in ["dpo", "rm"]:
tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch(
neox_args=neox_args, data_iterator=data_iterator
)
Expand Down Expand Up @@ -532,6 +532,32 @@ def forward_step(
else:
moe_loss = 0.0
loss = main_loss + moe_loss
elif neox_args.train_impl == "rm":
maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args)
if type(maybe_tuple) is tuple:
outputs, _ = maybe_tuple
else:
outputs = maybe_tuple
pos, neg = torch.chunk(outputs, 2, 0)
pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0)
# We assume that each pos, neg pair occur in the same order
# e.g. second nonzero pos is the corresponding second nonzero neg
# and that there are also an equal number of pos and neg in each sequence.
pos_indx = pos_loss_mask.nonzero()
neg_indx = neg_loss_mask.nonzero()
# indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index.
pos_indx = pos_indx[:, 1].unsqueeze(1)
neg_indx = neg_indx[:, 1].unsqueeze(1)
pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx)
neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx)
with torch.no_grad():
metrics["pos_values"] = pos.clone().detach().mean()
metrics["neg_values"] = neg.clone().detach().mean()
metrics["margin"] = (pos - neg).clone().detach().mean()
metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean()
loss = (-F.logsigmoid(pos - neg).mean()) + (
(neox_args.z_loss * (pos**2 + neg**2)).mean()
)
elif neox_args.train_impl == "dpo":
# Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
with torch.no_grad():
Expand Down Expand Up @@ -616,7 +642,7 @@ def get_model(neox_args, use_cache=False):
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand Down
25 changes: 24 additions & 1 deletion tools/datasets/preprocess_data_with_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def build_chat(
apply_mask: bool,
tokenizer: PreTrainedTokenizer,
only_last_turn: bool = False,
for_rm: bool = False,
) -> Tuple[List[int], List[int]]:
"""
Build a chat from a list of dictionaries. Each dictionary should have a "role" and "content" key, this follows the
Expand All @@ -91,12 +92,28 @@ def build_chat(
:param apply_mask: Whether to apply a loss mask to the chat, if False, all tokens will be included in the loss
:param tokenizer: A HF tokenizer
:param only_last_turn: Whether to only include the last turn in the chat, needed for some fine-tuning tasks
:param for_rm: Whether this is for a reward model or not, this will mask everything except EOS token.
If you need a more complicated setup, you can modify this function to suit your needs.
"""
tokens = []
mask = []
if apply_mask is False:
tokens = tokenizer.apply_chat_template(chat)
mask = tokens
if tokenizer.eos_token_id is not None:
mask.append(tokenizer.eos_token_id)
tokens.append(tokenizer.eos_token_id)
return tokens, mask
elif for_rm:
tokens = tokenizer.apply_chat_template(chat)
mask = [-100] * len(tokens)
if tokenizer.eos_token_id is not None:
mask.append(tokenizer.eos_token_id)
tokens.append(tokenizer.eos_token_id)
else:
raise ValueError(
"Tokenizer does not have an EOS token, unable to determine good mask, please edit and make your own."
)
return tokens, mask
for i, turn in enumerate(chat):
add_gen = (
Expand All @@ -105,7 +122,7 @@ def build_chat(
chat_tokens = tokenizer.apply_chat_template(
chat[: i + 1], add_generation_prompt=add_gen
)[len(tokens) :]

# remove previous stuff...
tokens.extend(chat_tokens)
if only_last_turn and (i != len(chat) - 1):
Expand Down Expand Up @@ -137,6 +154,7 @@ def encode(self, text):
not self.args.no_mask,
Encoder.tokenizer,
self.args.only_last,
self.args.for_rm,
)
ids[key] = (text_ids, label_ids)
return ids, len(text)
Expand All @@ -163,6 +181,11 @@ def get_args():
help="If set, this will not mask any tokens in the input data.",
action="store_true",
)
group.add_argument(
"--for-rm",
help="If set, this will mask everything except the last token in the chat.",
action="store_true",
)
group.add_argument(
"--generation-role",
type=str,
Expand Down
Loading