Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KTO training #1244

Merged
merged 37 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a950f8b
Add a chat data preprocessing script
dmahan93 Jun 21, 2024
e360e24
add EOT at end of a chat
dmahan93 Jun 21, 2024
9ee4a8f
- add different packing impl (Unpacked, packing until overflow)
dmahan93 Jun 21, 2024
0678573
update README.md
dmahan93 Jun 21, 2024
15e3059
Merge remote-tracking branch 'origin/add-chat-template-based-datasets…
dmahan93 Jun 24, 2024
2d20d86
- Add metrics to forward step to add DPO specific metrics that are us…
dmahan93 Jun 25, 2024
c045006
Update arguments.py to use train_label_data_paths instead of label_da…
dmahan93 Jun 25, 2024
eed3643
Merge remote-tracking branch 'origin/add-different-packing-impl' into…
dmahan93 Jun 25, 2024
0392080
- Bugfixes from upstreaming....
dmahan93 Jun 25, 2024
361f459
- add precompute logprobs...
dmahan93 Jun 25, 2024
7398e07
- Finishing up precompute logprobs...
dmahan93 Jun 26, 2024
51af714
- update readme for DPO...
dmahan93 Jun 26, 2024
b7bc196
- Add KTO
dmahan93 Jun 28, 2024
0cdcf2b
Update README.md
dmahan93 Jun 28, 2024
ba29aef
Merge branch 'main' into add-kto
Quentin-Anthony Sep 9, 2024
daab1ab
precommit
Quentin-Anthony Sep 9, 2024
0760116
- KTO implementation from main...
dmahan93 Sep 9, 2024
c83ff44
Merge branch 'add-kto' of https://github.com/dmahan93/gpt-neox into HEAD
dmahan93 Sep 9, 2024
f9ead88
Merge branch 'add-kto' of https://github.com/dmahan93/gpt-neox into HEAD
dmahan93 Sep 9, 2024
b6f9d5c
initial changes...
dmahan93 Sep 9, 2024
38aef24
pre-commit...
dmahan93 Sep 9, 2024
303ba80
hotfix + data loader update
dmahan93 Sep 9, 2024
9b33c7b
merge in main
dmahan93 Sep 9, 2024
a126566
Merge branch 'main' into add-kto
Quentin-Anthony Sep 10, 2024
792570e
- add different packing impl (Unpacked, packing until overflow)
dmahan93 Jun 21, 2024
4eca43f
- Add metrics to forward step to add DPO specific metrics that are us…
dmahan93 Jun 25, 2024
243b716
- Bugfixes from upstreaming....
dmahan93 Jun 25, 2024
3e966b0
- Add KTO
dmahan93 Jun 28, 2024
fadf3e6
Update README.md
dmahan93 Jun 28, 2024
b4c5a91
precommit
Quentin-Anthony Sep 9, 2024
d701aea
- KTO implementation from main...
dmahan93 Sep 9, 2024
9daa1bc
initial changes...
dmahan93 Sep 9, 2024
486840a
pre-commit...
dmahan93 Sep 9, 2024
f39ccbc
hotfix + data loader update
dmahan93 Sep 9, 2024
c730c6d
Merge branch 'add-kto' of https://github.com/dmahan93/gpt-neox into a…
Quentin-Anthony Sep 11, 2024
42123b5
Merge branch 'main' into add-kto
Quentin-Anthony Sep 14, 2024
cd6b2f2
precommit
Quentin-Anthony Sep 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,32 @@ Additional DeepSpeed settings besides those mentioned above should be wrapped in
"eval_iters": 10,
```

However, if you want to use DPO style training you'll need to set pos/neg data paths instead of a single one, e.g.
For KTO style training, you'll need to add the reward & label data path, e.g.:

```yaml
"data_impl": "mmap",
# Suggested data paths when using GPT-NeoX locally
"train_data_path": "data/enwik8/enwik8_text_document",
"train_label_data_path": "data/enwik8/enwik8_text_label_document",
"train_reward_data_path": "data/enwik8/enwik8_text_reward_document",
"test_data_path": "data/enwik8/enwik8_text_document",
"test_label_data_path": "data/enwik8/enwik8_text_label_document",
"test_reward_data_path": "data/enwik8/enwik8_text_reward_document",
"valid_data_path": "data/enwik8/enwik8_text_document",
"valid_label_data_path": "data/enwik8/enwik8_text_label_document",
"valid_reward_data_path": "data/enwik8/enwik8_text_reward_document",
"vocab_file": "data/gpt2-vocab.json",
"merge_file": "data/gpt2-merges.txt",
"save": "checkpoints",
"load": "checkpoints",
"tensorboard_dir": "tensorboard",
"log_dir": "logs",
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,
```

For DPO style training, you'll need to set pos/neg data paths instead of a single one, e.g.

```yaml
"dataset_impl": "pairwise",
Expand Down
27 changes: 24 additions & 3 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def build_the_dataset(
pos_label_prefix=None,
neg_label_prefix=None,
precompute_model_name=None,
reward_prefix=None,
):
"""Build train/valid/test datasets."""
if dataset_impl == "gpt2":
Expand All @@ -84,6 +85,12 @@ def build_the_dataset(
data_prefix + "_" + precompute_model_name, data_impl, skip_warmup
)
precompute_indexed_dataset = precompute_indexed_dataset
else:
precompute_indexed_dataset = None
if reward_prefix is not None:
reward_dataset = make_indexed_dataset(reward_prefix, data_impl, skip_warmup)
else:
reward_dataset = None
elif dataset_impl == "pairwise":
pos_indexed_dataset = make_indexed_dataset(
pos_data_prefix, data_impl, skip_warmup
Expand Down Expand Up @@ -127,7 +134,6 @@ def build_the_dataset(
print_rank_0(" no. of documents:{}".format(total_num_of_documents))
dataset = None
documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32)

if dataset_impl == "gpt2":
dataset = GPT2Dataset(
name,
Expand All @@ -141,6 +147,8 @@ def build_the_dataset(
allow_chopped=allow_chopped,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
reward_dataset=reward_dataset,
ref_dataset=precompute_indexed_dataset,
)
elif dataset_impl == "pairwise":
dataset = PairwiseDataset(
Expand All @@ -160,7 +168,6 @@ def build_the_dataset(
pos_ref_dataset=pos_ref_dataset,
neg_ref_dataset=neg_ref_dataset,
)

return dataset


Expand Down Expand Up @@ -285,10 +292,13 @@ def build_weighted_datasets(
for i, (
train_path,
train_label_path,
train_reward_path,
valid_path,
valid_label_path,
valid_reward_path,
test_path,
test_label_path,
test_reward_path,
pos_train_path,
neg_train_path,
pos_train_label_path,
Expand All @@ -307,12 +317,21 @@ def build_weighted_datasets(
neox_args.train_label_data_paths
if neox_args.train_label_data_paths
else [],
neox_args.train_reward_data_paths
if neox_args.train_reward_data_paths
else [],
neox_args.valid_data_paths if neox_args.valid_data_paths else [],
neox_args.valid_label_data_paths
if neox_args.valid_label_data_paths
else [],
neox_args.valid_reward_data_paths
if neox_args.valid_reward_data_paths
else [],
neox_args.test_data_paths if neox_args.test_data_paths else [],
neox_args.test_label_data_paths if neox_args.test_label_data_paths else [],
neox_args.test_reward_data_paths
if neox_args.test_reward_data_paths
else [],
neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [],
neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [],
neox_args.pos_train_label_data_paths
Expand Down Expand Up @@ -359,6 +378,7 @@ def build_weighted_datasets(
pos_label_prefix=pos_train_label_path,
neg_label_prefix=neg_train_label_path,
precompute_model_name=neox_args.precompute_model_name,
reward_prefix=train_reward_path,
)
)

Expand All @@ -382,6 +402,7 @@ def build_weighted_datasets(
pos_label_prefix=pos_valid_label_path,
neg_label_prefix=neg_valid_label_path,
precompute_model_name=neox_args.precompute_model_name,
reward_prefix=valid_reward_path,
)
)

Expand All @@ -405,6 +426,7 @@ def build_weighted_datasets(
pos_label_prefix=pos_test_label_path,
neg_label_prefix=neg_test_label_path,
precompute_model_name=neox_args.precompute_model_name,
reward_prefix=test_reward_path,
)
)
return train_datasets, valid_datasets, test_datasets
Expand Down Expand Up @@ -502,7 +524,6 @@ def build_train_valid_test_data_iterators(neox_args):
)

if neox_args.weight_by_num_documents:

# gets the number of documents in each datapath
get_num_docs_list = lambda datasets: [
dataset.indexed_dataset.sizes.shape[0] for dataset in datasets
Expand Down
143 changes: 87 additions & 56 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,23 @@ def __init__(
build_index_mappings=True,
use_shared_fs=True,
label_dataset=None,
reward_dataset=None,
ref_dataset=None,
):

self.name = name
self.pack_impl = pack_impl
self.allow_chopped = allow_chopped
self.indexed_dataset = indexed_dataset
self.label_dataset = label_dataset
self.reward_dataset = reward_dataset
self.ref_dataset = ref_dataset
self.seq_length = seq_length

# Checks
assert self.reward_dataset is None or (
pack_impl == "unpacked"
), "Reward dataset only supported with unpacked data."
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]

Expand Down Expand Up @@ -90,77 +97,101 @@ def __getitem__(self, idx):
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1]
# Labels and texts are supposed to be fully in sync.
datasets = (
[self.indexed_dataset]
if self.label_dataset is None
else [self.indexed_dataset, self.label_dataset]
)
datasets = [self.indexed_dataset]
rw_indx = 1
if self.label_dataset is not None:
rw_indx += 1
datasets.append(self.label_dataset)
if self.reward_dataset is not None:
datasets.append(self.reward_dataset)
else:
rw_indx = -1
if self.ref_dataset is not None:
datasets.append(self.ref_dataset)
samples = []
sample_lengths = []
# If we are within the same document, just extract the chunk.
for n, dataset in enumerate(datasets):
if doc_index_f == doc_index_l:
samples.append(
dataset.get(
self.doc_idx[doc_index_f],
offset=offset_f,
length=offset_l - offset_f + 1,
if rw_indx == n:
# If we are in the reward dataset, we only need the last token.
rw = dataset.get(self.doc_idx[doc_index_f])
samples.append(
np.array([rw[0] for _ in range(len(samples[-1]))])
)
else:
samples.append(
dataset.get(
self.doc_idx[doc_index_f],
offset=offset_l,
length=offset_l - offset_f + 1,
)
)
)
else:
if n != rw_indx:
# reset
sample_lengths = []
# Otherwise, get the rest of the initial document.
sample_list = [
dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
]
if n == rw_indx:
rw = dataset.get(self.doc_idx[doc_index_f])
sample_list = [
np.array([rw[0] for _ in range(sample_lengths.pop(0))])
]
else:
sample_list = [
dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
]
sample_lengths.append(len(sample_list[-1]))
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(dataset.get(self.doc_idx[i]))
if n == rw_indx:
rw = dataset.get(self.doc_idx[i])
sample_list.append(
np.array([rw[0] for _ in range(sample_lengths.pop(0))])
)
else:
sample_list.append(dataset.get(self.doc_idx[i]))
sample_lengths.append(len(sample_list[-1]))
# And finally add the relevant portion of last document.
sample_list.append(
dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)
)
if n == rw_indx:
rw = dataset.get(self.doc_idx[doc_index_l])
sample_list.append(
np.array([rw[0] for _ in range(sample_lengths.pop(0))])
)
else:
sample_list.append(
dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)
)
sample_lengths.append(len(sample_list[-1]))
samples.append(np.concatenate(sample_list))

if len(datasets) == 1:
if len(samples[0]) < (self.seq_length + 1):
# Pad with -100s so the masking function can ignore these.
samples[0] = np.pad(
samples[0],
(0, (self.seq_length + 1) - len(samples[0])),
mode="constant",
constant_values=-100,
)
elif len(samples[0]) > (self.seq_length + 1):
# Check for overflow and truncate.
samples[0] = samples[0][: (self.seq_length + 1)]
return {"text": np.array(samples[0], dtype=np.int64)}
else:
if len(samples[0]) < (self.seq_length + 1):
# Pad with 0s, can use any number since it's masked.
samples[0] = np.pad(
samples[0],
(0, (self.seq_length + 1) - len(samples[0])),
mode="constant",
constant_values=0,
)
# pad with -100s so we can mask it out
samples[1] = np.pad(
samples[1],
(0, (self.seq_length + 1) - len(samples[1])),
for i in range(len(samples)):
mask = (self.label_dataset is not None) and (i == 1)
if len(samples[i]) < (self.seq_length + 1):
# Pad
samples[i] = np.pad(
samples[i],
(0, (self.seq_length + 1) - len(samples[i])),
mode="constant",
constant_values=-100,
constant_values=-100 if mask else 0,
)
elif len(samples[0]) > (self.seq_length + 1):
# Check for overflow and truncate.
samples[0] = samples[0][: (self.seq_length + 1)]
samples[1] = samples[1][: (self.seq_length + 1)]
return {
"text": np.array(samples[0], dtype=np.int64),
"label": np.array(samples[1], dtype=np.int64),
}
except IndexError:
elif len(samples[i]) > (self.seq_length + 1):
# Truncate
samples[i] = samples[i][: (self.seq_length + 1)]
ret = {"text": np.array(samples[0], dtype=np.int64)}
next_idx = 1
if self.label_dataset is not None:
ret["label"] = np.array(samples[next_idx], dtype=np.int64)
next_idx += 1
if self.reward_dataset is not None:
ret["reward"] = np.array(samples[next_idx], dtype=np.float32)
next_idx += 1
if self.ref_dataset is not None:
ret["ref"] = np.array(samples[next_idx], dtype=np.float32)
return ret
except IndexError as err:
new_idx = idx % len(self)
print(
f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})"
f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx}), error: {err}"
)
return self[new_idx]

Expand Down
25 changes: 19 additions & 6 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,9 +1052,9 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Dataset implementation, can be one of "gpt2" or "pairwise"
"""

train_impl: Literal["normal", "dpo", "rm"] = "normal"
train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal"
"""
Training implementation, can be one of "normal", "dpo", or "rm"
Training implementation, can be one of "normal", "dpo", "kto", or "rm"
"""

dpo_fp32: bool = True
Expand All @@ -1072,11 +1072,24 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Beta value for DPO
"""

z_loss: float = 0.0
kto_fp32: bool = True
"""
Z-loss parameter, only implemented for RM training currently.
https://arxiv.org/pdf/2204.02311
https://arxiv.org/pdf/2309.10305
Whether to cast logits to fp32 for KTO loss calculation.
"""

kto_desirable_weight: float = 1.0
"""
Weight for desirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes.
"""

kto_undesirable_weight: float = 1.0
"""
Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes.
"""

kto_beta: float = 0.1
"""
Beta value for KTO
"""

allow_chopped: bool = True
Expand Down
5 changes: 4 additions & 1 deletion megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import copy
import json
import math
import os
import time
from typing import List, Union
Expand Down Expand Up @@ -876,7 +877,9 @@ def precompute_logits(neox_args, model):
out_dataset = make_builder(out_path + ".bin", neox_args.data_impl)
out_dataset._dtype = np.float32
i = 0
while i < len(dataset):

# TODO: Not sure why this requires a multiple of 8? Investigate later.
while i < int(math.ceil(len(dataset) / 8.0) * 8):
start = time.time()
model.module.clear_cache() # clear kv cache between batches
if is_mp_rank_0():
Expand Down
Loading
Loading