From a950f8b44f89f0f5bcec54dde4a64ddab91bac80 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 21 Jun 2024 11:27:31 -0500 Subject: [PATCH 01/28] Add a chat data preprocessing script --- .../preprocess_data_with_chat_template.py | 345 ++++++++++++++++++ 1 file changed, 345 insertions(+) create mode 100644 tools/datasets/preprocess_data_with_chat_template.py diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py new file mode 100644 index 000000000..df3f3cbae --- /dev/null +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -0,0 +1,345 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A script for processing a dataset such that chat templates are utilized in the creation of the data. +These are then used to perform instruction/chat model finetunes (for example, finetuning a model on only the assistant +portions of a chatml dataset). + +This follows the same output format as 'preprocess_data_with_mask.py' but using chat templates to generate the data. +This way we can support multiturn chat data in the finetuning process. instead of relying on a single turn of data. + +To run this script, first edit `tools/datasets/corpora.py` such that the command to call + `tools/datasets/preprocess_data_with_chat_template.py` is as follows: + +``` +cmd = f"python tools/datasets/preprocess_data_with_with_chat_template.py \ + --input {jsonl_filepath} \ + --output-prefix {parent_folder}/{self.name} \ + --tokenizer-path {hf-tokenizer} \ + --jsonl-keys {jsonl_keys} \ + --dataset-impl mmap \ + --workers {self.num_workers} " + +if self.only_last: + cmd += f"--only-last " + +if self.no_mask: + cmd += f"--no-mask " +``` + +Then, specify +``` +"train_data_paths": ["/path/to/dataset/name_text_document"], +"label_data_paths": ["/path/to/dataset/name_label_document"] +``` +in your YML config. This will then allow for finetuning on the data with loss masks set appropriately. + +""" + +import argparse +import multiprocessing +import os +import sys + +import lm_dataformat as lmd +import numpy as np + +sys.path.append( + os.path.abspath( + os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir) + ) +) + +import time +import tqdm +import jsonlines + +from megatron.data import indexed_dataset +from threading import Semaphore +from typing import List, Dict, Tuple +from transformers import AutoTokenizer, PreTrainedTokenizer + + +def build_chat( + chat: List[Dict[str, str]], + generation_role: str, + apply_mask: bool, + tokenizer: PreTrainedTokenizer, + only_last_turn: bool = False, +) -> Tuple[List[int], List[int]]: + """ + Build a chat from a list of dictionaries. Each dictionary should have a "role" and "content" key, this follows the + Chat Template from https://huggingface.co/docs/transformers/main/en/chat_templating + + :param chat: A list of dictionaries with "role" and "content" keys + :param generation_role: The role of the model generating the chat, usually "assistant" + :param apply_mask: Whether to apply a loss mask to the chat, if False, all tokens will be included in the loss + :param tokenizer: A HF tokenizer + :param only_last_turn: Whether to only include the last turn in the chat, needed for some fine-tuning tasks + """ + tokens = [] + mask = [] + if apply_mask is False: + tokens = tokenizer.apply_chat_template(chat) + mask = tokens + return tokens, mask + for i, turn in enumerate(chat): + add_gen = ( + False if i == len(chat) - 1 else chat[i + 1]["role"] == generation_role + ) + chat_tokens = tokenizer.apply_chat_template( + chat[: i + 1], add_generation_prompt=add_gen + ) + # remove previous stuff... + tokens.extend(chat_tokens) + if only_last_turn and (i != len(chat) - 1): + mask.extend([-100] * len(chat_tokens)) + elif apply_mask and (turn["role"] != generation_role): + mask.extend([-100] * len(chat_tokens)) + else: + mask.extend(chat_tokens) + return tokens, mask + + +class Encoder(object): + def __init__(self, args): + self.args = args + + def initializer(self): + # Use Encoder class as a container for global data + Encoder.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer_path) + + def encode(self, text): + ids = {} + for key in self.args.jsonl_keys: + text_ids, label_ids = build_chat( + text[key], + self.args.generation_role, + not self.args.no_mask, + Encoder.tokenizer, + self.args.only_last, + ) + ids[key] = (text_ids, label_ids) + return ids, len(text) + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="input data") + group.add_argument( + "--input", + type=str, + required=True, + help="Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma separated " + "list", + ) + group.add_argument( + "--jsonl-keys", + nargs="+", + default=["conversation"], + help="space separate listed of keys to extract from jsonl. Default: text", + ) + group.add_argument( + "--no-mask", + help="If set, this will not mask any tokens in the input data.", + action="store_true", + ) + group.add_argument( + "--generation-role", + type=str, + default="assistant", + help="The role of the model generating the chat, usually 'assistant'. Default: assistant", + ) + group.add_argument( + "--only-last", + help="If set, this will mask everything except the last turn in the chat.", + action="store_true", + ) + group.add_argument( + "--num-docs", + default=None, + help="Optional: Number of documents in the input data (if known) for an accurate progress bar.", + type=int, + ) + group = parser.add_argument_group(title="tokenizer") + group.add_argument( + "--tokenizer-path", + type=str, + required=True, + help="Path to HF Tokenizer.", + ) + group.add_argument("--ftfy", action="store_true", help="Use ftfy to clean text") + group = parser.add_argument_group(title="output data") + group.add_argument( + "--output-prefix", + type=str, + required=True, + help="Path to binary output file without suffix", + ) + group.add_argument( + "--dataset-impl", + type=str, + default="mmap", + choices=["lazy", "cached", "mmap"], + help="Dataset implementation to use. Default: mmap", + ) + + group = parser.add_argument_group(title="runtime") + group.add_argument( + "--workers", type=int, default=1, help="Number of worker processes to launch" + ) + group.add_argument( + "--log-interval", + type=int, + default=100, + help="Interval between progress updates", + ) + args = parser.parse_args() + args.keep_empty = False + + # some default/dummy values for the tokenizer + args.rank = 0 + args.make_vocab_size_divisible_by = 128 + args.model_parallel_size = 1 + + return args + + +def yield_from_files(fnames: list, semaphore): + """ + Iterator over input documents using lm_dataformat. Should be able to handle jsons / texts / + other compressed formats. Also filters out empty documents. + + :param fnames: list of filenames + """ + + def yielder(fname, semaphore): + with open(fname, encoding="utf-8") as f: + reader = jsonlines.Reader(f) + for f in reader: + semaphore.acquire() + yield f + + for fname in fnames: + semaphore.acquire() + + yield from yielder(fname, semaphore) + + +def main(): + args = get_args() + encoder = Encoder(args) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + print(f"Vocab size: {tokenizer.vocab_size}") + print(f"Output prefix: {args.output_prefix}") + + # build a semaphore object to stop `yield_from_files` from getting ahead of encoder.encode and + # hence building up memory + semaphore = Semaphore(10000 + args.workers) + + # use multiprocessing to iterate over input documents + fin = yield_from_files(args.input.split(","), semaphore) + + if args.workers > 1: + pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) + encoded_docs = pool.imap(encoder.encode, fin, chunksize=25) + else: + encoder.initializer() + encoded_docs = (encoder.encode(doc) for doc in fin) + + # make a dataset builder for each key in args.jsonl_keys + # each key will output to a different file beginning with args.output_prefix + output_bin_files = {} + output_idx_files = {} + builders = {} + for key in args.jsonl_keys: + output_bin_files[key] = "{}_{}_{}.bin".format( + args.output_prefix, key, "document" + ) + output_idx_files[key] = "{}_{}_{}.idx".format( + args.output_prefix, key, "document" + ) + builders[key] = indexed_dataset.make_builder( + output_bin_files[key], + impl=args.dataset_impl, + vocab_size=tokenizer.vocab_size, + ) + builders[key]._dtype = np.int32 + if not args.no_mask: + assert ( + key + "_label" not in args.jsonl_keys + ), "label should not be included as it will be generated according to the mask." + key += "_label" + output_bin_files[key] = "{}_{}_{}.bin".format( + args.output_prefix, key, "document" + ) + output_idx_files[key] = "{}_{}_{}.idx".format( + args.output_prefix, key, "document" + ) + builders[key] = indexed_dataset.make_builder( + output_bin_files[key], + impl=args.dataset_impl, + vocab_size=tokenizer.vocab_size, + ) + builders[key]._dtype = np.int32 + + # actually do tokenization + proc_start = time.time() + total_bytes_processed = 0 + pbar = tqdm.tqdm() + for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): + total_bytes_processed += bytes_processed + + # release semaphore so `yield_from_files` can add another file to the buffer + semaphore.release() + + # add each tokenized document / sentence + for key, conv in doc.items(): + tokens = conv[0] + token_mask = conv[1] + builders[key].add_item(np.array(tokens, dtype=builders[key].dtype)) + builders[key + "_label"].add_item( + np.array(token_mask, dtype=builders[key + "_label"].dtype) + ) + # add indx... + builders[key].end_document() + builders[key + "_label"].end_document() + if i == 1: + print("key: ", key) + print("tokens: ", tokens) + print("token_mask: ", token_mask) + # log progress + if i % args.log_interval == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed / elapsed / 1024 / 1024 + pbar.set_description( + f"Processed {i}{'' if args.num_docs is None else '/' + str(args.num_docs)} documents ({i / elapsed} docs/s, {mbs} MB/s)." + ) + if i != 0: + pbar.update(args.log_interval) + + # save output file + update_keys = args.jsonl_keys + for key in update_keys: + builders[key].finalize(output_idx_files[key]) + builders[key + "_label"].finalize(output_idx_files[key + "_label"]) + + +if __name__ == "__main__": + main() From e360e24c2905b09458d80f4a2ab8b6eaadf2065a Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 21 Jun 2024 11:29:52 -0500 Subject: [PATCH 02/28] add EOT at end of a chat --- tools/datasets/preprocess_data_with_chat_template.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index df3f3cbae..81770deff 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -113,6 +113,9 @@ def build_chat( mask.extend([-100] * len(chat_tokens)) else: mask.extend(chat_tokens) + if tokenizer.eos_token_id is not None: + mask.append(tokenizer.eos_token_id if mask[-1] != -100 else -100) + tokens.append(tokenizer.eos_token_id) return tokens, mask From 9ee4a8f642983b7e6a030b0a8fbcb80cb5f3c327 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 21 Jun 2024 15:00:51 -0500 Subject: [PATCH 03/28] - add different packing impl (Unpacked, packing until overflow) - fix labels to also have valid/test implementations - fix label masking in _get_batch to also include anything from get_ltor_masks_and_position_ids --- megatron/data/data_utils.py | 37 +++++- megatron/data/gpt2_dataset.py | 188 ++++++++++++++++++++++----- megatron/neox_arguments/neox_args.py | 29 ++++- megatron/training.py | 21 +-- 4 files changed, 225 insertions(+), 50 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index bc5754cdb..7e4dbdb37 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -55,6 +55,8 @@ def build_the_dataset( data_prefix, name, data_impl, + pack_impl, + allow_chopped, num_samples, seq_length, seed, @@ -83,6 +85,8 @@ def build_the_dataset( num_samples, seq_length, seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, build_index_mappings=build_index_mappings, label_dataset=label_dataset, ) @@ -93,6 +97,8 @@ def build_train_valid_test_datasets( data_prefix, use_shared_fs, data_impl, + pack_impl, + allow_chopped, splits_string, train_valid_test_num_samples, seq_length, @@ -138,6 +144,8 @@ def build_dataset(index, name): train_valid_test_num_samples[index], seq_length, seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, use_shared_fs=use_shared_fs, ) return dataset @@ -204,12 +212,25 @@ def build_weighted_datasets( ): # build individual datasets train_datasets, valid_datasets, test_datasets = [], [], [] - for i, (train_path, label_path, valid_path, test_path) in enumerate( + for i, ( + train_path, + train_label_path, + valid_path, + valid_label_path, + test_path, + test_label_path, + ) in enumerate( zip_longest( neox_args.train_data_paths, - neox_args.label_data_paths if neox_args.label_data_paths else [], + neox_args.train_label_data_paths + if neox_args.train_label_data_paths + else [], neox_args.valid_data_paths, + neox_args.valid_label_data_paths + if neox_args.valid_label_data_paths + else [], neox_args.test_data_paths, + neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], ) ): if train_path: @@ -218,12 +239,14 @@ def build_weighted_datasets( data_prefix=train_path, name=f"train_{i}", data_impl=neox_args.data_impl, + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, num_samples=train_num_samples[i], seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, - label_prefix=label_path, + label_prefix=train_label_path, ) ) @@ -233,11 +256,14 @@ def build_weighted_datasets( data_prefix=valid_path, name=f"valid_{i}", data_impl=neox_args.data_impl, + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, num_samples=valid_num_samples[i], seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + label_prefix=valid_label_path, ) ) @@ -247,11 +273,14 @@ def build_weighted_datasets( data_prefix=test_path, name=f"test_{i}", data_impl=neox_args.data_impl, + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, num_samples=test_num_samples[i], seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + label_prefix=test_label_path, ) ) return train_datasets, valid_datasets, test_datasets @@ -414,6 +443,8 @@ def build_train_valid_test_data_iterators(neox_args): seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, ) # Build dataloders. diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 75e601fda..edba57df2 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -36,14 +36,19 @@ def __init__( num_samples, seq_length, seed, + pack_impl="packed", + allow_chopped=True, build_index_mappings=True, use_shared_fs=True, label_dataset=None, ): self.name = name + self.pack_impl = pack_impl + self.allow_chopped = allow_chopped self.indexed_dataset = indexed_dataset self.label_dataset = label_dataset + self.seq_length = seq_length # Checks assert np.min(documents) >= 0 @@ -56,10 +61,13 @@ def __init__( data_prefix, documents, self.indexed_dataset.sizes, + self.label_dataset, num_samples, seq_length, seed, + self.pack_impl, use_shared_fs=use_shared_fs, + allow_chopped=self.allow_chopped, ) self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 self.sample_idx_len = self.sample_idx.shape[0] - 1 @@ -113,8 +121,38 @@ def __getitem__(self, idx): samples.append(np.concatenate(sample_list)) if len(datasets) == 1: + if len(samples[0]) < (self.seq_length + 1): + # Pad with -100s so the masking function can ignore these. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] return {"text": np.array(samples[0], dtype=np.int64)} else: + if len(samples[0]) < (self.seq_length + 1): + # Pad with 0s, can use any number since it's masked. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=0, + ) + # pad with -100s so we can mask it out + samples[1] = np.pad( + samples[1], + (0, (self.seq_length + 1) - len(samples[1])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] + samples[1] = samples[1][: (self.seq_length + 1)] return { "text": np.array(samples[0], dtype=np.int64), "label": np.array(samples[1], dtype=np.int64), @@ -132,10 +170,13 @@ def _build_index_mappings( data_prefix, documents, sizes, + label_dataset, num_samples, seq_length, seed, + packing_impl, use_shared_fs=True, + allow_chopped=True, ): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. @@ -155,6 +196,9 @@ def _build_index_mappings( _filename += "_{}ns".format(num_samples) _filename += "_{}sl".format(seq_length) _filename += "_{}s".format(seed) + _filename += "_{}pi".format(packing_impl) + if allow_chopped: + _filename += "_ac" doc_idx_filename = _filename + "_doc_idx.npy" sample_idx_filename = _filename + "_sample_idx.npy" shuffle_idx_filename = _filename + "_shuffle_idx.npy" @@ -177,44 +221,116 @@ def _build_index_mappings( ) # doc-idx. start_time = time.time() - doc_idx = _build_doc_idx(documents, num_epochs, np_rng) - np.save(doc_idx_filename, doc_idx, allow_pickle=True) - print_rank_0( - " > elapsed time to build and save doc-idx mapping " - "(seconds): {:4f}".format(time.time() - start_time) - ) - # sample-idx. - start_time = time.time() - # Use C++ implementation for speed. - from megatron.data import helpers - - assert doc_idx.dtype == np.int32 - assert sizes.dtype == np.int32 - - num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length - if 2 * (num_samples + 1) < np.iinfo(np.int32).max: - sample_idx = helpers.build_sample_idx_int32( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + if packing_impl == "packed": + doc_idx = _build_doc_idx(documents, num_epochs, np_rng) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) ) - else: - sample_idx = helpers.build_sample_idx_int64( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + from megatron.data import helpers + + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + + num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length + if 2 * (num_samples + 1) < np.iinfo(np.int32).max: + sample_idx = helpers.build_sample_idx_int32( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + ) + else: + sample_idx = helpers.build_sample_idx_int64( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + ) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) ) - np.save(sample_idx_filename, sample_idx, allow_pickle=True) - print_rank_0( - " > elapsed time to build and save sample-idx mapping " - "(seconds): {:4f}".format(time.time() - start_time) - ) - # shuffle-idx. - start_time = time.time() - # -1 is due to data structure used to retrieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) - shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) - np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) - print_rank_0( - " > elapsed time to build and save shuffle-idx mapping" - " (seconds): {:4f}".format(time.time() - start_time) - ) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) + elif packing_impl == "pack_until_overflow": + # Naively pack data until it overflows, then roll it over to a new one instead. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = [] + doc_idx = [] + # Iterate over files until we have enough samples. + temp_shuffle_idx = np.arange(len(documents)) + np_rng.shuffle(temp_shuffle_idx) + running_length = 0 + curr_shuffle_idx = 0 + while len(sample_idx) < num_samples: + if not allow_chopped: + # +1 since we shift left/right by 1 + if sizes[temp_shuffle_idx[curr_shuffle_idx]] > seq_length + 1: + curr_shuffle_idx += 1 + continue + # First, check if we need to skip this item... + if label_dataset is not None: + if np.all( + label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + doc_length = sizes[temp_shuffle_idx[curr_shuffle_idx]] + if running_length == 0: + sample_idx.append(np.array([len(doc_idx), 0])) + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + running_length += doc_length + else: + if running_length + doc_length > (seq_length + 1): + running_length = doc_length + sample_idx.append(np.array([len(doc_idx), 0])) + else: + running_length += doc_length + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + curr_shuffle_idx += 1 + if curr_shuffle_idx == len(documents): + curr_shuffle_idx = 0 + np_rng.shuffle(temp_shuffle_idx) + sample_idx.append(np.array([len(doc_idx), 0])) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + elif packing_impl == "unpacked": + # Unpacked data, one sample per document. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = np.zeros((num_samples + 1, 2), dtype=np.int64) + sample_idx[:, 0] = np.array([i for i in range(num_samples + 1)]) + sample_idx[:, 1] = 0 + doc_idx = list() + doc_i = 0 + while len(doc_idx) <= num_samples: + if not allow_chopped: + # +1 since we shift left/right by 1 + if sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + # Just in case we have bad data in the loop... + if np.all(label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + doc_idx.append(doc_i) + doc_i = (doc_i + 1) % len(documents) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index febefb3c2..6878c79eb 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -848,9 +848,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to train datasets. """ - label_data_paths: list = None + train_label_data_paths: list = None """ - List of paths to label datasets (not shifted by 1 yet!). + List of paths to train label datasets (not shifted by 1 yet!). """ test_data_paths: list = None @@ -858,11 +858,21 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to test datasets. """ + test_label_data_paths: list = None + """ + List of paths to test label datasets (not shifted by 1 yet!). + """ + valid_data_paths: list = None """ List of paths to validation datasets. """ + valid_label_data_paths: list = None + """ + List of paths to validation label datasets (not shifted by 1 yet!). + """ + train_data_weights: list = None """ List of 'weights' that decide how often to sample from each training dataset when blending datasets. If None, defaults to equal weighting. @@ -912,6 +922,21 @@ class NeoXArgsTraining(NeoXArgsTemplate): Implementation of indexed datasets, can be one of "infer", "cached", or "mmap" """ + pack_impl: Literal["packed", "pack_until_overflow", "unpacked"] = "packed" + """ + Packing implementation, can be one of "packed", "pack_until_overflow", or "unpacked". + + warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets + """ + + allow_chopped: bool = True + """ + WARNING: if your packing impl is packed, this is ignored. + + Allow chopped samples in the dataset. + (e.g if your sequence length is 1024 and you have a sample of length 1026, it will be chopped to 1024) + """ + mmap_warmup: bool = False """ Warm up mmap files. diff --git a/megatron/training.py b/megatron/training.py index 3265680c5..482b4154f 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -277,16 +277,19 @@ def pretrain(neox_args): def _get_batch(neox_args, tokenizer, keys, data, datatype): """Support function for get_batch / get_batch pipe (to avoid code repetition)""" data_b = mpu.broadcast_data(keys, data, datatype) - + token_key = keys[0] + label_key = keys[1] if len(keys) > 1 else None # Unpack. - tokens_ = data_b["text"].long() + tokens_ = data_b[token_key].long() if "label" in data_b: + label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous() labels = torch.where( - data_b["label"].long() >= 0, - data_b["label"].long(), + data_b[label_key].long() >= 0, + data_b[label_key].long(), torch.zeros_like(data_b["label"].long()), )[:, 1:].contiguous() else: + label_mask = (tokens_.long() >= 0)[:, 1:].contiguous() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() @@ -297,9 +300,9 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): eod_mask_loss=neox_args.eod_mask_loss, sliding_window_width=neox_args.sliding_window_width, ) - # If `label` is present, any token < 0 (e.g., -100, the default for torch) skips the loss computation - if "label" in data_b: - loss_mask = (data_b["label"][:, 1:] >= 0).to(loss_mask.dtype) + + # combine loss masks from get_ltor_masks_and_position_ids with loss masks from data + loss_mask = label_mask.to(loss_mask.dtype) * loss_mask return tokens, labels, loss_mask, attention_mask, position_ids @@ -307,7 +310,7 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] datatype = torch.int64 # Broadcast data. @@ -327,7 +330,7 @@ def get_batch(neox_args, data_iterator): def get_batch_pipe(data, neox_args, curr_scheduler=None): """A modification of get_batch() to work with the latest batch instead of an iterator.""" # Items and their type. - keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] datatype = torch.int64 tokens, labels, loss_mask, attention_mask, position_ids = _get_batch( From 0678573a4f3e4940b2b030f2d4add4fc4a1c61f5 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 21 Jun 2024 15:08:23 -0500 Subject: [PATCH 04/28] update README.md --- tools/datasets/README.md | 51 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tools/datasets/README.md b/tools/datasets/README.md index f8215959c..af3009a23 100644 --- a/tools/datasets/README.md +++ b/tools/datasets/README.md @@ -93,6 +93,57 @@ output data: --dataset-impl {lazy,cached,mmap} Dataset implementation to use. Default: mmap +runtime: + --workers WORKERS Number of worker processes to launch + --log-interval LOG_INTERVAL + Interval between progress updates +``` +## `preprocess_data_with_chat_template.py` +Similar, but uses huggingface's [chat templates](https://huggingface.co/docs/transformers/main/en/chat_templating) to +tokenize the data to support multiturn and more complicated use cases. + +N.B. If using this, you **must** specify your data when training/finetuning with the following configs +```json +"train_data_paths": ["train_documents"], +"test_data_paths": ["test_documents"], +"valid_data_paths": ["test_documents"], +"label_data_paths": ["label_documents"] +``` + +the `"data_path"` option will not work with `"label_data_paths"`. + + +``` +usage: preprocess_data_with_chat_template.py [-h] --input INPUT [--jsonl-keys JSONL_KEYS [JSONL_KEYS ...]] [--no-mask] + [--generation-role GENERATION_ROLE] [--only-last] [--num-docs NUM_DOCS] + --tokenizer-path TOKENIZER_PATH [--ftfy] --output-prefix OUTPUT_PREFIX + [--dataset-impl {lazy,cached,mmap}] [--workers WORKERS] + [--log-interval LOG_INTERVAL] + +options: + -h, --help show this help message and exit + +input data: + --input INPUT Path to input jsonl files or lmd archive(s) - if using multiple archives, put them in a comma separated list + --jsonl-keys JSONL_KEYS [JSONL_KEYS ...] + space separate listed of keys to extract from jsonl. Default: text + --no-mask If set, this will not mask any tokens in the input data. + --generation-role GENERATION_ROLE + The role of the model generating the chat, usually 'assistant'. Default: assistant + --only-last If set, this will mask everything except the last turn in the chat. + --num-docs NUM_DOCS Optional: Number of documents in the input data (if known) for an accurate progress bar. + +tokenizer: + --tokenizer-path TOKENIZER_PATH + Path to HF Tokenizer. + --ftfy Use ftfy to clean text + +output data: + --output-prefix OUTPUT_PREFIX + Path to binary output file without suffix + --dataset-impl {lazy,cached,mmap} + Dataset implementation to use. Default: mmap + runtime: --workers WORKERS Number of worker processes to launch --log-interval LOG_INTERVAL From 2d20d86526f0714a475434f16fe9bc9ad7d48e8c Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 24 Jun 2024 20:27:37 -0500 Subject: [PATCH 05/28] - Add metrics to forward step to add DPO specific metrics that are useful (accuracy, etc) - Add reference model setup for DPO - Add pairwise dataset for positive/negative pairs - Add DPO loss --- megatron/data/data_utils.py | 159 ++++++-- megatron/data/pairwise_dataset.py | 585 +++++++++++++++++++++++++++ megatron/neox_arguments/neox_args.py | 56 +++ megatron/training.py | 267 +++++++++--- megatron/utils.py | 2 +- 5 files changed, 994 insertions(+), 75 deletions(-) create mode 100644 megatron/data/pairwise_dataset.py diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 7e4dbdb37..2c548077d 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -23,6 +23,7 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.blendable_dataset import BlendableDataset from megatron.data.gpt2_dataset import GPT2Dataset +from megatron.data.pairwise_dataset import PairwiseDataset from megatron.data.samplers import DistributedBatchSampler @@ -53,9 +54,12 @@ def make_data_loader(dataset, neox_args): def build_the_dataset( data_prefix, + pos_data_prefix, + neg_data_prefix, name, data_impl, pack_impl, + dataset_impl, allow_chopped, num_samples, seq_length, @@ -63,33 +67,92 @@ def build_the_dataset( skip_warmup, build_index_mappings=True, label_prefix=None, + pos_label_prefix=None, + neg_label_prefix=None, + pos_ref_prefix=None, + neg_ref_prefix=None, ): """Build train/valid/test datasets.""" - - indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) - if label_prefix is None: - label_dataset = None + if dataset_impl == "gpt2": + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + if label_prefix is None: + label_dataset = None + else: + label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) + elif dataset_impl == "pairwise": + pos_indexed_dataset = make_indexed_dataset( + pos_data_prefix, data_impl, skip_warmup + ) + neg_indexed_dataset = make_indexed_dataset( + neg_data_prefix, data_impl, skip_warmup + ) + if pos_label_prefix is None: + pos_label_dataset = None + # Also do neg here since they both must be the same + assert neg_label_prefix is None + neg_label_dataset = None + else: + pos_label_dataset = make_indexed_dataset( + pos_label_prefix, data_impl, skip_warmup + ) + # Also do neg here since they both must be the same + assert neg_label_prefix is not None + neg_label_dataset = make_indexed_dataset( + neg_label_prefix, data_impl, skip_warmup + ) + if pos_ref_prefix is not None: + pos_ref_dataset = make_indexed_dataset( + pos_ref_prefix, data_impl, skip_warmup + ) + # Also do neg here since they both must be the same + assert neg_ref_prefix is not None + neg_ref_dataset = make_indexed_dataset( + neg_ref_prefix, data_impl, skip_warmup + ) else: - label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) + raise NotImplementedError(f"dataset_impl={dataset_impl} not implemented") - total_num_of_documents = indexed_dataset.sizes.shape[0] + total_num_of_documents = ( + indexed_dataset.sizes.shape[0] + if dataset_impl == "gpt2" + else pos_indexed_dataset.sizes.shape[0] + ) print_rank_0(" {}:".format(name)) print_rank_0(" no. of documents:{}".format(total_num_of_documents)) dataset = None documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) - dataset = GPT2Dataset( - name, - data_prefix, - documents, - indexed_dataset, - num_samples, - seq_length, - seed, - pack_impl=pack_impl, - allow_chopped=allow_chopped, - build_index_mappings=build_index_mappings, - label_dataset=label_dataset, - ) + if dataset_impl == "gpt2": + dataset = GPT2Dataset( + name, + data_prefix, + documents, + indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, + build_index_mappings=build_index_mappings, + label_dataset=label_dataset, + ) + elif dataset_impl == "pairwise": + dataset = PairwiseDataset( + name, + pos_data_prefix, + documents, + pos_indexed_dataset, + neg_indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, + build_index_mappings=build_index_mappings, + pos_label_dataset=pos_label_dataset, + neg_label_dataset=neg_label_dataset, + pos_ref_dataset=pos_ref_dataset, + neg_ref_dataset=neg_ref_dataset, + ) return dataset @@ -135,7 +198,6 @@ def build_dataset(index, name): documents = np.arange( start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 ) - dataset = GPT2Dataset( name, data_prefix, @@ -219,18 +281,54 @@ def build_weighted_datasets( valid_label_path, test_path, test_label_path, + pos_train_path, + neg_train_path, + pos_train_label_path, + neg_train_label_path, + pos_valid_path, + neg_valid_path, + pos_valid_label_path, + neg_valid_label_path, + pos_test_path, + neg_test_path, + pos_test_label_path, + neg_test_label_path, ) in enumerate( zip_longest( - neox_args.train_data_paths, + neox_args.train_data_paths if neox_args.train_data_paths else [], neox_args.train_label_data_paths if neox_args.train_label_data_paths else [], - neox_args.valid_data_paths, + neox_args.valid_data_paths if neox_args.valid_data_paths else [], neox_args.valid_label_data_paths if neox_args.valid_label_data_paths else [], - neox_args.test_data_paths, + neox_args.test_data_paths if neox_args.pos_train_data_paths else [], neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], + neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], + neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], + neox_args.pos_train_label_data_paths + if neox_args.pos_train_label_data_paths + else [], + neox_args.neg_train_label_data_paths + if neox_args.neg_train_label_data_paths + else [], + neox_args.pos_valid_data_paths if neox_args.pos_valid_data_paths else [], + neox_args.neg_valid_data_paths if neox_args.neg_valid_data_paths else [], + neox_args.pos_valid_label_data_paths + if neox_args.pos_valid_label_data_paths + else [], + neox_args.neg_valid_label_data_paths + if neox_args.neg_valid_label_data_paths + else [], + neox_args.pos_test_data_paths if neox_args.pos_test_data_paths else [], + neox_args.neg_test_data_paths if neox_args.neg_test_data_paths else [], + neox_args.pos_test_label_data_paths + if neox_args.pos_test_label_data_paths + else [], + neox_args.neg_test_label_data_paths + if neox_args.neg_test_label_data_paths + else [], ) ): if train_path: @@ -247,6 +345,11 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=train_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_train_path, + neg_data_prefix=neg_train_path, + pos_label_prefix=pos_train_label_path, + neg_label_prefix=neg_train_label_path, ) ) @@ -264,6 +367,11 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=valid_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_valid_path, + neg_data_prefix=neg_valid_path, + pos_label_prefix=pos_valid_label_path, + neg_label_prefix=neg_valid_label_path, ) ) @@ -281,6 +389,11 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=test_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_test_path, + neg_data_prefix=neg_test_path, + pos_label_prefix=pos_test_label_path, + neg_label_prefix=neg_test_label_path, ) ) return train_datasets, valid_datasets, test_datasets diff --git a/megatron/data/pairwise_dataset.py b/megatron/data/pairwise_dataset.py new file mode 100644 index 000000000..b59218f08 --- /dev/null +++ b/megatron/data/pairwise_dataset.py @@ -0,0 +1,585 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pairwise style dataset.""" + +import os +import time + +import numpy as np +import torch + +from megatron import mpu, print_rank_0 + + +class PairwiseDataset(torch.utils.data.Dataset): + def __init__( + self, + name, + pos_data_prefix, # Don't need neg since it's assumed you have paired the data already. + documents, + pos_indexed_dataset, + neg_indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl="unpacked", + build_index_mappings=True, + use_shared_fs=True, + pos_label_dataset=None, + pos_ref_dataset=None, + neg_label_dataset=None, + neg_ref_dataset=None, + allow_chopped=True, + ): + + self.name = name + self.pos_indexed_dataset = pos_indexed_dataset + self.pos_label_dataset = pos_label_dataset + self.pos_ref_dataset = pos_ref_dataset + self.neg_indexed_dataset = neg_indexed_dataset + self.neg_label_dataset = neg_label_dataset + self.neg_ref_dataset = neg_ref_dataset + self.pack_impl = pack_impl + self.seq_length = seq_length + # Checks + assert np.min(documents) >= 0 + assert (neg_label_dataset is not None and pos_label_dataset is not None) or ( + neg_label_dataset is None and pos_label_dataset is None + ), "Label datasets must be both None or both not None" + assert np.max(documents) < pos_indexed_dataset.sizes.shape[0] + assert pos_indexed_dataset.sizes.shape[0] == neg_indexed_dataset.sizes.shape[0] + assert ( + pack_impl != "packed" + ), "Packed implementation not supported for pairwise dataset" + + if build_index_mappings: + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, + pos_data_prefix, + documents, + self.pos_indexed_dataset.sizes, + self.neg_indexed_dataset.sizes, + self.pos_label_dataset, + self.neg_label_dataset, + num_samples, + seq_length, + seed, + pack_impl, + use_shared_fs=use_shared_fs, + allow_chopped=allow_chopped, + ) + self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 + self.sample_idx_len = self.sample_idx.shape[0] - 1 + + if self.shuffle_idx_len != self.sample_idx_len - 1: + print( + f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" + ) + + def __len__(self): + return min(self.shuffle_idx_len, self.sample_idx_len) + + def __getitem__(self, idx): + try: + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # Labels and texts are supposed to be fully in sync. + datasets = ( + [self.pos_indexed_dataset, self.neg_indexed_dataset] + if self.pos_label_dataset is None + else [ + self.pos_indexed_dataset, + self.neg_indexed_dataset, + self.pos_label_dataset, + self.neg_label_dataset, + ] + ) + samples = [] + pos_ref_samples = [] + neg_ref_samples = [] + # If we are within the same document, just extract the chunk. + for n, dataset in enumerate(datasets): + if doc_index_f == doc_index_l: + samples.append( + dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + ) + if n == 0: + if self.pos_ref_dataset is not None: + pos_ref_samples.append( + self.pos_ref_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + ) + neg_ref_samples.append( + self.neg_ref_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + ) + + else: + # Otherwise, get the rest of the initial document. + sample_list = [ + dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] + + if n == 0: + if self.pos_ref_dataset is not None: + pos_ref_sample_list = [ + self.pos_ref_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + ) + ] + neg_ref_sample_list = [ + self.neg_ref_dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + ) + ] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(dataset.get(self.doc_idx[i])) + if n == 0: + if self.pos_ref_dataset is not None: + pos_ref_sample_list.append( + self.pos_ref_dataset.get( + self.doc_idx[i], + ) + ) + neg_ref_sample_list.append( + self.neg_ref_dataset.get( + self.doc_idx[i], + ) + ) + # And finally add the relevant portion of last document. + sample_list.append( + dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) + samples.append(np.concatenate(sample_list)) + if n == 0: + if self.pos_ref_dataset is not None: + pos_ref_sample_list.append( + self.pos_ref_dataset.get( + self.doc_idx[doc_index_l], length=offset_l + 1 + ) + ) + pos_ref_samples.append(np.concatenate(pos_ref_sample_list)) + neg_ref_sample_list.append( + self.neg_ref_dataset.get( + self.doc_idx[doc_index_l], length=offset_l + 1 + ) + ) + neg_ref_samples.append(np.concatenate(neg_ref_sample_list)) + if self.pos_ref_dataset is not None: + if len(pos_ref_samples[0]) < (self.seq_length): + # Pad with 0s + pos_ref_samples[0] = np.pad( + pos_ref_samples[0], + (0, (self.seq_length) - len(pos_ref_samples[0])), + mode="constant", + constant_values=0, + ) + elif len(pos_ref_samples[0]) > (self.seq_length): + # Check for overflow and truncate. + pos_ref_samples[0] = pos_ref_samples[0][: (self.seq_length)] + if len(neg_ref_samples[0]) < (self.seq_length): + # Pad with 0s + neg_ref_samples[0] = np.pad( + neg_ref_samples[0], + (0, (self.seq_length) - len(neg_ref_samples[0])), + mode="constant", + constant_values=0, + ) + elif len(neg_ref_samples[0]) > (self.seq_length): + # Check for overflow and truncate. + neg_ref_samples[0] = neg_ref_samples[0][: (self.seq_length)] + if len(datasets) == 2: + # pos + if len(samples[0]) < (self.seq_length + 1): + # Pad with -100s so the masking function can ignore these. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] + # neg + if len(samples[1]) < (self.seq_length + 1): + # Pad with -100s so the masking function can ignore these. + samples[1] = np.pad( + samples[1], + (0, (self.seq_length + 1) - len(samples[1])), + mode="constant", + constant_values=-100, + ) + elif len(samples[1]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[1] = samples[1][: (self.seq_length + 1)] + ret = { + "pos": np.array(samples[0], dtype=np.int64), + "neg": np.array(samples[1], dtype=np.int64), + } + if self.pos_ref_dataset is not None: + ret["pos_ref"] = np.array(pos_ref_samples[0], dtype=np.float32) + ret["neg_ref"] = np.array(neg_ref_samples[0], dtype=np.float32) + return ret + else: + # pos + if len(samples[0]) < (self.seq_length + 1): + # Pad with 0s, can use any number since it's masked. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=0, + ) + # pad with -100s so we can mask it out + samples[2] = np.pad( + samples[2], + (0, (self.seq_length + 1) - len(samples[2])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] + samples[2] = samples[2][: (self.seq_length + 1)] + # neg + if len(samples[1]) < (self.seq_length + 1): + # Pad with 0s, can use any number since it's masked. + samples[1] = np.pad( + samples[1], + (0, (self.seq_length + 1) - len(samples[1])), + mode="constant", + constant_values=0, + ) + # pad with -100s so we can mask it out + samples[3] = np.pad( + samples[3], + (0, (self.seq_length + 1) - len(samples[3])), + mode="constant", + constant_values=-100, + ) + elif len(samples[1]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[1] = samples[1][: (self.seq_length + 1)] + samples[3] = samples[3][: (self.seq_length + 1)] + ret = { + "pos": np.array(samples[0], dtype=np.int64), + "neg": np.array(samples[1], dtype=np.int64), + "pos_label": np.array(samples[2], dtype=np.int64), + "neg_label": np.array(samples[3], dtype=np.int64), + } + if self.pos_ref_dataset is not None: + ret["pos_ref"] = np.array(pos_ref_samples[0], dtype=np.float32) + ret["neg_ref"] = np.array(neg_ref_samples[0], dtype=np.float32) + return ret + except IndexError: + new_idx = idx % len(self) + print( + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + ) + return self[new_idx] + + +def _build_index_mappings( + name, + pos_data_prefix, + documents, + pos_sizes, + neg_sizes, + pos_label_dataset, + neg_label_dataset, + num_samples, + seq_length, + seed, + packing_impl, + use_shared_fs=True, + allow_chopped=True, +): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, pos_sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = pos_data_prefix + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + _filename += "_{}pi".format(packing_impl) + doc_idx_filename = _filename + "_doc_idx.npy" + sample_idx_filename = _filename + "_sample_idx.npy" + shuffle_idx_filename = _filename + "_shuffle_idx.npy" + + if not use_shared_fs: + should_process_dataset = int(os.environ["LOCAL_RANK"]) == 0 + else: + should_process_dataset = torch.distributed.get_rank() == 0 + + # Build the indexed mapping if not exist. + if should_process_dataset: + if ( + (not os.path.isfile(doc_idx_filename)) + or (not os.path.isfile(sample_idx_filename)) + or (not os.path.isfile(shuffle_idx_filename)) + ): + print_rank_0( + " > WARNING: could not find index map files, building " + "the indices on rank 0 ..." + ) + # doc-idx. + start_time = time.time() + if packing_impl == "pack_until_overflow": + # Naively pack data until it overflows, then roll it over to a new one instead. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = [] + doc_idx = [] + # Iterate over files until we have enough samples. + temp_shuffle_idx = np.arange(len(documents)) + np_rng.shuffle(temp_shuffle_idx) + running_length = 0 + curr_shuffle_idx = 0 + while len(sample_idx) < num_samples: + # If not allow_chopped, skip this item if it's chopped. + if not allow_chopped: + if ( + pos_sizes[temp_shuffle_idx[curr_shuffle_idx]] + < seq_length + 1 + ): + curr_shuffle_idx += 1 + continue + if ( + neg_sizes[temp_shuffle_idx[curr_shuffle_idx]] + < seq_length + 1 + ): + curr_shuffle_idx += 1 + continue + # Then, check if we need to skip this item... + if pos_label_dataset is not None: + if np.all( + pos_label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + if np.all( + neg_label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + doc_length = max( + pos_sizes[temp_shuffle_idx[curr_shuffle_idx]], + neg_sizes[temp_shuffle_idx[curr_shuffle_idx]], + ) + if running_length == 0: + sample_idx.append(np.array([len(doc_idx), 0])) + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + running_length += doc_length + else: + if running_length + doc_length > (seq_length + 1): + running_length = doc_length + sample_idx.append(np.array([len(doc_idx), 0])) + else: + running_length += doc_length + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + curr_shuffle_idx += 1 + if curr_shuffle_idx == len(documents): + curr_shuffle_idx = 0 + np_rng.shuffle(temp_shuffle_idx) + sample_idx.append(np.array([len(doc_idx), 0])) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + elif packing_impl == "unpacked": + # Unpacked data, one sample per document. + shuffle_idx = np.array([i % len(documents) for i in range(num_samples)]) + np_rng.shuffle(shuffle_idx) + sample_idx = np.zeros((num_samples + 1, 2), dtype=np.int64) + sample_idx[:, 0] = np.array([i for i in range(num_samples + 1)]) + sample_idx[:, 1] = 0 + doc_idx = list() + doc_i = 0 + while len(doc_idx) <= num_samples: + # Check if we need to skip this item... + if not allow_chopped: + # +1 since we shift left/right by 1 + if pos_sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + if neg_sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + # In theory if we don't allow chopped we should be able to skip it, but the warm fuzzies I get + # from this are worth the extra bool check + if np.all(pos_label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + if np.all(neg_label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + doc_idx.append(doc_i) + doc_i = (doc_i + 1) % len(documents) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_io_parallel_group() + ) + + # Load mappings. + start_time = time.time() + print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading sample-idx mapping from {}".format(sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence length, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng): + """Build an array with length = number-of-epochs * number-of-documents. + Each index is mapped to a corresponding document.""" + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int64) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Beginning offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += remaining_seq_length + doc_length - 1 + remaining_seq_length = 0 + else: + # Otherwise, start from the beginning of the next document. + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(size, np_rng): + """Build the range [0, size) and shuffle.""" + dtype_ = np.uint32 + if size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx) + return shuffle_idx diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 6878c79eb..7b1a60d46 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -873,6 +873,42 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to validation label datasets (not shifted by 1 yet!). """ + pos_train_data_paths: list = None + neg_train_data_paths: list = None + """ + List of paths to positive and negative training datasets. + """ + + pos_train_label_data_paths: list = None + neg_train_label_data_paths: list = None + """ + List of paths to positive and negative training label datasets (not shifted by 1 yet!). + """ + + pos_valid_data_paths: list = None + neg_valid_data_paths: list = None + """ + List of paths to positive and negative validation datasets. + """ + + pos_valid_label_data_paths: list = None + neg_valid_label_data_paths: list = None + """ + List of paths to positive and negative validation label datasets (not shifted by 1 yet!). + """ + + pos_test_data_paths: list = None + neg_test_data_paths: list = None + """ + List of paths to positive and negative test datasets. + """ + + pos_test_label_data_paths: list = None + neg_test_label_data_paths: list = None + """ + List of paths to positive and negative test label datasets (not shifted by 1 yet!). + """ + train_data_weights: list = None """ List of 'weights' that decide how often to sample from each training dataset when blending datasets. If None, defaults to equal weighting. @@ -929,6 +965,26 @@ class NeoXArgsTraining(NeoXArgsTemplate): warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets """ + dataset_impl: Literal["gpt2", "pairwise"] = "gpt2" + """ + Dataset implementation, can be one of "gpt2" or "pairwise" + """ + + train_impl: Literal["normal", "dpo"] = "normal" + """ + Training implementation, can be one of "normal" or "dpo" + """ + + dpo_fp32: bool = True + """ + Whether to cast logits to fp32 for DPO loss calculation. + """ + + dpo_beta: float = 0.1 + """ + Beta value for DPO + """ + allow_chopped: bool = True """ WARNING: if your packing impl is packed, this is ignored. diff --git a/megatron/training.py b/megatron/training.py index 482b4154f..3c6a6b506 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -21,12 +21,14 @@ """Pretrain utilities.""" from datetime import datetime from functools import partial +from collections import defaultdict import math import sys from contextlib import nullcontext import torch +import torch.nn.functional as F import deepspeed from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler import numpy as np @@ -44,6 +46,7 @@ SoftEmbedding, get_params_for_weight_decay_optimization, ) +from megatron.mpu.mappings import gather_from_model_parallel_region from megatron.checkpointing import load_checkpoint, save_checkpoint from megatron.data.data_utils import build_train_valid_test_data_iterators from megatron.initialize import initialize_megatron @@ -136,7 +139,7 @@ def gen(): old_hidden_size = neox_args.hidden_size neox_args.hidden_size = hidden_size - model, optimizer, _ = setup_model_and_optimizer( + model, optimizer, _, _ = setup_model_and_optimizer( neox_args=neox_args, use_cache=False ) @@ -192,7 +195,7 @@ def pretrain(neox_args): # Model, optimizer, and learning rate. timers("model and optimizer").start() - model, optimizer, lr_scheduler = setup_model_and_optimizer( + model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer( neox_args=neox_args, use_cache=False, iteration=neox_args.iteration ) timers("model and optimizer").stop() @@ -230,6 +233,7 @@ def pretrain(neox_args): neox_args=neox_args, timers=timers, model=model, + reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, train_data_iterator=train_data_iterator, @@ -310,7 +314,14 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] + if neox_args.train_impl == "normal": + keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + elif neox_args.train_impl == "dpo": + keys = ( + [["pos", "pos_label"], ["neg", "neg_label"]] + if neox_args.pos_label_data_paths + else [["pos"], ["neg"]] + ) datatype = torch.int64 # Broadcast data. @@ -318,13 +329,33 @@ def get_batch(neox_args, data_iterator): data = next(data_iterator) else: data = None - return _get_batch( - neox_args=neox_args, - tokenizer=neox_args.tokenizer, - keys=keys, - data=data, - datatype=datatype, - ) + if neox_args.train_type == "normal": + return _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys, + data=data, + datatype=datatype, + ) + elif neox_args.train_type == "dpo": + pos_tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys[0], + data=data, + datatype=datatype, + ) + neg_tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys[1], + data=data, + datatype=datatype, + ) + return [ + torch.cat((pos_item, neg_item), dim=0) + for pos_item, neg_item in zip(pos_tup, neg_tup) + ] def get_batch_pipe(data, neox_args, curr_scheduler=None): @@ -418,8 +449,23 @@ def mb_moe_loss_func(args, loss_mask, output_tensor=None): return averaged_lbl, loss_dict +def get_pos_neg_logp(logits, labels, force_fp32=False): + if force_fp32: + logits = logits.float() + logp = logits.log_softmax(dim=-1) + per_token_logp = torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + # Split to pos/neg... + return torch.chunk(per_token_logp, 2, 0) + + def forward_step( - data_iterator, model, neox_args, timers, return_logits=False, is_train=False + data_iterator, + model, + neox_args, + timers, + return_logits=False, + is_train=False, + reference_model=None, ): """Forward step.""" if neox_args.is_pipe_parallel: @@ -441,38 +487,97 @@ def forward_step( if neox_args.memory_profiling: torch.cuda.nvtx.range_push(f"Forward pass") - # Sequential returns moe_losses, but this is not yet supported by pipe parallel - maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) - if type(maybe_tuple) is tuple: - outputs, moe_losses = maybe_tuple - else: - outputs = maybe_tuple - moe_losses = [] - if ( - is_train - and neox_args.curriculum_learning - and neox_args.curriculum_seqlen < neox_args.seq_length - ): - loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() - labels = labels[:, : neox_args.curriculum_seqlen].contiguous() - main_loss = cross_entropy( - outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy - ) - if neox_args.moe_num_experts > 1: - if neox_args.moe_type == "deepspeed": - moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) - elif neox_args.moe_type == "megablocks": - moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] + metrics = {} + if neox_args.train_impl == "normal": + # Sequential returns moe_losses, but this is not yet supported by pipe parallel + maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) + if type(maybe_tuple) is tuple: + outputs, moe_losses = maybe_tuple else: - raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") - else: - moe_loss = 0.0 - loss = main_loss + moe_loss + outputs = maybe_tuple + moe_losses = [] + if ( + is_train + and neox_args.curriculum_learning + and neox_args.curriculum_seqlen < neox_args.seq_length + ): + loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() + labels = labels[:, : neox_args.curriculum_seqlen].contiguous() + main_loss = cross_entropy( + outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy + ) + if neox_args.moe_num_experts > 1: + if neox_args.moe_type == "deepspeed": + moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) + elif neox_args.moe_type == "megablocks": + moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] + else: + raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") + else: + moe_loss = 0.0 + loss = main_loss + moe_loss + elif neox_args.train_type == "dpo": + # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 + with torch.no_grad(): + # So we can gather token logps... + token_logp_labels = labels.clone() + token_logp_labels[token_logp_labels == -100] = 0 + pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) + ref_maybe_tuple = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + ref_outputs, _ = ref_maybe_tuple + else: + ref_outputs = ref_maybe_tuple + # gather across tensor parallel group + ref_outputs = gather_from_model_parallel_region(ref_outputs) + ref_pos, ref_neg = get_pos_neg_logp( + ref_outputs, token_logp_labels, neox_args.dpo_fp32 + ) + ref_pos = (ref_pos * pos_loss_mask).sum(-1) + ref_neg = (ref_neg * neg_loss_mask).sum(-1) + chosen_maybe_tuple = model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(chosen_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + chosen_outputs, _ = chosen_maybe_tuple + else: + chosen_outputs = chosen_maybe_tuple + chosen_outputs = gather_from_model_parallel_region(chosen_outputs) + chosen_pos, chosen_neg = get_pos_neg_logp( + chosen_outputs, token_logp_labels, neox_args.dpo_fp32 + ) + chosen_pos = (chosen_pos * pos_loss_mask).sum(-1) + chosen_neg = (chosen_neg * neg_loss_mask).sum(-1) + with torch.no_grad(): + # Collect metrics... + metrics["ref_neg"] = ref_neg.clone().detach().mean() + metrics["ref_pos"] = ref_pos.clone().detach().mean() + metrics["chosen_neg"] = chosen_neg.clone().detach().mean() + metrics["chosen_pos"] = chosen_pos.clone().detach().mean() + chosen_rewards = neox_args.dpo_beta * ( + chosen_pos.clone().detach() - ref_pos.clone().detach() + ) + rejected_rewards = neox_args.dpo_beta * ( + chosen_neg.clone().detach() - ref_neg.clone().detach() + ) + reward_acc = (chosen_rewards > rejected_rewards).float() + metrics["reward_acc"] = reward_acc.mean() + metrics["chosen_rewards"] = chosen_rewards.mean() + metrics["rejected_rewards"] = rejected_rewards.mean() + metrics["margins"] = (chosen_rewards - rejected_rewards).mean() + pi_logrations = chosen_pos - chosen_neg + ref_logrations = ref_pos - ref_neg + logits = pi_logrations - ref_logrations + loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean() if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: - return loss, outputs - return loss + return loss, outputs, metrics + return loss, metrics def get_model(neox_args, use_cache=False): @@ -547,7 +652,7 @@ def get_model(neox_args, use_cache=False): raise ValueError("Must be using deepspeed to run neox") -def get_optimizer(model, neox_args): +def get_optimizer(model, neox_args, dummy=False): """Set up the optimizer.""" if neox_args.no_load_optim: return None, None @@ -583,8 +688,13 @@ def get_optimizer(model, neox_args): _param_groups = [] for param_group in param_groups: trainable_params = [p for p in param_group["params"] if p.requires_grad] + if dummy: + trainable_params = [trainable_params[0]] # just take the first one param_group["params"] = trainable_params _param_groups.append(param_group) + if dummy: + # Only need one. + break param_groups = _param_groups # If we're using mup, then the optimizer must be adam or sgd @@ -743,10 +853,24 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" + needs_reference_model = neox_args.train_type == "dpo" model = get_model(neox_args=neox_args, use_cache=use_cache) + if needs_reference_model: + reference_model = get_model(neox_args=neox_args, use_cache=use_cache) + else: + reference_model = None optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args) lr_scheduler = get_learning_rate_scheduler(optimizer=optimizer, neox_args=neox_args) - + if neox_args.deepspeed and needs_reference_model: + # Need an optimizer & lr_scheduler so make a very small one to keep deepspeed happy... + ref_optimizer, ref_param_groups = get_optimizer( + model=reference_model, neox_args=neox_args, dummy=True + ) + ref_lr_scheduler = get_learning_rate_scheduler( + optimizer=ref_optimizer, neox_args=neox_args + ) + else: + ref_optimizer, ref_param_groups, ref_lr_scheduler = None, None, None if neox_args.deepspeed: print_rank_0("DeepSpeed is enabled.") if neox_args.no_load_optim: @@ -768,6 +892,16 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + if needs_reference_model: + reference_model, _, _, _ = deepspeed.initialize( + model=reference_model, + optimizer=ref_optimizer, + args=neox_args, + lr_scheduler=ref_lr_scheduler, + dist_init_required=False, + model_parameters=ref_param_groups, + mpu=mpu if not neox_args.is_pipe_parallel else None, + ) if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. model.has_moe_layers = True @@ -799,10 +933,19 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): neox_args.iteration = load_checkpoint( neox_args=neox_args, model=model, + reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, iteration=iteration, ) + if needs_reference_model: + _ = load_checkpoint( + neox_args=neox_args, + model=reference_model, + optimizer=ref_optimizer, + lr_scheduler=ref_lr_scheduler, + iteration=iteration, + ) print_rank_0( f"Loading checkpoint and starting from iteration {neox_args.iteration}" ) @@ -814,7 +957,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): if lr_scheduler is not None: lr_scheduler.optimizer = model.optimizer - return model, optimizer, lr_scheduler + return model, optimizer, lr_scheduler, reference_model def backward_step(neox_args, timers, optimizer, model, loss): @@ -836,7 +979,15 @@ def backward_step(neox_args, timers, optimizer, model, loss): raise ValueError("Must be using deepspeed to run neox") -def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler): +def train_step( + neox_args, + timers, + data_iterator, + model, + optimizer, + lr_scheduler, + reference_model=None, +): """Single training step.""" # Pipeline parallelism schedules forward/backward/step @@ -844,6 +995,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) reduced_loss = train_step_pipe( neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator ) + reduced_metrics = {"lm_loss": reduced_loss} if ( neox_args.memory_profiling and neox_args.iteration >= neox_args.profile_step_start @@ -853,18 +1005,22 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) save_snapshot(neox_args) else: losses = [] + metric_dicts = defaultdict(list) for _ in range(neox_args.gradient_accumulation_steps): # Forward model for one step. timers("forward").start() - loss = forward_step( + loss, metric_dict = forward_step( neox_args=neox_args, timers=timers, data_iterator=data_iterator, model=model, is_train=True, + reference_model=reference_model, ) timers("forward").stop() losses.append(loss) + for key in metric_dict.keys(): + metric_dicts[key].append(metric_dict[key]) # Calculate gradients, reduce across processes, and clip. if ( neox_args.profile @@ -913,17 +1069,20 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) and torch.distributed.get_rank() == 0 ): save_snapshot(neox_args) - reduced_loss = { - "lm_loss": reduce_losses(losses).mean() - } # reduces losses across machines for logging + # reduces metrics across machines for logging + reduce_metrics = { + key: reduce_losses([metric_dicts[key]]).mean() + for key in metric_dicts.keys() + } + reduce_metrics["lm_loss"] = reduce_losses(losses).mean() if neox_args.precision == "fp16" and model.optimizer.overflow: skipped_iter = 1 else: skipped_iter = 0 - collect_loss_for_unit_test(reduced_loss["lm_loss"]) - return reduced_loss, skipped_iter + collect_loss_for_unit_test(reduce_metrics["lm_loss"]) + return reduce_metrics, skipped_iter def train_step_pipe(neox_args, timers, model, data_iterator): @@ -949,6 +1108,7 @@ def train( neox_args, timers, model, + reference_model, optimizer, lr_scheduler, train_data_iterator, @@ -1004,6 +1164,7 @@ def train( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, + reference_model=reference_model, ) if neox_args.profile and iteration == neox_args.profile_step_stop: torch.cuda.cudart().cudaProfilerStop() @@ -1094,6 +1255,7 @@ def evaluate( # Turn on evaluation mode which disables dropout. model.eval() losses = [] + metric_dicts = defaultdict(list) if neox_args.char_level_ppl: data_iterator = CharCounter(data_iterator, neox_args.tokenizer) @@ -1115,14 +1277,15 @@ def evaluate( else neox_args.gradient_accumulation_steps ): # Forward evaluation - loss = forward_step_fn( + loss, metric_dict = forward_step_fn( model=model, data_iterator=data_iterator, neox_args=neox_args, timers=timers, ) losses.append(loss) - + for key in metric_dict.keys(): + metric_dicts[key].append(metric_dict[key]) # When contiguous memory optimizations are enabled, the buffers # allocated by the optimizations are deallocated during backward pass # in the absence of backward pass the buffers should be reset after each @@ -1132,6 +1295,8 @@ def evaluate( # reduces losses across processes for logging & run eval harness tasks eval_results = {"lm_loss": reduce_losses(losses).mean().item()} + for key in metric_dicts.keys(): + eval_results[key] = reduce_losses(metric_dicts[key]).mean().item() eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"]) if neox_args.char_level_ppl: diff --git a/megatron/utils.py b/megatron/utils.py index 26b4439bd..a64a8ba6c 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -449,7 +449,7 @@ def setup_for_inference_or_eval(use_cache=True, overwrite_values=None, input_arg initialize_megatron(neox_args) # set up model and load checkpoint. - model, _, _ = setup_model_and_optimizer( + model, _, _, _ = setup_model_and_optimizer( neox_args=neox_args, use_cache=use_cache, iteration=neox_args.iteration, From c0450064aff6ac15d59b2937d02a64b8801e7b08 Mon Sep 17 00:00:00 2001 From: dmahan93 <44207705+dmahan93@users.noreply.github.com> Date: Tue, 25 Jun 2024 10:07:47 -0500 Subject: [PATCH 06/28] Update arguments.py to use train_label_data_paths instead of label_data_paths --- megatron/neox_arguments/arguments.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 9cad02c43..770ec50b4 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1115,9 +1115,9 @@ def calculate_derived(self): if self.test_data_paths and (self.test_data_weights is None): self.test_data_weights = [1.0] * len(self.test_data_paths) - if self.label_data_paths: + if self.train_label_data_paths: err_str = ( - "Must use `label_data_paths` with `train_data_paths`, not `data_path`" + "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`" ) assert self.train_data_paths and not self.data_path, err_str From 03920805a5a3af3a695c3e6fc2b33478f4d88d75 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Tue, 25 Jun 2024 12:08:27 -0500 Subject: [PATCH 07/28] - Bugfixes from upstreaming.... --- megatron/data/data_utils.py | 15 +++++++++------ megatron/neox_arguments/arguments.py | 18 ++++++++++++------ megatron/training.py | 18 ++++++++---------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 2c548077d..58e7953a1 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -100,7 +100,10 @@ def build_the_dataset( neg_label_dataset = make_indexed_dataset( neg_label_prefix, data_impl, skip_warmup ) - if pos_ref_prefix is not None: + if pos_ref_prefix is None: + pos_ref_dataset = None + neg_ref_dataset = None + else: pos_ref_dataset = make_indexed_dataset( pos_ref_prefix, data_impl, skip_warmup ) @@ -303,7 +306,7 @@ def build_weighted_datasets( neox_args.valid_label_data_paths if neox_args.valid_label_data_paths else [], - neox_args.test_data_paths if neox_args.pos_train_data_paths else [], + neox_args.test_data_paths if neox_args.test_data_paths else [], neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], @@ -331,7 +334,7 @@ def build_weighted_datasets( else [], ) ): - if train_path: + if train_path or pos_train_path: train_datasets.append( build_the_dataset( data_prefix=train_path, @@ -353,7 +356,7 @@ def build_weighted_datasets( ) ) - if valid_path: + if valid_path or pos_valid_path: valid_datasets.append( build_the_dataset( data_prefix=valid_path, @@ -375,7 +378,7 @@ def build_weighted_datasets( ) ) - if test_path: + if test_path or pos_test_path: test_datasets.append( build_the_dataset( data_prefix=test_path, @@ -465,7 +468,7 @@ def build_train_valid_test_data_iterators(neox_args): test_iters * neox_args.train_batch_size, ] - if neox_args.train_data_paths: + if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths): # when individual train / valid / test data paths are provided # normalize weight values and get num samples for each dataset train_weights, train_num_samples = get_normalized_weights_and_num_samples( diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 770ec50b4..a89aa04a6 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -794,7 +794,9 @@ def calculate_batch_parameters( # either none of the three parameters are provided or just gradient_accumulation_step is provided else: - assert False, "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided" + assert ( + False + ), "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided" return int(train_batch), int(micro_batch), int(grad_acc) @staticmethod @@ -1098,8 +1100,8 @@ def calculate_derived(self): if "flash" in self.attention_config: _flash_version = packaging.version.Version(version("flash-attn")) if self.sliding_window_width is not None: - assert ( - _flash_version >= packaging.version.Version("2.3.0") + assert _flash_version >= packaging.version.Version( + "2.3.0" ), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention." if self.pos_emb == "alibi": if not _flash_version >= packaging.version.Version("2.4.0.post1"): @@ -1110,15 +1112,19 @@ def calculate_derived(self): # Adding equal dataset weights if none are provided if self.train_data_paths and (self.train_data_weights is None): self.train_data_weights = [1.0] * len(self.train_data_paths) + elif self.pos_train_data_paths and (self.train_data_weights is None): + self.train_data_weights = [1.0] * len(self.pos_train_data_paths) if self.valid_data_paths and (self.valid_data_weights is None): self.valid_data_weights = [1.0] * len(self.valid_data_paths) + elif self.pos_valid_data_paths and (self.valid_data_weights is None): + self.valid_data_weights = [1.0] * len(self.pos_valid_data_paths) if self.test_data_paths and (self.test_data_weights is None): self.test_data_weights = [1.0] * len(self.test_data_paths) + elif self.pos_test_data_paths and (self.test_data_weights is None): + self.test_data_weights = [1.0] * len(self.pos_test_data_paths) if self.train_label_data_paths: - err_str = ( - "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`" - ) + err_str = "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`" assert self.train_data_paths and not self.data_path, err_str # if a sample input file is provided, default text_gen_type type to input-file diff --git a/megatron/training.py b/megatron/training.py index 3c6a6b506..b578c4ad9 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -285,12 +285,12 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): label_key = keys[1] if len(keys) > 1 else None # Unpack. tokens_ = data_b[token_key].long() - if "label" in data_b: + if label_key in data_b: label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous() labels = torch.where( data_b[label_key].long() >= 0, data_b[label_key].long(), - torch.zeros_like(data_b["label"].long()), + torch.zeros_like(data_b[label_key].long()), )[:, 1:].contiguous() else: label_mask = (tokens_.long() >= 0)[:, 1:].contiguous() @@ -319,7 +319,7 @@ def get_batch(neox_args, data_iterator): elif neox_args.train_impl == "dpo": keys = ( [["pos", "pos_label"], ["neg", "neg_label"]] - if neox_args.pos_label_data_paths + if neox_args.pos_train_label_data_paths else [["pos"], ["neg"]] ) datatype = torch.int64 @@ -329,7 +329,7 @@ def get_batch(neox_args, data_iterator): data = next(data_iterator) else: data = None - if neox_args.train_type == "normal": + if neox_args.train_impl == "normal": return _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -337,7 +337,7 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) - elif neox_args.train_type == "dpo": + elif neox_args.train_impl == "dpo": pos_tup = _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -516,7 +516,7 @@ def forward_step( else: moe_loss = 0.0 loss = main_loss + moe_loss - elif neox_args.train_type == "dpo": + elif neox_args.train_impl == "dpo": # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 with torch.no_grad(): # So we can gather token logps... @@ -853,7 +853,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" - needs_reference_model = neox_args.train_type == "dpo" + needs_reference_model = neox_args.train_impl == "dpo" model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: reference_model = get_model(neox_args=neox_args, use_cache=use_cache) @@ -933,7 +933,6 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): neox_args.iteration = load_checkpoint( neox_args=neox_args, model=model, - reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, iteration=iteration, @@ -1071,8 +1070,7 @@ def train_step( save_snapshot(neox_args) # reduces metrics across machines for logging reduce_metrics = { - key: reduce_losses([metric_dicts[key]]).mean() - for key in metric_dicts.keys() + key: reduce_losses(metric_dicts[key]).mean() for key in metric_dicts.keys() } reduce_metrics["lm_loss"] = reduce_losses(losses).mean() From 361f4597b61c92ca34cda76cf93bbd6ed69ddcea Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Tue, 25 Jun 2024 15:30:29 -0500 Subject: [PATCH 08/28] - add precompute logprobs... --- generate.py | 3 + megatron/neox_arguments/neox_args.py | 7 +- megatron/text_generation_utils.py | 192 ++++++++++++++++++++++++++- megatron/training.py | 18 +-- 4 files changed, 207 insertions(+), 13 deletions(-) diff --git a/generate.py b/generate.py index 743e350d0..e19ef2e0e 100755 --- a/generate.py +++ b/generate.py @@ -23,6 +23,7 @@ generate_samples_from_prompt, generate_samples_unconditional, generate_samples_interactive, + precompute_logits, ) @@ -83,6 +84,8 @@ def main(input_args=None, overwrite_values=None): top_p=neox_args.top_p, ) + elif neox_args.text_gen_type == "precompute": + precompute_logits(neox_args=neox_args, model=model) else: raise ValueError( f"`text_gen_type` either not specified or not recognised: {neox_args.text_gen_type}" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 7b1a60d46..3ce8b881a 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1281,7 +1281,12 @@ class NeoXArgsTextgen(NeoXArgsTemplate): text_gen_type: str = None """ How to generate text/sample the model. - Options: `unconditional`, `input-file`, `interactive` + Options: `unconditional`, `input-file`, `interactive`, `precompute` + """ + + precompute_model_name: str = None + """ + Model name to use for saving precomputed logprobs """ temperature: float = 0.0 diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 7b7a390ab..02926c2c3 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -23,12 +23,15 @@ import time from typing import List, Union +import numpy as np import torch import torch.nn.functional as F from megatron import print_rank_0 from megatron import mpu from megatron.utils import get_ltor_masks_and_position_ids, is_mp_rank_0 +from megatron.data.indexed_dataset import make_builder, make_dataset +from megatron.mpu.mappings import gather_from_model_parallel_region def get_batch(neox_args, context_tokens: torch.Tensor): @@ -52,7 +55,9 @@ def get_batch(neox_args, context_tokens: torch.Tensor): return tokens, attention_mask, position_ids -def pad_batch(context_tokens: List[List[int]], pad_id: int, pad_len: int): +def pad_batch( + context_tokens: List[List[int]], pad_id: int, pad_len: int, truncate: bool = False +): """ pads context lengths in context_tokens with pad_id to equal neox_args.seq_length, and returns the padded batch and the new lengths. @@ -60,17 +65,21 @@ def pad_batch(context_tokens: List[List[int]], pad_id: int, pad_len: int): context_tokens: list of lists of tokens pad_id: int, integer to use as padding token pad_len: int, context length to be padded; all batch items will be padded to the same length + truncate: bool, if True, truncate context tokens to pad_len if they are longer than pad_len returns: tuple of padded context tokens and a list of unpadded token count """ context_lengths = [] - for tokens in context_tokens: + for i, tokens in enumerate(context_tokens): context_length = len(tokens) if context_length < pad_len: tokens.extend([pad_id] * (pad_len - context_length)) elif context_length > pad_len: - raise ValueError("context_length is bigger than to be padded length") + if not truncate: + raise ValueError("context_length is bigger than to be padded length") + context_tokens[i] = tokens[:pad_len] + context_length = pad_len context_lengths.append(context_length) return context_tokens, context_lengths @@ -807,3 +816,180 @@ def generate_samples_interactive( print_rank_0("Generated Text: " + generated_text) if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: _ = input("\n") + + +def get_logp(logits, labels, force_fp32=False): + if force_fp32: + logits = logits.float() + logp = logits.log_softmax(dim=-1) + return torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + +def precompute_logits(neox_args, model): + """ + Precomputes logprobs from training/testing/validation datasets + + Saves it to the same directory as the dataset with the model name appended to it + + neox_args: NeoXArgs. + model: a Megatron model + + """ + if neox_args.precompute_model_name is None: + mdl_name = str(hash(neox_args.load)) + else: + mdl_name = neox_args.precompute_model_name + print_rank_0("Precomputing logprobs...") + model.eval() + data_paths = list() + if neox_args.train_data_paths is not None: + for path in neox_args.train_data_paths: + data_paths.append(path) + for path in neox_args.test_data_paths: + data_paths.append(path) + for path in neox_args.valid_data_paths: + data_paths.append(path) + elif neox_args.pos_train_data_paths is not None: + # Pairwise data... + for path in neox_args.pos_train_data_paths: + data_paths.append(path) + for path in neox_args.neg_train_data_paths: + data_paths.append(path) + for path in neox_args.pos_valid_data_paths: + data_paths.append(path) + for path in neox_args.neg_valid_data_paths: + data_paths.append(path) + for path in neox_args.pos_test_data_paths: + data_paths.append(path) + for path in neox_args.neg_test_data_paths: + data_paths.append(path) + for path in data_paths: + print_rank_0(f"Precomputing logits for {path}") + # Add hash to path... + out_path = path + f"_{mdl_name}" + if os.path.exists(out_path + ".idx"): + continue + dataset = make_dataset(path, neox_args.data_impl, not neox_args.mmap_warmup) + if is_mp_rank_0(): + out_dataset = make_builder(out_path + ".bin", neox_args.data_impl) + out_dataset._dtype = np.float32 + i = 0 + while i < len(dataset): + start = time.time() + model.module.clear_cache() # clear kv cache between batches + if is_mp_rank_0(): + offset = ( + mpu.get_data_parallel_rank() + * neox_args.train_micro_batch_size_per_gpu + ) + context_tokens = [ + [int(x) for x in dataset.get(j % len(dataset)).tolist()] + for j in range( + i + offset, + i + (neox_args.train_micro_batch_size_per_gpu + offset), + ) + ] + # grab microbatch + # pad batch in order to allow conversion to tensor + context_tokens, context_lengths = pad_batch( + copy.deepcopy(context_tokens), + pad_id=0, + pad_len=neox_args.seq_length + 1, + truncate=True, + ) + # print(context_tokens) + label_tokens = [tokens[1:] for tokens in context_tokens] + context_tokens = [tokens[:-1] for tokens in context_tokens] + else: + context_tokens = [ + [0 for _ in range(neox_args.seq_length)] + for _ in range(neox_args.batch_size) + ] + label_tokens = [ + [0 for _ in range(neox_args.seq_length)] + for _ in range(neox_args.batch_size) + ] + context_lengths = [0 for _ in range(neox_args.batch_size)] + i += ( + neox_args.train_micro_batch_size_per_gpu + * mpu.get_data_parallel_world_size() + ) + # print(context_tokens) + # convert to tensor and broadcast + context_tokens = torch.cuda.LongTensor(context_tokens) + label_tokens = torch.cuda.LongTensor(label_tokens) + # Make sure context tokens + start tokens are the same across all ranks + token_generation_start_index = torch.cuda.LongTensor(context_lengths) + torch.distributed.broadcast( + context_tokens, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + torch.distributed.broadcast( + token_generation_start_index, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + torch.distributed.broadcast( + label_tokens, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + # context_tokens = context_tokens[:, :chop_len].contiguous() + # label_tokens = label_tokens[:, :chop_len].contiguous() + with torch.no_grad(): + # get attention mask / position ids + context_tokens, attention_mask, position_ids = get_batch( + neox_args, context_tokens + ) + model_inputs = ( + context_tokens, + position_ids, + attention_mask, + ) + maybe_tuple = forward_model( + model, model_inputs, neox_args.is_pipe_parallel + ) + if isinstance(maybe_tuple, tuple): + logits, _ = maybe_tuple + else: + logits = maybe_tuple + if logits is not None: # if pipe parallel, not all ranks return logits + logits = gather_from_model_parallel_region(logits) + logp = get_logp(logits, label_tokens, True).squeeze() + if neox_args.is_pipe_parallel: + # broadcast generated tokens to pipe parallel group + src_rank = model.grid.stage_to_global(model.num_stages - 1) + logp = ( + logp + if logits is not None + else torch.zeros( + neox_args.batch_size, dtype=torch.float32 + ).cuda() + ) + torch.distributed.broadcast( + tensor=logp, + src=src_rank, + group=mpu.get_pipe_parallel_group(), + ) + logp = logp.squeeze() + logp_list = [ + torch.zeros_like(logp) + for _ in range(mpu.get_data_parallel_world_size()) + ] + torch.distributed.all_gather( + logp_list, logp, group=mpu.get_data_parallel_group() + ) + logp = torch.cat(logp_list, dim=0).cpu().numpy() + if (mpu.get_model_parallel_rank() == 0) and ( + mpu.get_data_parallel_rank() == 0 + ): + for j in range(logp.shape[0]): + out_dataset.add_item(logp[j]) + out_dataset.end_document() + print_rank_0(f"Processed {i} / {len(dataset)} in {time.time() - start}") + if is_mp_rank_0(): + out_dataset.finalize( + out_path + ".idx", + ) + torch.distributed.barrier() diff --git a/megatron/training.py b/megatron/training.py index b578c4ad9..ed01996e5 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -654,7 +654,12 @@ def get_model(neox_args, use_cache=False): def get_optimizer(model, neox_args, dummy=False): """Set up the optimizer.""" - if neox_args.no_load_optim: + if neox_args.no_load_optim and neox_args.deepspeed: + # Required to have something so... + dummy = True + neox_args.optimizer = {"params": {"lr": 0.0}} + neox_args.optimizer_type = "adam" + elif neox_args.no_load_optim: return None, None if neox_args.optimizer is None: @@ -808,7 +813,7 @@ def get_optimizer(model, neox_args, dummy=False): def get_learning_rate_scheduler(optimizer, neox_args): """Build the learning rate scheduler.""" - if neox_args.no_load_optim: + if (neox_args.no_load_optim) and not neox_args.deepspeed: # TODO: this should be configured as a separate arg return None if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam": @@ -873,13 +878,8 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ref_optimizer, ref_param_groups, ref_lr_scheduler = None, None, None if neox_args.deepspeed: print_rank_0("DeepSpeed is enabled.") - if neox_args.no_load_optim: - assert optimizer is None - _model_params = None - _lr_scheduler = None - else: - _model_params = param_groups if optimizer is None else None - _lr_scheduler = lr_scheduler + _model_params = param_groups if optimizer is None else None + _lr_scheduler = lr_scheduler model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model, From 7398e072ab5577e465df8c38b5ea247dbfe376da Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Tue, 25 Jun 2024 19:24:33 -0500 Subject: [PATCH 09/28] - Finishing up precompute logprobs... --- megatron/data/data_utils.py | 20 ++- megatron/data/pairwise_dataset.py | 212 ++++++------------------------ megatron/training.py | 52 +++++--- 3 files changed, 91 insertions(+), 193 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 58e7953a1..edf48b066 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -69,8 +69,7 @@ def build_the_dataset( label_prefix=None, pos_label_prefix=None, neg_label_prefix=None, - pos_ref_prefix=None, - neg_ref_prefix=None, + precompute_model_name=None, ): """Build train/valid/test datasets.""" if dataset_impl == "gpt2": @@ -79,6 +78,12 @@ def build_the_dataset( label_dataset = None else: label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) + if precompute_model_name is not None: + # If we have the name, assume it exists. If it doesn't, it will just be None which is fine. + precompute_indexed_dataset = make_indexed_dataset( + data_prefix + "_" + precompute_model_name, data_impl, skip_warmup + ) + precompute_indexed_dataset = precompute_indexed_dataset elif dataset_impl == "pairwise": pos_indexed_dataset = make_indexed_dataset( pos_data_prefix, data_impl, skip_warmup @@ -100,17 +105,15 @@ def build_the_dataset( neg_label_dataset = make_indexed_dataset( neg_label_prefix, data_impl, skip_warmup ) - if pos_ref_prefix is None: + if precompute_model_name is None: pos_ref_dataset = None neg_ref_dataset = None else: pos_ref_dataset = make_indexed_dataset( - pos_ref_prefix, data_impl, skip_warmup + pos_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup ) - # Also do neg here since they both must be the same - assert neg_ref_prefix is not None neg_ref_dataset = make_indexed_dataset( - neg_ref_prefix, data_impl, skip_warmup + neg_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup ) else: raise NotImplementedError(f"dataset_impl={dataset_impl} not implemented") @@ -353,6 +356,7 @@ def build_weighted_datasets( neg_data_prefix=neg_train_path, pos_label_prefix=pos_train_label_path, neg_label_prefix=neg_train_label_path, + precompute_model_name=neox_args.precompute_model_name, ) ) @@ -375,6 +379,7 @@ def build_weighted_datasets( neg_data_prefix=neg_valid_path, pos_label_prefix=pos_valid_label_path, neg_label_prefix=neg_valid_label_path, + precompute_model_name=neox_args.precompute_model_name, ) ) @@ -397,6 +402,7 @@ def build_weighted_datasets( neg_data_prefix=neg_test_path, pos_label_prefix=pos_test_label_path, neg_label_prefix=neg_test_label_path, + precompute_model_name=neox_args.precompute_model_name, ) ) return train_datasets, valid_datasets, test_datasets diff --git a/megatron/data/pairwise_dataset.py b/megatron/data/pairwise_dataset.py index b59218f08..e39b4d626 100644 --- a/megatron/data/pairwise_dataset.py +++ b/megatron/data/pairwise_dataset.py @@ -105,16 +105,18 @@ def __getitem__(self, idx): offset_f = self.sample_idx[idx][1] offset_l = self.sample_idx[idx + 1][1] # Labels and texts are supposed to be fully in sync. - datasets = ( - [self.pos_indexed_dataset, self.neg_indexed_dataset] - if self.pos_label_dataset is None - else [ - self.pos_indexed_dataset, - self.neg_indexed_dataset, + datasets = [self.pos_indexed_dataset, self.neg_indexed_dataset] + + if self.pos_label_dataset is not None: + datasets += [ self.pos_label_dataset, self.neg_label_dataset, ] - ) + if self.pos_ref_dataset is not None: + datasets += [ + self.pos_ref_dataset, + self.neg_ref_dataset, + ] samples = [] pos_ref_samples = [] neg_ref_samples = [] @@ -128,184 +130,54 @@ def __getitem__(self, idx): length=offset_l - offset_f + 1, ) ) - if n == 0: - if self.pos_ref_dataset is not None: - pos_ref_samples.append( - self.pos_ref_dataset.get( - self.doc_idx[doc_index_f], - offset=offset_f, - length=offset_l - offset_f + 1, - ) - ) - neg_ref_samples.append( - self.neg_ref_dataset.get( - self.doc_idx[doc_index_f], - offset=offset_f, - length=offset_l - offset_f + 1, - ) - ) - else: # Otherwise, get the rest of the initial document. sample_list = [ dataset.get(self.doc_idx[doc_index_f], offset=offset_f) ] - - if n == 0: - if self.pos_ref_dataset is not None: - pos_ref_sample_list = [ - self.pos_ref_dataset.get( - self.doc_idx[doc_index_f], - offset=offset_f, - ) - ] - neg_ref_sample_list = [ - self.neg_ref_dataset.get( - self.doc_idx[doc_index_f], - offset=offset_f, - ) - ] # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): sample_list.append(dataset.get(self.doc_idx[i])) - if n == 0: - if self.pos_ref_dataset is not None: - pos_ref_sample_list.append( - self.pos_ref_dataset.get( - self.doc_idx[i], - ) - ) - neg_ref_sample_list.append( - self.neg_ref_dataset.get( - self.doc_idx[i], - ) - ) # And finally add the relevant portion of last document. sample_list.append( dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) ) samples.append(np.concatenate(sample_list)) - if n == 0: - if self.pos_ref_dataset is not None: - pos_ref_sample_list.append( - self.pos_ref_dataset.get( - self.doc_idx[doc_index_l], length=offset_l + 1 - ) - ) - pos_ref_samples.append(np.concatenate(pos_ref_sample_list)) - neg_ref_sample_list.append( - self.neg_ref_dataset.get( - self.doc_idx[doc_index_l], length=offset_l + 1 - ) - ) - neg_ref_samples.append(np.concatenate(neg_ref_sample_list)) - if self.pos_ref_dataset is not None: - if len(pos_ref_samples[0]) < (self.seq_length): - # Pad with 0s - pos_ref_samples[0] = np.pad( - pos_ref_samples[0], - (0, (self.seq_length) - len(pos_ref_samples[0])), - mode="constant", - constant_values=0, - ) - elif len(pos_ref_samples[0]) > (self.seq_length): - # Check for overflow and truncate. - pos_ref_samples[0] = pos_ref_samples[0][: (self.seq_length)] - if len(neg_ref_samples[0]) < (self.seq_length): - # Pad with 0s - neg_ref_samples[0] = np.pad( - neg_ref_samples[0], - (0, (self.seq_length) - len(neg_ref_samples[0])), - mode="constant", - constant_values=0, - ) - elif len(neg_ref_samples[0]) > (self.seq_length): - # Check for overflow and truncate. - neg_ref_samples[0] = neg_ref_samples[0][: (self.seq_length)] - if len(datasets) == 2: - # pos - if len(samples[0]) < (self.seq_length + 1): - # Pad with -100s so the masking function can ignore these. - samples[0] = np.pad( - samples[0], - (0, (self.seq_length + 1) - len(samples[0])), - mode="constant", - constant_values=-100, - ) - elif len(samples[0]) > (self.seq_length + 1): - # Check for overflow and truncate. - samples[0] = samples[0][: (self.seq_length + 1)] - # neg - if len(samples[1]) < (self.seq_length + 1): - # Pad with -100s so the masking function can ignore these. - samples[1] = np.pad( - samples[1], - (0, (self.seq_length + 1) - len(samples[1])), - mode="constant", - constant_values=-100, - ) - elif len(samples[1]) > (self.seq_length + 1): - # Check for overflow and truncate. - samples[1] = samples[1][: (self.seq_length + 1)] - ret = { - "pos": np.array(samples[0], dtype=np.int64), - "neg": np.array(samples[1], dtype=np.int64), - } - if self.pos_ref_dataset is not None: - ret["pos_ref"] = np.array(pos_ref_samples[0], dtype=np.float32) - ret["neg_ref"] = np.array(neg_ref_samples[0], dtype=np.float32) - return ret - else: - # pos - if len(samples[0]) < (self.seq_length + 1): - # Pad with 0s, can use any number since it's masked. - samples[0] = np.pad( - samples[0], - (0, (self.seq_length + 1) - len(samples[0])), - mode="constant", - constant_values=0, - ) - # pad with -100s so we can mask it out - samples[2] = np.pad( - samples[2], - (0, (self.seq_length + 1) - len(samples[2])), - mode="constant", - constant_values=-100, - ) - elif len(samples[0]) > (self.seq_length + 1): - # Check for overflow and truncate. - samples[0] = samples[0][: (self.seq_length + 1)] - samples[2] = samples[2][: (self.seq_length + 1)] - # neg - if len(samples[1]) < (self.seq_length + 1): - # Pad with 0s, can use any number since it's masked. - samples[1] = np.pad( - samples[1], - (0, (self.seq_length + 1) - len(samples[1])), - mode="constant", - constant_values=0, - ) - # pad with -100s so we can mask it out - samples[3] = np.pad( - samples[3], - (0, (self.seq_length + 1) - len(samples[3])), - mode="constant", - constant_values=-100, - ) - elif len(samples[1]) > (self.seq_length + 1): + for i in range(len(samples)): + if len(samples[i]) < (self.seq_length + 1): + if ((i == 2) or (i == 3)) and self.pos_label_dataset is not None: + # Labels... So pad with -100 + samples[i] = np.pad( + samples[i], + (0, (self.seq_length + 1) - len(samples[i])), + mode="constant", + constant_values=-100, + ) + else: + # Pad with 0s, can use any number since it's masked. + samples[i] = np.pad( + samples[i], + (0, (self.seq_length + 1) - len(samples[i])), + mode="constant", + constant_values=0, + ) + elif len(samples[i]) > (self.seq_length + 1): # Check for overflow and truncate. - samples[1] = samples[1][: (self.seq_length + 1)] - samples[3] = samples[3][: (self.seq_length + 1)] - ret = { - "pos": np.array(samples[0], dtype=np.int64), - "neg": np.array(samples[1], dtype=np.int64), - "pos_label": np.array(samples[2], dtype=np.int64), - "neg_label": np.array(samples[3], dtype=np.int64), - } + samples[i] = samples[i][: (self.seq_length + 1)] + ret = {} + ret["pos"] = np.array(samples[0], dtype=np.int64) + ret["neg"] = np.array(samples[1], dtype=np.int64) + if self.pos_label_dataset is not None: + ret["pos_label"] = np.array(samples[2], dtype=np.int64) + ret["neg_label"] = np.array(samples[3], dtype=np.int64) if self.pos_ref_dataset is not None: - ret["pos_ref"] = np.array(pos_ref_samples[0], dtype=np.float32) - ret["neg_ref"] = np.array(neg_ref_samples[0], dtype=np.float32) - return ret + ret["pos_ref"] = np.array(samples[4], dtype=np.float32) + ret["neg_ref"] = np.array(samples[5], dtype=np.float32) + elif self.pos_ref_dataset is not None: + # Don't have labels... + ret["pos_ref"] = np.array(samples[2], dtype=np.float32) + ret["neg_ref"] = np.array(samples[3], dtype=np.float32) + return ret except IndexError: new_idx = idx % len(self) print( diff --git a/megatron/training.py b/megatron/training.py index ed01996e5..ee691fb70 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -352,9 +352,19 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) + if neox_args.precompute_model_name: + ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float) + else: + ref_data = {"pos_ref": None} return [ torch.cat((pos_item, neg_item), dim=0) for pos_item, neg_item in zip(pos_tup, neg_tup) + ] + [ + torch.cat((ref_data["pos_ref"], ref_data["neg_ref"]), dim=0)[ + :, :-1 + ].contiguous() + if ref_data["pos_ref"] is not None + else None ] @@ -476,9 +486,14 @@ def forward_step( torch.cuda.nvtx.range_push(f"Get batch") if timers is not None: timers("batch generator").start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - neox_args=neox_args, data_iterator=data_iterator - ) + if neox_args.train_impl == "normal": + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + neox_args=neox_args, data_iterator=data_iterator + ) + if neox_args.train_impl == "dpo": + tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( + neox_args=neox_args, data_iterator=data_iterator + ) if timers is not None: timers("batch generator").stop() @@ -523,19 +538,22 @@ def forward_step( token_logp_labels = labels.clone() token_logp_labels[token_logp_labels == -100] = 0 pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) - ref_maybe_tuple = reference_model( - (tokens, position_ids, attention_mask), neox_args=neox_args - ) - if type(ref_maybe_tuple) is tuple: - # We should ignore MoE losses yeah? - ref_outputs, _ = ref_maybe_tuple + if ref_logp is None: + ref_maybe_tuple = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + ref_outputs, _ = ref_maybe_tuple + else: + ref_outputs = ref_maybe_tuple + # gather across tensor parallel group + ref_outputs = gather_from_model_parallel_region(ref_outputs) + ref_pos, ref_neg = get_pos_neg_logp( + ref_outputs, token_logp_labels, neox_args.dpo_fp32 + ) else: - ref_outputs = ref_maybe_tuple - # gather across tensor parallel group - ref_outputs = gather_from_model_parallel_region(ref_outputs) - ref_pos, ref_neg = get_pos_neg_logp( - ref_outputs, token_logp_labels, neox_args.dpo_fp32 - ) + ref_pos, ref_neg = torch.chunk(ref_logp, 2, 0) ref_pos = (ref_pos * pos_loss_mask).sum(-1) ref_neg = (ref_neg * neg_loss_mask).sum(-1) chosen_maybe_tuple = model( @@ -858,7 +876,9 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" - needs_reference_model = neox_args.train_impl == "dpo" + needs_reference_model = (neox_args.train_impl == "dpo") and ( + neox_args.precompute_model_name is None + ) model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: reference_model = get_model(neox_args=neox_args, use_cache=use_cache) From 51af71487fd465eeb6a79c90df81a64d1facadb8 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Tue, 25 Jun 2024 19:33:27 -0500 Subject: [PATCH 10/28] - update readme for DPO... --- configs/README.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/configs/README.md b/configs/README.md index e14274b56..3102a34d1 100644 --- a/configs/README.md +++ b/configs/README.md @@ -235,6 +235,33 @@ Additional DeepSpeed settings besides those mentioned above should be wrapped in "eval_iters": 10, ``` +However, if you want to use DPO style training you'll need to set pos/neg data paths instead of a single one, e.g. + +```yaml + "dataset_impl": "pairwise", + "train_impl": "dpo", + "pack_impl": "unpacked", + "dpo_beta": 0.1, + "dpo_fp32": true, + "pos_train_data_path": "data/enwik8/enwik8_text_pos_document", + "pos_valid_data_path": "data/enwik8/enwik8_text_pos_document", + "pos_test_data_path": "data/enwik8/enwik8_text_pos_document", + "neg_train_data_path": "data/enwik8/enwik8_text_neg_document", + "neg_valid_data_path": "data/enwik8/enwik8_text_neg_document", + "neg_test_data_path": "data/enwik8/enwik8_text_neg_document", + ## If you have labels... (likely to mask out user turns) + "pos_train_label_data_path": "data/enwik8/enwik8_text_pos_label_document", + "pos_valid_label_data_path": "data/enwik8/enwik8_text_pos_label_document", + "pos_test_label_data_path": "data/enwik8/enwik8_text_pos_label_document", + "neg_train_label_data_path": "data/enwik8/enwik8_text_neg_label_document", + "neg_valid_label_data_path": "data/enwik8/enwik8_text_neg_label_document", + "neg_test_label_data_path": "data/enwik8/enwik8_text_neg_label_document", + ## If you want to precompute the logits over your dataset... + "precompute_model_name": "gpt2", + ## Needed for the generation.py step, if precomputing + "text_gen_type": "precompute" +``` + ### LR Scheduler settings ```yaml From b7bc196dec5e3b880bfb38cd288eaeba6984991b Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 28 Jun 2024 09:55:23 -0500 Subject: [PATCH 11/28] - Add KTO --- megatron/data/data_utils.py | 22 +++ megatron/data/gpt2_dataset.py | 141 ++++++++------- megatron/neox_arguments/neox_args.py | 39 ++++- megatron/text_generation_utils.py | 4 +- megatron/training.py | 160 ++++++++++++++++-- .../preprocess_data_with_chat_template.py | 62 ++++++- 6 files changed, 349 insertions(+), 79 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index edf48b066..f3c23dc4d 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -70,6 +70,7 @@ def build_the_dataset( pos_label_prefix=None, neg_label_prefix=None, precompute_model_name=None, + reward_prefix=None, ): """Build train/valid/test datasets.""" if dataset_impl == "gpt2": @@ -84,6 +85,10 @@ def build_the_dataset( data_prefix + "_" + precompute_model_name, data_impl, skip_warmup ) precompute_indexed_dataset = precompute_indexed_dataset + else: + precompute_indexed_dataset = None + if reward_prefix is not None: + reward_dataset = make_indexed_dataset(reward_prefix, data_impl, skip_warmup) elif dataset_impl == "pairwise": pos_indexed_dataset = make_indexed_dataset( pos_data_prefix, data_impl, skip_warmup @@ -140,6 +145,8 @@ def build_the_dataset( allow_chopped=allow_chopped, build_index_mappings=build_index_mappings, label_dataset=label_dataset, + reward_dataset=reward_dataset, + ref_dataset=precompute_indexed_dataset, ) elif dataset_impl == "pairwise": dataset = PairwiseDataset( @@ -283,10 +290,13 @@ def build_weighted_datasets( for i, ( train_path, train_label_path, + train_reward_path, valid_path, valid_label_path, + valid_reward_path, test_path, test_label_path, + test_reward_path, pos_train_path, neg_train_path, pos_train_label_path, @@ -305,12 +315,21 @@ def build_weighted_datasets( neox_args.train_label_data_paths if neox_args.train_label_data_paths else [], + neox_args.train_reward_data_paths + if neox_args.train_reward_data_paths + else [], neox_args.valid_data_paths if neox_args.valid_data_paths else [], neox_args.valid_label_data_paths if neox_args.valid_label_data_paths else [], + neox_args.valid_reward_data_paths + if neox_args.valid_reward_data_paths + else [], neox_args.test_data_paths if neox_args.test_data_paths else [], neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], + neox_args.test_reward_data_paths + if neox_args.test_reward_data_paths + else [], neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], neox_args.pos_train_label_data_paths @@ -357,6 +376,7 @@ def build_weighted_datasets( pos_label_prefix=pos_train_label_path, neg_label_prefix=neg_train_label_path, precompute_model_name=neox_args.precompute_model_name, + reward_prefix=train_reward_path, ) ) @@ -380,6 +400,7 @@ def build_weighted_datasets( pos_label_prefix=pos_valid_label_path, neg_label_prefix=neg_valid_label_path, precompute_model_name=neox_args.precompute_model_name, + reward_prefix=valid_reward_path, ) ) @@ -403,6 +424,7 @@ def build_weighted_datasets( pos_label_prefix=pos_test_label_path, neg_label_prefix=neg_test_label_path, precompute_model_name=neox_args.precompute_model_name, + reward_prefix=test_reward_path, ) ) return train_datasets, valid_datasets, test_datasets diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index edba57df2..5d200fd72 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -41,6 +41,8 @@ def __init__( build_index_mappings=True, use_shared_fs=True, label_dataset=None, + reward_dataset=None, + ref_dataset=None, ): self.name = name @@ -48,9 +50,13 @@ def __init__( self.allow_chopped = allow_chopped self.indexed_dataset = indexed_dataset self.label_dataset = label_dataset + self.reward_dataset = reward_dataset + self.ref_dataset = ref_dataset self.seq_length = seq_length - # Checks + assert self.reward_dataset is None or ( + pack_impl == "unpacked" + ), "Reward dataset only supported with unpacked data." assert np.min(documents) >= 0 assert np.max(documents) < indexed_dataset.sizes.shape[0] @@ -90,77 +96,98 @@ def __getitem__(self, idx): offset_f = self.sample_idx[idx][1] offset_l = self.sample_idx[idx + 1][1] # Labels and texts are supposed to be fully in sync. - datasets = ( - [self.indexed_dataset] - if self.label_dataset is None - else [self.indexed_dataset, self.label_dataset] - ) + datasets = [self.indexed_dataset] + rw_indx = 1 + if self.label_dataset is not None: + rw_indx += 1 + datasets.append(self.label_dataset) + if self.reward_dataset is not None: + datasets.append(self.reward_dataset) + if self.ref_dataset is not None: + datasets.append(self.ref_dataset) samples = [] + sample_lengths = [] # If we are within the same document, just extract the chunk. for n, dataset in enumerate(datasets): if doc_index_f == doc_index_l: - samples.append( - dataset.get( - self.doc_idx[doc_index_f], - offset=offset_f, - length=offset_l - offset_f + 1, + if rw_indx == n: + # If we are in the reward dataset, we only need the last token. + samples.append( + dataset.get( + self.doc_idx[doc_index_f], offset=offset_l, length=1 + ) + ) + else: + rw = dataset.get(self.doc_idx[doc_index_f]) + samples.append( + np.array([rw[0] for _ in range(len(samples[-1]))]) ) - ) else: + if n != rw_indx: + # reset + sample_lengths = [] # Otherwise, get the rest of the initial document. - sample_list = [ - dataset.get(self.doc_idx[doc_index_f], offset=offset_f) - ] + if n == rw_indx: + rw = dataset.get(self.doc_idx[doc_index_f]) + sample_list = [ + np.array([rw[0] for _ in range(sample_lengths[0])]) + ] + else: + sample_list = [ + dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] + sample_lengths.append(len(sample_list[-1])) # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): - sample_list.append(dataset.get(self.doc_idx[i])) + if n == rw_indx: + rw = dataset.get(self.doc_idx[i]) + sample_list.append( + np.array([rw[0] for _ in range(sample_lengths[1 + i])]) + ) + else: + sample_list.append(dataset.get(self.doc_idx[i])) + sample_lengths.append(len(sample_list[-1])) # And finally add the relevant portion of last document. - sample_list.append( - dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) - ) - samples.append(np.concatenate(sample_list)) - if len(datasets) == 1: - if len(samples[0]) < (self.seq_length + 1): - # Pad with -100s so the masking function can ignore these. - samples[0] = np.pad( - samples[0], - (0, (self.seq_length + 1) - len(samples[0])), - mode="constant", - constant_values=-100, - ) - elif len(samples[0]) > (self.seq_length + 1): - # Check for overflow and truncate. - samples[0] = samples[0][: (self.seq_length + 1)] - return {"text": np.array(samples[0], dtype=np.int64)} - else: - if len(samples[0]) < (self.seq_length + 1): - # Pad with 0s, can use any number since it's masked. - samples[0] = np.pad( - samples[0], - (0, (self.seq_length + 1) - len(samples[0])), - mode="constant", - constant_values=0, - ) - # pad with -100s so we can mask it out - samples[1] = np.pad( - samples[1], - (0, (self.seq_length + 1) - len(samples[1])), + if n == rw_indx: + rw = dataset.get(self.doc_idx[doc_index_l]) + sample_list.append( + np.array([rw[0] for _ in range(sample_lengths[-1])]) + ) + else: + sample_list.append( + dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) + sample_lengths.append(len(sample_list[-1])) + samples.append(np.concatenate(sample_list)) + for i in range(len(samples)): + mask = (self.label_dataset is not None) and (i == 1) + if len(samples[i]) < (self.seq_length + 1): + # Pad + samples[i] = np.pad( + samples[i], + (0, (self.seq_length + 1) - len(samples[i])), mode="constant", - constant_values=-100, + constant_values=-100 if mask else 0, ) - elif len(samples[0]) > (self.seq_length + 1): - # Check for overflow and truncate. - samples[0] = samples[0][: (self.seq_length + 1)] - samples[1] = samples[1][: (self.seq_length + 1)] - return { - "text": np.array(samples[0], dtype=np.int64), - "label": np.array(samples[1], dtype=np.int64), - } - except IndexError: + elif len(samples[i]) > (self.seq_length + 1): + # Truncate + samples[i] = samples[i][: (self.seq_length + 1)] + ret = {"text": np.array(samples[0], dtype=np.int64)} + next_idx = 1 + if self.label_dataset is not None: + ret["label"] = np.array(samples[next_idx], dtype=np.int64) + next_idx += 1 + if self.reward_dataset is not None: + ret["reward"] = np.array(samples[next_idx], dtype=np.float32) + next_idx += 1 + if self.ref_dataset is not None: + ret["ref"] = np.array(samples[next_idx], dtype=np.float32) + return ret + except IndexError as err: new_idx = idx % len(self) print( - f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx}), error: {err}" ) return self[new_idx] diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 3ce8b881a..80f3de429 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -853,6 +853,11 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to train label datasets (not shifted by 1 yet!). """ + train_reward_data_paths: list = None + """ + List of paths to train reward datasets + """ + test_data_paths: list = None """ List of paths to test datasets. @@ -863,6 +868,11 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to test label datasets (not shifted by 1 yet!). """ + test_reward_data_paths: list = None + """ + List of paths to test reward datasets + """ + valid_data_paths: list = None """ List of paths to validation datasets. @@ -873,6 +883,11 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to validation label datasets (not shifted by 1 yet!). """ + valid_reward_data_paths: list = None + """ + List of paths to validation reward datasets + """ + pos_train_data_paths: list = None neg_train_data_paths: list = None """ @@ -970,9 +985,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): Dataset implementation, can be one of "gpt2" or "pairwise" """ - train_impl: Literal["normal", "dpo"] = "normal" + train_impl: Literal["normal", "dpo", "kto"] = "normal" """ - Training implementation, can be one of "normal" or "dpo" + Training implementation, can be one of "normal", "dpo", or "kto" """ dpo_fp32: bool = True @@ -980,11 +995,31 @@ class NeoXArgsTraining(NeoXArgsTemplate): Whether to cast logits to fp32 for DPO loss calculation. """ + kto_fp32: bool = True + """ + Whether to cast logits to fp32 for KTO loss calculation. + """ + + kto_desirable_weight: float = 1.0 + """ + Weight for desirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + """ + + kto_undesirable_weight: float = 1.0 + """ + Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + """ + dpo_beta: float = 0.1 """ Beta value for DPO """ + kto_beta: float = 0.1 + """ + Beta value for KTO + """ + allow_chopped: bool = True """ WARNING: if your packing impl is packed, this is ignored. diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 02926c2c3..4ca86808a 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -19,6 +19,7 @@ import copy import json +import math import os import time from typing import List, Union @@ -874,7 +875,8 @@ def precompute_logits(neox_args, model): out_dataset = make_builder(out_path + ".bin", neox_args.data_impl) out_dataset._dtype = np.float32 i = 0 - while i < len(dataset): + # Not sure why this requires a multiple of 8 but... + while i < int(math.ceil(len(dataset) / 8.0) * 8): start = time.time() model.module.clear_cache() # clear kv cache between batches if is_mp_rank_0(): diff --git a/megatron/training.py b/megatron/training.py index ee691fb70..1570cf238 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -314,8 +314,8 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - if neox_args.train_impl == "normal": - keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + if neox_args.train_impl in ["normal", "kto"]: + keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] elif neox_args.train_impl == "dpo": keys = ( [["pos", "pos_label"], ["neg", "neg_label"]] @@ -337,6 +337,25 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) + elif neox_args.train_impl == "kto": + tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys, + data=data, + datatype=datatype, + ) + # Remove the last token from the reward since we predict the next token, so + # Reward of will be based on the label of + rw_data = mpu.broadcast_data(["reward"], data, torch.float)["reward"][ + :, :-1 + ].contiguous() + ref_data = ( + mpu.broadcast_data(["ref"], data, torch.float)["ref"][:, :-1].contiguous() + if neox_args.precompute_model_name + else None + ) + return tup + (rw_data, ref_data) elif neox_args.train_impl == "dpo": pos_tup = _get_batch( neox_args=neox_args, @@ -459,13 +478,15 @@ def mb_moe_loss_func(args, loss_mask, output_tensor=None): return averaged_lbl, loss_dict -def get_pos_neg_logp(logits, labels, force_fp32=False): +def get_logp(logits, labels, force_fp32=False): if force_fp32: logits = logits.float() logp = logits.log_softmax(dim=-1) - per_token_logp = torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) - # Split to pos/neg... - return torch.chunk(per_token_logp, 2, 0) + return torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + +def get_pos_neg_logp(logits, labels, force_fp32=False): + return torch.chunk(get_logp(logits, labels, force_fp32), 2, 0) def forward_step( @@ -490,7 +511,17 @@ def forward_step( tokens, labels, loss_mask, attention_mask, position_ids = get_batch( neox_args=neox_args, data_iterator=data_iterator ) - if neox_args.train_impl == "dpo": + elif neox_args.train_impl == "kto": + ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + rewards, + ref_logp, + ) = get_batch(neox_args=neox_args, data_iterator=data_iterator) + elif neox_args.train_impl == "dpo": tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( neox_args=neox_args, data_iterator=data_iterator ) @@ -591,6 +622,115 @@ def forward_step( ref_logrations = ref_pos - ref_neg logits = pi_logrations - ref_logrations loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean() + elif neox_args.train_impl == "kto": + # Based on https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py + # Except we don't have an extra input for KL logp, we just split the batch in half + with torch.no_grad(): + # So we can gather token logps... + token_logp_labels = labels.clone() + token_logp_labels[token_logp_labels == -100] = 0 + if ref_logp is None: + # Did not precompute logits.... + ref_maybe_tuple = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + ref_outputs, _ = ref_maybe_tuple + else: + ref_outputs = ref_maybe_tuple + # gather across tensor parallel group + ref_outputs = gather_from_model_parallel_region(ref_outputs) + + ref_logp = get_logp(ref_outputs, token_logp_labels, neox_args.kto_fp32) + else: + print(f"REF LOGP: {ref_logp.clone().detach().mean()}") + ref_logp = ref_logp * loss_mask + scaling = (rewards.sum(-1) > 0.001).float() * neox_args.kto_desirable_weight + scaling += ( + rewards.sum(-1) < -0.001 + ).float() * neox_args.kto_undesirable_weight + pos_mask = (rewards > 0.001).float() + neg_mask = (rewards < -0.001).float() + chosen_maybe_tuple = model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(chosen_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + chosen_outputs, _ = chosen_maybe_tuple + else: + chosen_outputs = chosen_maybe_tuple + chosen_outputs = gather_from_model_parallel_region(chosen_outputs) + chosen_logp = get_logp(chosen_outputs, token_logp_labels, neox_args.kto_fp32) + chosen_logp = chosen_logp * loss_mask + with torch.no_grad(): + # Collect metrics... + metrics["ref_logp"] = ref_logp.clone().detach().sum(-1).mean() + metrics["policy_logp"] = chosen_logp.clone().detach().sum(-1).mean() + metrics["pos_ref_logp"] = ( + (ref_logp * pos_mask).clone().detach().sum(-1).mean() + ) + metrics["neg_ref_logp"] = ( + (ref_logp * neg_mask).clone().detach().sum(-1).mean() + ) + metrics["pos_policy_logp"] = ( + (chosen_logp * pos_mask).clone().detach().sum(-1).mean() + ) + metrics["neg_policy_logp"] = ( + (chosen_logp * neg_mask).clone().detach().sum(-1).mean() + ) + metrics["kl"] = ( + chosen_logp.clone().detach() - ref_logp.clone().detach() + ).sum() / loss_mask.sum() + policy_rewards = ( + neox_args.kto_beta + * rewards + * (chosen_logp.clone().detach() - ref_logp.clone().detach()) + ) + reward_acc = (policy_rewards.sum(-1) > 0.0).float() + metrics["reward_acc"] = reward_acc.mean() + metrics["policy_rewards"] = policy_rewards.sum() + print(metrics) + pol_logp1, pol_logp2 = torch.chunk(chosen_logp, 2, 0) + ref_logp1, ref_logp2 = torch.chunk(ref_logp, 2, 0) + reward1, reward2 = torch.chunk(rewards, 2, 0) + scaling1, scaling2 = torch.chunk(scaling, 2, 0) + kl1 = torch.clamp((pol_logp1 - ref_logp1).sum(-1), min=0).mean() + kl2 = torch.clamp((pol_logp2 - ref_logp2).sum(-1), min=0).mean() + log_ratio1 = pol_logp1 - ref_logp1 + log_ratio2 = pol_logp2 - ref_logp2 + + # TODO: Add pack_until_overflow sequence support + loss = ( + 0.5 + * scaling1.mean(-1) + * ( + 1 + - F.sigmoid( + ( + neox_args.kto_beta + * reward1.mean(-1) + * (log_ratio1.sum(-1) - kl2.clone().detach()) + ) + ) + ) + ) + ( + 0.5 + * scaling2.mean(-1) + * ( + 1 + - F.sigmoid( + ( + neox_args.kto_beta + * reward2.mean(-1) + * (log_ratio2.sum(-1) - kl1.clone().detach()) + ) + ) + ) + ) + # print(loss.shape) + loss = loss.mean() + # print(loss.shape) if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: @@ -876,9 +1016,9 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" - needs_reference_model = (neox_args.train_impl == "dpo") and ( - neox_args.precompute_model_name is None - ) + needs_reference_model = ( + (neox_args.train_impl == "dpo") and (neox_args.precompute_model_name is None) + ) or ((neox_args.train_impl == "kto") and (neox_args.precompute_model_name is None)) model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: reference_model = get_model(neox_args=neox_args, use_cache=use_cache) diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 81770deff..cbef58eb0 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -137,7 +137,15 @@ def encode(self, text): Encoder.tokenizer, self.args.only_last, ) - ids[key] = (text_ids, label_ids) + if self.args.reward_key is not None: + reward = text[self.args.reward_key] + if self.args.binary_reward: + reward = [1] if reward else [-1] + elif type(reward) == float: + reward = [reward] + ids[key] = (text_ids, label_ids, reward) + else: + ids[key] = (text_ids, label_ids, None) return ids, len(text) @@ -173,6 +181,17 @@ def get_args(): help="If set, this will mask everything except the last turn in the chat.", action="store_true", ) + group.add_argument( + "--reward-key", + type=str, + default=None, + help="Optional: key to use for reward data in the input data.", + ) + group.add_argument( + "--binary-reward", + help="If set, this will treat the reward data as a boolean.", + action="store_true", + ) group.add_argument( "--num-docs", default=None, @@ -287,19 +306,36 @@ def main(): assert ( key + "_label" not in args.jsonl_keys ), "label should not be included as it will be generated according to the mask." - key += "_label" - output_bin_files[key] = "{}_{}_{}.bin".format( - args.output_prefix, key, "document" + label_key = key + "_label" + output_bin_files[label_key] = "{}_{}_{}.bin".format( + args.output_prefix, label_key, "document" + ) + output_idx_files[label_key] = "{}_{}_{}.idx".format( + args.output_prefix, label_key, "document" + ) + builders[label_key] = indexed_dataset.make_builder( + output_bin_files[label_key], + impl=args.dataset_impl, + vocab_size=tokenizer.vocab_size, + ) + builders[label_key]._dtype = np.int32 + if args.reward_key is not None: + assert ( + key + "_reward" not in args.jsonl_keys + ), "reward should not be included as it will be generated from the data." + reward_key = key + "_reward" + output_bin_files[reward_key] = "{}_{}_{}.bin".format( + args.output_prefix, reward_key, "document" ) - output_idx_files[key] = "{}_{}_{}.idx".format( - args.output_prefix, key, "document" + output_idx_files[reward_key] = "{}_{}_{}.idx".format( + args.output_prefix, reward_key, "document" ) - builders[key] = indexed_dataset.make_builder( - output_bin_files[key], + builders[reward_key] = indexed_dataset.make_builder( + output_bin_files[reward_key], impl=args.dataset_impl, vocab_size=tokenizer.vocab_size, ) - builders[key]._dtype = np.int32 + builders[reward_key]._dtype = np.int32 # actually do tokenization proc_start = time.time() @@ -315,17 +351,25 @@ def main(): for key, conv in doc.items(): tokens = conv[0] token_mask = conv[1] + reward = conv[2] builders[key].add_item(np.array(tokens, dtype=builders[key].dtype)) builders[key + "_label"].add_item( np.array(token_mask, dtype=builders[key + "_label"].dtype) ) + if args.reward_key is not None: + builders[key + "_reward"].add_item( + np.array(reward, dtype=builders[key + "_reward"].dtype) + ) # add indx... builders[key].end_document() builders[key + "_label"].end_document() + if args.reward_key is not None: + builders[key + "_reward"].end_document() if i == 1: print("key: ", key) print("tokens: ", tokens) print("token_mask: ", token_mask) + print("Reward: ", reward) # log progress if i % args.log_interval == 0: current = time.time() From 0cdcf2b840b19937d25d5e67a9de730f2e0bdb4c Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 28 Jun 2024 10:02:34 -0500 Subject: [PATCH 12/28] Update README.md --- configs/README.md | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/configs/README.md b/configs/README.md index 3102a34d1..71a09ebea 100644 --- a/configs/README.md +++ b/configs/README.md @@ -235,7 +235,32 @@ Additional DeepSpeed settings besides those mentioned above should be wrapped in "eval_iters": 10, ``` -However, if you want to use DPO style training you'll need to set pos/neg data paths instead of a single one, e.g. +For KTO style training, you'll need to add the reward & label data path, e.g.: + +```yaml + "data_impl": "mmap", + # Suggested data paths when using GPT-NeoX locally + "train_data_path": "data/enwik8/enwik8_text_document", + "train_label_data_path": "data/enwik8/enwik8_text_label_document", + "train_reward_data_path": "data/enwik8/enwik8_text_reward_document", + "test_data_path": "data/enwik8/enwik8_text_document", + "test_label_data_path": "data/enwik8/enwik8_text_label_document", + "test_reward_data_path": "data/enwik8/enwik8_text_reward_document", + "valid_data_path": "data/enwik8/enwik8_text_document", + "valid_label_data_path": "data/enwik8/enwik8_text_label_document", + "valid_reward_data_path": "data/enwik8/enwik8_text_reward_document", + "vocab_file": "data/gpt2-vocab.json", + "merge_file": "data/gpt2-merges.txt", + "save": "checkpoints", + "load": "checkpoints", + "tensorboard_dir": "tensorboard", + "log_dir": "logs", + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, +``` + +For DPO style training, you'll need to set pos/neg data paths instead of a single one, e.g. ```yaml "dataset_impl": "pairwise", From daab1ab5e17d3c32b0588c70e90b09a55aa93599 Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sun, 8 Sep 2024 17:45:47 -0700 Subject: [PATCH 13/28] precommit --- megatron/data/data_utils.py | 2 +- megatron/text_generation_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 41c9ac515..38d9f9c9f 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -167,7 +167,7 @@ def build_the_dataset( pos_ref_dataset=pos_ref_dataset, neg_ref_dataset=neg_ref_dataset, ) - + return dataset diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 4fe18bb21..4d2714c53 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -875,7 +875,7 @@ def precompute_logits(neox_args, model): out_dataset = make_builder(out_path + ".bin", neox_args.data_impl) out_dataset._dtype = np.float32 i = 0 - + # TODO: Not sure why this requires a multiple of 8 but... while i < int(math.ceil(len(dataset) / 8.0) * 8): start = time.time() From 07601169b0356055a57b29e305694a9688509c0d Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Sun, 8 Sep 2024 21:32:37 -0500 Subject: [PATCH 14/28] - KTO implementation from main... --- megatron/data/data_utils.py | 23 ++- megatron/data/gpt2_dataset.py | 140 +++++++------ megatron/neox_arguments/neox_args.py | 40 +++- megatron/training.py | 192 ++++++++++++++---- .../preprocess_data_with_chat_template.py | 88 ++++---- 5 files changed, 346 insertions(+), 137 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 7c13131ad..4de1dae44 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -70,6 +70,7 @@ def build_the_dataset( pos_label_prefix=None, neg_label_prefix=None, precompute_model_name=None, + reward_prefix=None, ): """Build train/valid/test datasets.""" if dataset_impl == "gpt2": @@ -84,6 +85,10 @@ def build_the_dataset( data_prefix + "_" + precompute_model_name, data_impl, skip_warmup ) precompute_indexed_dataset = precompute_indexed_dataset + else: + precompute_indexed_dataset = None + if reward_prefix is not None: + reward_dataset = make_indexed_dataset(reward_prefix, data_impl, skip_warmup) elif dataset_impl == "pairwise": pos_indexed_dataset = make_indexed_dataset( pos_data_prefix, data_impl, skip_warmup @@ -141,6 +146,8 @@ def build_the_dataset( allow_chopped=allow_chopped, build_index_mappings=build_index_mappings, label_dataset=label_dataset, + reward_dataset=reward_dataset, + ref_dataset=precompute_indexed_dataset, ) elif dataset_impl == "pairwise": dataset = PairwiseDataset( @@ -285,10 +292,13 @@ def build_weighted_datasets( for i, ( train_path, train_label_path, + train_reward_path, valid_path, valid_label_path, + valid_reward_path, test_path, test_label_path, + test_reward_path, pos_train_path, neg_train_path, pos_train_label_path, @@ -307,12 +317,21 @@ def build_weighted_datasets( neox_args.train_label_data_paths if neox_args.train_label_data_paths else [], + neox_args.train_reward_data_paths + if neox_args.train_reward_data_paths + else [], neox_args.valid_data_paths if neox_args.valid_data_paths else [], neox_args.valid_label_data_paths if neox_args.valid_label_data_paths else [], + neox_args.valid_reward_data_paths + if neox_args.valid_reward_data_paths + else [], neox_args.test_data_paths if neox_args.test_data_paths else [], neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], + neox_args.test_reward_data_paths + if neox_args.test_reward_data_paths + else [], neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], neox_args.pos_train_label_data_paths @@ -359,6 +378,7 @@ def build_weighted_datasets( pos_label_prefix=pos_train_label_path, neg_label_prefix=neg_train_label_path, precompute_model_name=neox_args.precompute_model_name, + reward_prefix=train_reward_path, ) ) @@ -382,6 +402,7 @@ def build_weighted_datasets( pos_label_prefix=pos_valid_label_path, neg_label_prefix=neg_valid_label_path, precompute_model_name=neox_args.precompute_model_name, + reward_prefix=valid_reward_path, ) ) @@ -405,6 +426,7 @@ def build_weighted_datasets( pos_label_prefix=pos_test_label_path, neg_label_prefix=neg_test_label_path, precompute_model_name=neox_args.precompute_model_name, + reward_prefix=test_reward_path, ) ) return train_datasets, valid_datasets, test_datasets @@ -502,7 +524,6 @@ def build_train_valid_test_data_iterators(neox_args): ) if neox_args.weight_by_num_documents: - # gets the number of documents in each datapath get_num_docs_list = lambda datasets: [ dataset.indexed_dataset.sizes.shape[0] for dataset in datasets diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index edba57df2..86d0c5ad3 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -41,6 +41,8 @@ def __init__( build_index_mappings=True, use_shared_fs=True, label_dataset=None, + reward_dataset=None, + ref_dataset=None, ): self.name = name @@ -48,9 +50,13 @@ def __init__( self.allow_chopped = allow_chopped self.indexed_dataset = indexed_dataset self.label_dataset = label_dataset + self.reward_dataset = reward_dataset + self.ref_dataset = ref_dataset self.seq_length = seq_length - # Checks + assert self.reward_dataset is None or ( + pack_impl == "unpacked" + ), "Reward dataset only supported with unpacked data." assert np.min(documents) >= 0 assert np.max(documents) < indexed_dataset.sizes.shape[0] @@ -90,77 +96,97 @@ def __getitem__(self, idx): offset_f = self.sample_idx[idx][1] offset_l = self.sample_idx[idx + 1][1] # Labels and texts are supposed to be fully in sync. - datasets = ( - [self.indexed_dataset] - if self.label_dataset is None - else [self.indexed_dataset, self.label_dataset] - ) + datasets = [self.indexed_dataset] + rw_indx = 1 + if self.label_dataset is not None: + rw_indx += 1 + datasets.append(self.label_dataset) + if self.reward_dataset is not None: + datasets.append(self.reward_dataset) + if self.ref_dataset is not None: + datasets.append(self.ref_dataset) samples = [] + sample_lengths = [] # If we are within the same document, just extract the chunk. for n, dataset in enumerate(datasets): if doc_index_f == doc_index_l: - samples.append( - dataset.get( - self.doc_idx[doc_index_f], - offset=offset_f, - length=offset_l - offset_f + 1, + if rw_indx == n: + # If we are in the reward dataset, we only need the last token. + samples.append( + dataset.get( + self.doc_idx[doc_index_f], offset=offset_l, length=1 + ) + ) + else: + rw = dataset.get(self.doc_idx[doc_index_f]) + samples.append( + np.array([rw[0] for _ in range(len(samples[-1]))]) ) - ) else: + if n != rw_indx: + # reset + sample_lengths = [] # Otherwise, get the rest of the initial document. - sample_list = [ - dataset.get(self.doc_idx[doc_index_f], offset=offset_f) - ] + if n == rw_indx: + rw = dataset.get(self.doc_idx[doc_index_f]) + sample_list = [ + np.array([rw[0] for _ in range(sample_lengths[0])]) + ] + else: + sample_list = [ + dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] + sample_lengths.append(len(sample_list[-1])) # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): - sample_list.append(dataset.get(self.doc_idx[i])) + if n == rw_indx: + rw = dataset.get(self.doc_idx[i]) + sample_list.append( + np.array([rw[0] for _ in range(sample_lengths[1 + i])]) + ) + else: + sample_list.append(dataset.get(self.doc_idx[i])) + sample_lengths.append(len(sample_list[-1])) # And finally add the relevant portion of last document. - sample_list.append( - dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) - ) + if n == rw_indx: + rw = dataset.get(self.doc_idx[doc_index_l]) + sample_list.append( + np.array([rw[0] for _ in range(sample_lengths[-1])]) + ) + else: + sample_list.append( + dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) + sample_lengths.append(len(sample_list[-1])) samples.append(np.concatenate(sample_list)) - - if len(datasets) == 1: - if len(samples[0]) < (self.seq_length + 1): - # Pad with -100s so the masking function can ignore these. - samples[0] = np.pad( - samples[0], - (0, (self.seq_length + 1) - len(samples[0])), - mode="constant", - constant_values=-100, - ) - elif len(samples[0]) > (self.seq_length + 1): - # Check for overflow and truncate. - samples[0] = samples[0][: (self.seq_length + 1)] - return {"text": np.array(samples[0], dtype=np.int64)} - else: - if len(samples[0]) < (self.seq_length + 1): - # Pad with 0s, can use any number since it's masked. - samples[0] = np.pad( - samples[0], - (0, (self.seq_length + 1) - len(samples[0])), - mode="constant", - constant_values=0, - ) - # pad with -100s so we can mask it out - samples[1] = np.pad( - samples[1], - (0, (self.seq_length + 1) - len(samples[1])), + for i in range(len(samples)): + mask = (self.label_dataset is not None) and (i == 1) + if len(samples[i]) < (self.seq_length + 1): + # Pad + samples[i] = np.pad( + samples[i], + (0, (self.seq_length + 1) - len(samples[i])), mode="constant", - constant_values=-100, + constant_values=-100 if mask else 0, ) - elif len(samples[0]) > (self.seq_length + 1): - # Check for overflow and truncate. - samples[0] = samples[0][: (self.seq_length + 1)] - samples[1] = samples[1][: (self.seq_length + 1)] - return { - "text": np.array(samples[0], dtype=np.int64), - "label": np.array(samples[1], dtype=np.int64), - } - except IndexError: + elif len(samples[i]) > (self.seq_length + 1): + # Truncate + samples[i] = samples[i][: (self.seq_length + 1)] + ret = {"text": np.array(samples[0], dtype=np.int64)} + next_idx = 1 + if self.label_dataset is not None: + ret["label"] = np.array(samples[next_idx], dtype=np.int64) + next_idx += 1 + if self.reward_dataset is not None: + ret["reward"] = np.array(samples[next_idx], dtype=np.float32) + next_idx += 1 + if self.ref_dataset is not None: + ret["ref"] = np.array(samples[next_idx], dtype=np.float32) + return ret + except IndexError as err: new_idx = idx % len(self) print( - f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx}), error: {err}" ) return self[new_idx] diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index a87a573ae..a50f28a7a 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -880,6 +880,11 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to train label datasets (not shifted by 1 yet!). """ + train_reward_data_paths: list = None + """ + List of paths to train reward datasets + """ + test_data_paths: list = None """ List of paths to test datasets. @@ -890,6 +895,11 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to test label datasets (not shifted by 1 yet!). """ + test_reward_data_paths: list = None + """ + List of paths to test reward datasets + """ + valid_data_paths: list = None """ List of paths to validation datasets. @@ -900,6 +910,11 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to validation label datasets (not shifted by 1 yet!). """ + valid_reward_data_paths: list = None + """ + List of paths to validation reward datasets + """ + pos_train_data_paths: list = None neg_train_data_paths: list = None """ @@ -997,9 +1012,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): Dataset implementation, can be one of "gpt2" or "pairwise" """ - train_impl: Literal["normal", "dpo", "rm"] = "normal" + train_impl: Literal["normal", "dpo", "kto"] = "normal" """ - Training implementation, can be one of "normal", "dpo", or "rm" + Training implementation, can be one of "normal", "dpo", or "kto" """ dpo_fp32: bool = True @@ -1007,16 +1022,29 @@ class NeoXArgsTraining(NeoXArgsTemplate): Whether to cast logits to fp32 for DPO loss calculation. """ + kto_fp32: bool = True + """ + Whether to cast logits to fp32 for KTO loss calculation. + """ + + kto_desirable_weight: float = 1.0 + """ + Weight for desirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + """ + + kto_undesirable_weight: float = 1.0 + """ + Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + """ + dpo_beta: float = 0.1 """ Beta value for DPO """ - z_loss: float = 0.0 + kto_beta: float = 0.1 """ - Z-loss parameter, only implemented for RM training currently. - https://arxiv.org/pdf/2204.02311 - https://arxiv.org/pdf/2309.10305 + Beta value for KTO """ allow_chopped: bool = True diff --git a/megatron/training.py b/megatron/training.py index 31e5c2bce..52372d9e8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -315,9 +315,9 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - if neox_args.train_impl == "normal": + if neox_args.train_impl in ["normal", "kto"]: keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] - elif neox_args.train_impl in ["dpo", "rm"]: + elif neox_args.train_impl == "dpo": keys = ( [["pos", "pos_label"], ["neg", "neg_label"]] if neox_args.pos_train_label_data_paths @@ -338,7 +338,26 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) - elif neox_args.train_impl in ["dpo", "rm"]: + elif neox_args.train_impl == "kto": + tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys, + data=data, + datatype=datatype, + ) + # Remove the last token from the reward since we predict the next token, so + # Reward of will be based on the label of + rw_data = mpu.broadcast_data(["reward"], data, torch.float)["reward"][ + :, :-1 + ].contiguous() + ref_data = ( + mpu.broadcast_data(["ref"], data, torch.float)["ref"][:, :-1].contiguous() + if neox_args.precompute_model_name + else None + ) + return tup + (rw_data, ref_data) + elif neox_args.train_impl == "dpo": pos_tup = _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -353,7 +372,7 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) - if (neox_args.precompute_model_name) and (neox_args.train_impl == "dpo"): + if neox_args.precompute_model_name: ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float) else: ref_data = {"pos_ref": None} @@ -460,13 +479,15 @@ def mb_moe_loss_func(args, loss_mask, output_tensor=None): return averaged_lbl, loss_dict -def get_pos_neg_logp(logits, labels, force_fp32=False): +def get_logp(logits, labels, force_fp32=False): if force_fp32: logits = logits.float() logp = logits.log_softmax(dim=-1) - per_token_logp = torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) - # Split to pos/neg... - return torch.chunk(per_token_logp, 2, 0) + return torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + +def get_pos_neg_logp(logits, labels, force_fp32=False): + return torch.chunk(get_logp(logits, labels, force_fp32), 2, 0) def forward_step( @@ -491,7 +512,17 @@ def forward_step( tokens, labels, loss_mask, attention_mask, position_ids = get_batch( neox_args=neox_args, data_iterator=data_iterator ) - if neox_args.train_impl in ["dpo", "rm"]: + elif neox_args.train_impl == "kto": + ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + rewards, + ref_logp, + ) = get_batch(neox_args=neox_args, data_iterator=data_iterator) + elif neox_args.train_impl == "dpo": tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( neox_args=neox_args, data_iterator=data_iterator ) @@ -532,32 +563,6 @@ def forward_step( else: moe_loss = 0.0 loss = main_loss + moe_loss - elif neox_args.train_impl == "rm": - maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) - if type(maybe_tuple) is tuple: - outputs, _ = maybe_tuple - else: - outputs = maybe_tuple - pos, neg = torch.chunk(outputs, 2, 0) - pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) - # We assume that each pos, neg pair occur in the same order - # e.g. second nonzero pos is the corresponding second nonzero neg - # and that there are also an equal number of pos and neg in each sequence. - pos_indx = pos_loss_mask.nonzero() - neg_indx = neg_loss_mask.nonzero() - # indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index. - pos_indx = pos_indx[:, 1].unsqueeze(1) - neg_indx = neg_indx[:, 1].unsqueeze(1) - pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx) - neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx) - with torch.no_grad(): - metrics["pos_values"] = pos.clone().detach().mean() - metrics["neg_values"] = neg.clone().detach().mean() - metrics["margin"] = (pos - neg).clone().detach().mean() - metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean() - loss = (-F.logsigmoid(pos - neg).mean()) + ( - (neox_args.z_loss * (pos**2 + neg**2)).mean() - ) elif neox_args.train_impl == "dpo": # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 with torch.no_grad(): @@ -618,6 +623,115 @@ def forward_step( ref_logrations = ref_pos - ref_neg logits = pi_logrations - ref_logrations loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean() + elif neox_args.train_impl == "kto": + # Based on https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py + # Except we don't have an extra input for KL logp, we just split the batch in half + with torch.no_grad(): + # So we can gather token logps... + token_logp_labels = labels.clone() + token_logp_labels[token_logp_labels == -100] = 0 + if ref_logp is None: + # Did not precompute logits.... + ref_maybe_tuple = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + ref_outputs, _ = ref_maybe_tuple + else: + ref_outputs = ref_maybe_tuple + # gather across tensor parallel group + ref_outputs = gather_from_model_parallel_region(ref_outputs) + + ref_logp = get_logp(ref_outputs, token_logp_labels, neox_args.kto_fp32) + else: + print(f"REF LOGP: {ref_logp.clone().detach().mean()}") + ref_logp = ref_logp * loss_mask + scaling = (rewards.sum(-1) > 0.001).float() * neox_args.kto_desirable_weight + scaling += ( + rewards.sum(-1) < -0.001 + ).float() * neox_args.kto_undesirable_weight + pos_mask = (rewards > 0.001).float() + neg_mask = (rewards < -0.001).float() + chosen_maybe_tuple = model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(chosen_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + chosen_outputs, _ = chosen_maybe_tuple + else: + chosen_outputs = chosen_maybe_tuple + chosen_outputs = gather_from_model_parallel_region(chosen_outputs) + chosen_logp = get_logp(chosen_outputs, token_logp_labels, neox_args.kto_fp32) + chosen_logp = chosen_logp * loss_mask + with torch.no_grad(): + # Collect metrics... + metrics["ref_logp"] = ref_logp.clone().detach().sum(-1).mean() + metrics["policy_logp"] = chosen_logp.clone().detach().sum(-1).mean() + metrics["pos_ref_logp"] = ( + (ref_logp * pos_mask).clone().detach().sum(-1).mean() + ) + metrics["neg_ref_logp"] = ( + (ref_logp * neg_mask).clone().detach().sum(-1).mean() + ) + metrics["pos_policy_logp"] = ( + (chosen_logp * pos_mask).clone().detach().sum(-1).mean() + ) + metrics["neg_policy_logp"] = ( + (chosen_logp * neg_mask).clone().detach().sum(-1).mean() + ) + metrics["kl"] = ( + chosen_logp.clone().detach() - ref_logp.clone().detach() + ).sum() / loss_mask.sum() + policy_rewards = ( + neox_args.kto_beta + * rewards + * (chosen_logp.clone().detach() - ref_logp.clone().detach()) + ) + reward_acc = (policy_rewards.sum(-1) > 0.0).float() + metrics["reward_acc"] = reward_acc.mean() + metrics["policy_rewards"] = policy_rewards.sum() + print(metrics) + pol_logp1, pol_logp2 = torch.chunk(chosen_logp, 2, 0) + ref_logp1, ref_logp2 = torch.chunk(ref_logp, 2, 0) + reward1, reward2 = torch.chunk(rewards, 2, 0) + scaling1, scaling2 = torch.chunk(scaling, 2, 0) + kl1 = torch.clamp((pol_logp1 - ref_logp1).sum(-1), min=0).mean() + kl2 = torch.clamp((pol_logp2 - ref_logp2).sum(-1), min=0).mean() + log_ratio1 = pol_logp1 - ref_logp1 + log_ratio2 = pol_logp2 - ref_logp2 + + # TODO: Add pack_until_overflow sequence support + loss = ( + 0.5 + * scaling1.mean(-1) + * ( + 1 + - F.sigmoid( + ( + neox_args.kto_beta + * reward1.mean(-1) + * (log_ratio1.sum(-1) - kl2.clone().detach()) + ) + ) + ) + ) + ( + 0.5 + * scaling2.mean(-1) + * ( + 1 + - F.sigmoid( + ( + neox_args.kto_beta + * reward2.mean(-1) + * (log_ratio2.sum(-1) - kl1.clone().detach()) + ) + ) + ) + ) + # print(loss.shape) + loss = loss.mean() + # print(loss.shape) if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: @@ -642,7 +756,7 @@ def get_model(neox_args, use_cache=False): model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True if neox_args.train_impl != "rm" else False, + parallel_output=True, topology=mpu.get_topology(), use_cache=use_cache, ) @@ -903,9 +1017,9 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" - needs_reference_model = (neox_args.train_impl == "dpo") and ( - neox_args.precompute_model_name is None - ) + needs_reference_model = ( + (neox_args.train_impl == "dpo") and (neox_args.precompute_model_name is None) + ) or ((neox_args.train_impl == "kto") and (neox_args.precompute_model_name is None)) model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: reference_model = get_model(neox_args=neox_args, use_cache=use_cache) diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 4d058127c..cbef58eb0 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -81,7 +81,6 @@ def build_chat( apply_mask: bool, tokenizer: PreTrainedTokenizer, only_last_turn: bool = False, - for_rm: bool = False, ) -> Tuple[List[int], List[int]]: """ Build a chat from a list of dictionaries. Each dictionary should have a "role" and "content" key, this follows the @@ -92,28 +91,12 @@ def build_chat( :param apply_mask: Whether to apply a loss mask to the chat, if False, all tokens will be included in the loss :param tokenizer: A HF tokenizer :param only_last_turn: Whether to only include the last turn in the chat, needed for some fine-tuning tasks - :param for_rm: Whether this is for a reward model or not, this will mask everything except EOS token. - If you need a more complicated setup, you can modify this function to suit your needs. """ tokens = [] mask = [] if apply_mask is False: tokens = tokenizer.apply_chat_template(chat) mask = tokens - if tokenizer.eos_token_id is not None: - mask.append(tokenizer.eos_token_id) - tokens.append(tokenizer.eos_token_id) - return tokens, mask - elif for_rm: - tokens = tokenizer.apply_chat_template(chat) - mask = [-100] * len(tokens) - if tokenizer.eos_token_id is not None: - mask.append(tokenizer.eos_token_id) - tokens.append(tokenizer.eos_token_id) - else: - raise ValueError( - "Tokenizer does not have an EOS token, unable to determine good mask, please edit and make your own." - ) return tokens, mask for i, turn in enumerate(chat): add_gen = ( @@ -121,8 +104,7 @@ def build_chat( ) chat_tokens = tokenizer.apply_chat_template( chat[: i + 1], add_generation_prompt=add_gen - )[len(tokens) :] - + ) # remove previous stuff... tokens.extend(chat_tokens) if only_last_turn and (i != len(chat) - 1): @@ -154,9 +136,16 @@ def encode(self, text): not self.args.no_mask, Encoder.tokenizer, self.args.only_last, - self.args.for_rm, ) - ids[key] = (text_ids, label_ids) + if self.args.reward_key is not None: + reward = text[self.args.reward_key] + if self.args.binary_reward: + reward = [1] if reward else [-1] + elif type(reward) == float: + reward = [reward] + ids[key] = (text_ids, label_ids, reward) + else: + ids[key] = (text_ids, label_ids, None) return ids, len(text) @@ -181,11 +170,6 @@ def get_args(): help="If set, this will not mask any tokens in the input data.", action="store_true", ) - group.add_argument( - "--for-rm", - help="If set, this will mask everything except the last token in the chat.", - action="store_true", - ) group.add_argument( "--generation-role", type=str, @@ -197,6 +181,17 @@ def get_args(): help="If set, this will mask everything except the last turn in the chat.", action="store_true", ) + group.add_argument( + "--reward-key", + type=str, + default=None, + help="Optional: key to use for reward data in the input data.", + ) + group.add_argument( + "--binary-reward", + help="If set, this will treat the reward data as a boolean.", + action="store_true", + ) group.add_argument( "--num-docs", default=None, @@ -311,19 +306,36 @@ def main(): assert ( key + "_label" not in args.jsonl_keys ), "label should not be included as it will be generated according to the mask." - key += "_label" - output_bin_files[key] = "{}_{}_{}.bin".format( - args.output_prefix, key, "document" + label_key = key + "_label" + output_bin_files[label_key] = "{}_{}_{}.bin".format( + args.output_prefix, label_key, "document" + ) + output_idx_files[label_key] = "{}_{}_{}.idx".format( + args.output_prefix, label_key, "document" + ) + builders[label_key] = indexed_dataset.make_builder( + output_bin_files[label_key], + impl=args.dataset_impl, + vocab_size=tokenizer.vocab_size, + ) + builders[label_key]._dtype = np.int32 + if args.reward_key is not None: + assert ( + key + "_reward" not in args.jsonl_keys + ), "reward should not be included as it will be generated from the data." + reward_key = key + "_reward" + output_bin_files[reward_key] = "{}_{}_{}.bin".format( + args.output_prefix, reward_key, "document" ) - output_idx_files[key] = "{}_{}_{}.idx".format( - args.output_prefix, key, "document" + output_idx_files[reward_key] = "{}_{}_{}.idx".format( + args.output_prefix, reward_key, "document" ) - builders[key] = indexed_dataset.make_builder( - output_bin_files[key], + builders[reward_key] = indexed_dataset.make_builder( + output_bin_files[reward_key], impl=args.dataset_impl, vocab_size=tokenizer.vocab_size, ) - builders[key]._dtype = np.int32 + builders[reward_key]._dtype = np.int32 # actually do tokenization proc_start = time.time() @@ -339,17 +351,25 @@ def main(): for key, conv in doc.items(): tokens = conv[0] token_mask = conv[1] + reward = conv[2] builders[key].add_item(np.array(tokens, dtype=builders[key].dtype)) builders[key + "_label"].add_item( np.array(token_mask, dtype=builders[key + "_label"].dtype) ) + if args.reward_key is not None: + builders[key + "_reward"].add_item( + np.array(reward, dtype=builders[key + "_reward"].dtype) + ) # add indx... builders[key].end_document() builders[key + "_label"].end_document() + if args.reward_key is not None: + builders[key + "_reward"].end_document() if i == 1: print("key: ", key) print("tokens: ", tokens) print("token_mask: ", token_mask) + print("Reward: ", reward) # log progress if i % args.log_interval == 0: current = time.time() From b6f9d5c0d595309d091cd96621b47b817ec25925 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 9 Sep 2024 11:30:49 -0500 Subject: [PATCH 15/28] initial changes... --- megatron/data/gpt2_dataset.py | 20 +++++++++++-------- .../preprocess_data_with_chat_template.py | 20 ++++++++++++++++++- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 86d0c5ad3..abe819089 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -103,6 +103,8 @@ def __getitem__(self, idx): datasets.append(self.label_dataset) if self.reward_dataset is not None: datasets.append(self.reward_dataset) + else: + rw_indx = -1 if self.ref_dataset is not None: datasets.append(self.ref_dataset) samples = [] @@ -112,15 +114,17 @@ def __getitem__(self, idx): if doc_index_f == doc_index_l: if rw_indx == n: # If we are in the reward dataset, we only need the last token. + rw = dataset.get(self.doc_idx[doc_index_f]) samples.append( - dataset.get( - self.doc_idx[doc_index_f], offset=offset_l, length=1 - ) + np.array([rw[0] for _ in range(len(samples[-1]))]) ) else: - rw = dataset.get(self.doc_idx[doc_index_f]) samples.append( - np.array([rw[0] for _ in range(len(samples[-1]))]) + dataset.get( + self.doc_idx[doc_index_f], + offset=offset_l, + length=offset_l - offset_f + 1, + ) ) else: if n != rw_indx: @@ -130,7 +134,7 @@ def __getitem__(self, idx): if n == rw_indx: rw = dataset.get(self.doc_idx[doc_index_f]) sample_list = [ - np.array([rw[0] for _ in range(sample_lengths[0])]) + np.array([rw[0] for _ in range(sample_lengths.pop(0))]) ] else: sample_list = [ @@ -142,7 +146,7 @@ def __getitem__(self, idx): if n == rw_indx: rw = dataset.get(self.doc_idx[i]) sample_list.append( - np.array([rw[0] for _ in range(sample_lengths[1 + i])]) + np.array([rw[0] for _ in range(sample_lengths.pop(0))]) ) else: sample_list.append(dataset.get(self.doc_idx[i])) @@ -151,7 +155,7 @@ def __getitem__(self, idx): if n == rw_indx: rw = dataset.get(self.doc_idx[doc_index_l]) sample_list.append( - np.array([rw[0] for _ in range(sample_lengths[-1])]) + np.array([rw[0] for _ in range(sample_lengths.pop(0))]) ) else: sample_list.append( diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index cbef58eb0..1083155fa 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -81,6 +81,7 @@ def build_chat( apply_mask: bool, tokenizer: PreTrainedTokenizer, only_last_turn: bool = False, + for_rm: bool = False, ) -> Tuple[List[int], List[int]]: """ Build a chat from a list of dictionaries. Each dictionary should have a "role" and "content" key, this follows the @@ -98,13 +99,24 @@ def build_chat( tokens = tokenizer.apply_chat_template(chat) mask = tokens return tokens, mask + elif for_rm: + tokens = tokenizer.apply_chat_template(chat) + mask = [-100] * len(tokens) + if tokenizer.eos_token_id is not None: + mask.append(tokenizer.eos_token_id) + tokens.append(tokenizer.eos_token_id) + else: + raise ValueError( + "Tokenizer does not have an EOS token, unable to determine good mask, please edit and make your own." + ) + return tokens, mask for i, turn in enumerate(chat): add_gen = ( False if i == len(chat) - 1 else chat[i + 1]["role"] == generation_role ) chat_tokens = tokenizer.apply_chat_template( chat[: i + 1], add_generation_prompt=add_gen - ) + )[len(tokens) :] # remove previous stuff... tokens.extend(chat_tokens) if only_last_turn and (i != len(chat) - 1): @@ -136,6 +148,7 @@ def encode(self, text): not self.args.no_mask, Encoder.tokenizer, self.args.only_last, + self.args.for_rm, ) if self.args.reward_key is not None: reward = text[self.args.reward_key] @@ -170,6 +183,11 @@ def get_args(): help="If set, this will not mask any tokens in the input data.", action="store_true", ) + group.add_argument( + "--for-rm", + help="If set, this will mask everything except the last token in the chat.", + action="store_true", + ) group.add_argument( "--generation-role", type=str, From 38aef24790ad3d9debffd573e89b7320f50d35e3 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 9 Sep 2024 11:32:10 -0500 Subject: [PATCH 16/28] pre-commit... --- megatron/data/gpt2_dataset.py | 1 + tools/datasets/preprocess_data_with_chat_template.py | 1 + 2 files changed, 2 insertions(+) diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index abe819089..e37c558d2 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -53,6 +53,7 @@ def __init__( self.reward_dataset = reward_dataset self.ref_dataset = ref_dataset self.seq_length = seq_length + # Checks assert self.reward_dataset is None or ( pack_impl == "unpacked" diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 1083155fa..3e44cf261 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -188,6 +188,7 @@ def get_args(): help="If set, this will mask everything except the last token in the chat.", action="store_true", ) + group.add_argument( "--generation-role", type=str, From 303ba8010b4cf9eb446432adc5659254a356f5a3 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 9 Sep 2024 14:48:05 -0500 Subject: [PATCH 17/28] hotfix + data loader update --- megatron/data/data_utils.py | 2 ++ megatron/model/transformer.py | 5 ++- megatron/neox_arguments/neox_args.py | 9 +++++ tools/ckpts/convert_hf_llama_to_neox.py | 44 ++++++++++--------------- 4 files changed, 33 insertions(+), 27 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 4de1dae44..533311451 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -89,6 +89,8 @@ def build_the_dataset( precompute_indexed_dataset = None if reward_prefix is not None: reward_dataset = make_indexed_dataset(reward_prefix, data_impl, skip_warmup) + else: + reward_dataset = None elif dataset_impl == "pairwise": pos_indexed_dataset = make_indexed_dataset( pos_data_prefix, data_impl, skip_warmup diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 523cbe4cf..194fef6bf 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -139,6 +139,7 @@ def __init__( skip_bias_add=True, MOE=MOE, MoE_mp_size=MoE_mp_size, + bias=neox_args.use_bias_in_mlp, ) # Project back to h. self.linear2 = mpu.RowParallelLinear( @@ -151,6 +152,7 @@ def __init__( skip_bias_add=True, MOE=MOE, MoE_mp_size=MoE_mp_size, + bias=neox_args.use_bias_in_mlp, ) def forward(self, hidden_states): @@ -230,7 +232,7 @@ def __init__( # skip_bias_add=False, # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here # ) - else: # Not using cross entropy loss for RMs + else: # Not using cross entropy loss for RMs self.rm_linear = mpu.RowParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, @@ -1032,6 +1034,7 @@ def get_mlp(mlp_type, **kw): init_method=init_method, output_layer_init_method=output_layer_init_method, parallel_output=self.gpt_j_residual, + multiple_of=neox_args.mlp_multiple_of, **kw, ) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index a50f28a7a..554b57ac0 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -124,6 +124,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Transformer intermediate size. Default = 4h """ + mlp_multiple_of: int = 256 + """ + force mlp size to be a multiple of this value + """ + expansion_factor: float = None """ Transformer intermediate size. Default = 4 @@ -438,6 +443,10 @@ class NeoXArgsModel(NeoXArgsTemplate): """ If false, attn_linear (e.g. QKVO) will not have bias terms """ + use_bias_in_mlp: bool = True + """ + If false, mlps will not have bias terms + """ mlp_type: str = "regular" """ diff --git a/tools/ckpts/convert_hf_llama_to_neox.py b/tools/ckpts/convert_hf_llama_to_neox.py index 2adddb19d..21249995b 100644 --- a/tools/ckpts/convert_hf_llama_to_neox.py +++ b/tools/ckpts/convert_hf_llama_to_neox.py @@ -85,38 +85,30 @@ def convert_model(hf_state_dict, hf_config, tp_ranks): # --- mlp --- # Do SwiGLU weights... # w1... - for i, chunk in enumerate( - torch.chunk( - hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"], - tp_ranks, - dim=0, + for i, (w1, w3) in enumerate( + zip( + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"], + tp_ranks, + dim=0, + ), + torch.chunk( + hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"], + tp_ranks, + dim=0, + ), ) ): conv_state_dicts[i][ - f"sequential.{layer_num+2}.mlp.w1.weight" - ] = chunk.clone().detach() + f"sequential.{layer_num+2}.mlp.linear1.weight" + ] = torch.cat([w3.clone().detach(), w1.clone().detach()], dim=0) print( f"model.layers.{layer_num}.mlp.gate_proj.weight", hf_state_dict[f"model.layers.{layer_num}.mlp.gate_proj.weight"].shape, - f"sequential.{layer_num+2}.mlp.w1.weight", - conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w1.weight"].shape, - ) - # w3... - for i, chunk in enumerate( - torch.chunk( - hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"], - tp_ranks, - dim=0, - ) - ): - conv_state_dicts[i][ - f"sequential.{layer_num+2}.mlp.w3.weight" - ] = chunk.clone().detach() - print( f"model.layers.{layer_num}.mlp.up_proj.weight", hf_state_dict[f"model.layers.{layer_num}.mlp.up_proj.weight"].shape, f"sequential.{layer_num+2}.mlp.w3.weight", - conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w3.weight"].shape, + conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.linear1.weight"].shape, ) # w2 (output)... for i, chunk in enumerate( @@ -127,13 +119,13 @@ def convert_model(hf_state_dict, hf_config, tp_ranks): ) ): conv_state_dicts[i][ - f"sequential.{layer_num+2}.mlp.w2.weight" + f"sequential.{layer_num+2}.mlp.linear2.weight" ] = chunk.clone().detach() print( f"model.layers.{layer_num}.mlp.down_proj.weight", hf_state_dict[f"model.layers.{layer_num}.mlp.down_proj.weight"].shape, - f"sequential.{layer_num+2}.mlp.w2.weight", - conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.w2.weight"].shape, + f"sequential.{layer_num+2}.mlp.linear2.weight", + conv_state_dicts[0][f"sequential.{layer_num+2}.mlp.linear2.weight"].shape, ) # --- norm --- for i in range(tp_ranks): From 792570e12b6449f84c9f31748125df7b9a0e22b0 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 21 Jun 2024 15:00:51 -0500 Subject: [PATCH 18/28] - add different packing impl (Unpacked, packing until overflow) - fix labels to also have valid/test implementations - fix label masking in _get_batch to also include anything from get_ltor_masks_and_position_ids --- megatron/data/data_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 7c13131ad..aedffe063 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -59,7 +59,10 @@ def build_the_dataset( name, data_impl, pack_impl, +<<<<<<< HEAD dataset_impl, +======= +>>>>>>> 9ee4a8f6 (- add different packing impl (Unpacked, packing until overflow)) allow_chopped, num_samples, seq_length, From 4eca43f3c2d3932a0401533fbe094eeac83474be Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 24 Jun 2024 20:27:37 -0500 Subject: [PATCH 19/28] - Add metrics to forward step to add DPO specific metrics that are useful (accuracy, etc) - Add reference model setup for DPO - Add pairwise dataset for positive/negative pairs - Add DPO loss --- megatron/data/data_utils.py | 5 ----- megatron/training.py | 1 + 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index aedffe063..edf48b066 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -59,10 +59,7 @@ def build_the_dataset( name, data_impl, pack_impl, -<<<<<<< HEAD dataset_impl, -======= ->>>>>>> 9ee4a8f6 (- add different packing impl (Unpacked, packing until overflow)) allow_chopped, num_samples, seq_length, @@ -130,7 +127,6 @@ def build_the_dataset( print_rank_0(" no. of documents:{}".format(total_num_of_documents)) dataset = None documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) - if dataset_impl == "gpt2": dataset = GPT2Dataset( name, @@ -163,7 +159,6 @@ def build_the_dataset( pos_ref_dataset=pos_ref_dataset, neg_ref_dataset=neg_ref_dataset, ) - return dataset diff --git a/megatron/training.py b/megatron/training.py index 8f41fa594..280f96f99 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1000,6 +1000,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): neox_args.iteration = load_checkpoint( neox_args=neox_args, model=model, + reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, iteration=iteration, From 243b716acb287ad48d4e33ae6a652bb7e71baeb1 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Tue, 25 Jun 2024 12:08:27 -0500 Subject: [PATCH 20/28] - Bugfixes from upstreaming.... --- megatron/training.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/training.py b/megatron/training.py index 280f96f99..8f41fa594 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1000,7 +1000,6 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): neox_args.iteration = load_checkpoint( neox_args=neox_args, model=model, - reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, iteration=iteration, From 3e966b0312d395961d80d03c5dcf65288630ae35 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 28 Jun 2024 09:55:23 -0500 Subject: [PATCH 21/28] - Add KTO --- megatron/data/data_utils.py | 22 +++ megatron/data/gpt2_dataset.py | 141 +++++++++------- megatron/neox_arguments/neox_args.py | 25 ++- megatron/text_generation_utils.py | 4 +- megatron/training.py | 156 +++++++++++++++++- .../preprocess_data_with_chat_template.py | 62 ++++++- 6 files changed, 329 insertions(+), 81 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index edf48b066..f3c23dc4d 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -70,6 +70,7 @@ def build_the_dataset( pos_label_prefix=None, neg_label_prefix=None, precompute_model_name=None, + reward_prefix=None, ): """Build train/valid/test datasets.""" if dataset_impl == "gpt2": @@ -84,6 +85,10 @@ def build_the_dataset( data_prefix + "_" + precompute_model_name, data_impl, skip_warmup ) precompute_indexed_dataset = precompute_indexed_dataset + else: + precompute_indexed_dataset = None + if reward_prefix is not None: + reward_dataset = make_indexed_dataset(reward_prefix, data_impl, skip_warmup) elif dataset_impl == "pairwise": pos_indexed_dataset = make_indexed_dataset( pos_data_prefix, data_impl, skip_warmup @@ -140,6 +145,8 @@ def build_the_dataset( allow_chopped=allow_chopped, build_index_mappings=build_index_mappings, label_dataset=label_dataset, + reward_dataset=reward_dataset, + ref_dataset=precompute_indexed_dataset, ) elif dataset_impl == "pairwise": dataset = PairwiseDataset( @@ -283,10 +290,13 @@ def build_weighted_datasets( for i, ( train_path, train_label_path, + train_reward_path, valid_path, valid_label_path, + valid_reward_path, test_path, test_label_path, + test_reward_path, pos_train_path, neg_train_path, pos_train_label_path, @@ -305,12 +315,21 @@ def build_weighted_datasets( neox_args.train_label_data_paths if neox_args.train_label_data_paths else [], + neox_args.train_reward_data_paths + if neox_args.train_reward_data_paths + else [], neox_args.valid_data_paths if neox_args.valid_data_paths else [], neox_args.valid_label_data_paths if neox_args.valid_label_data_paths else [], + neox_args.valid_reward_data_paths + if neox_args.valid_reward_data_paths + else [], neox_args.test_data_paths if neox_args.test_data_paths else [], neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], + neox_args.test_reward_data_paths + if neox_args.test_reward_data_paths + else [], neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], neox_args.pos_train_label_data_paths @@ -357,6 +376,7 @@ def build_weighted_datasets( pos_label_prefix=pos_train_label_path, neg_label_prefix=neg_train_label_path, precompute_model_name=neox_args.precompute_model_name, + reward_prefix=train_reward_path, ) ) @@ -380,6 +400,7 @@ def build_weighted_datasets( pos_label_prefix=pos_valid_label_path, neg_label_prefix=neg_valid_label_path, precompute_model_name=neox_args.precompute_model_name, + reward_prefix=valid_reward_path, ) ) @@ -403,6 +424,7 @@ def build_weighted_datasets( pos_label_prefix=pos_test_label_path, neg_label_prefix=neg_test_label_path, precompute_model_name=neox_args.precompute_model_name, + reward_prefix=test_reward_path, ) ) return train_datasets, valid_datasets, test_datasets diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index edba57df2..5d200fd72 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -41,6 +41,8 @@ def __init__( build_index_mappings=True, use_shared_fs=True, label_dataset=None, + reward_dataset=None, + ref_dataset=None, ): self.name = name @@ -48,9 +50,13 @@ def __init__( self.allow_chopped = allow_chopped self.indexed_dataset = indexed_dataset self.label_dataset = label_dataset + self.reward_dataset = reward_dataset + self.ref_dataset = ref_dataset self.seq_length = seq_length - # Checks + assert self.reward_dataset is None or ( + pack_impl == "unpacked" + ), "Reward dataset only supported with unpacked data." assert np.min(documents) >= 0 assert np.max(documents) < indexed_dataset.sizes.shape[0] @@ -90,77 +96,98 @@ def __getitem__(self, idx): offset_f = self.sample_idx[idx][1] offset_l = self.sample_idx[idx + 1][1] # Labels and texts are supposed to be fully in sync. - datasets = ( - [self.indexed_dataset] - if self.label_dataset is None - else [self.indexed_dataset, self.label_dataset] - ) + datasets = [self.indexed_dataset] + rw_indx = 1 + if self.label_dataset is not None: + rw_indx += 1 + datasets.append(self.label_dataset) + if self.reward_dataset is not None: + datasets.append(self.reward_dataset) + if self.ref_dataset is not None: + datasets.append(self.ref_dataset) samples = [] + sample_lengths = [] # If we are within the same document, just extract the chunk. for n, dataset in enumerate(datasets): if doc_index_f == doc_index_l: - samples.append( - dataset.get( - self.doc_idx[doc_index_f], - offset=offset_f, - length=offset_l - offset_f + 1, + if rw_indx == n: + # If we are in the reward dataset, we only need the last token. + samples.append( + dataset.get( + self.doc_idx[doc_index_f], offset=offset_l, length=1 + ) + ) + else: + rw = dataset.get(self.doc_idx[doc_index_f]) + samples.append( + np.array([rw[0] for _ in range(len(samples[-1]))]) ) - ) else: + if n != rw_indx: + # reset + sample_lengths = [] # Otherwise, get the rest of the initial document. - sample_list = [ - dataset.get(self.doc_idx[doc_index_f], offset=offset_f) - ] + if n == rw_indx: + rw = dataset.get(self.doc_idx[doc_index_f]) + sample_list = [ + np.array([rw[0] for _ in range(sample_lengths[0])]) + ] + else: + sample_list = [ + dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] + sample_lengths.append(len(sample_list[-1])) # Loop over all in between documents and add the entire document. for i in range(doc_index_f + 1, doc_index_l): - sample_list.append(dataset.get(self.doc_idx[i])) + if n == rw_indx: + rw = dataset.get(self.doc_idx[i]) + sample_list.append( + np.array([rw[0] for _ in range(sample_lengths[1 + i])]) + ) + else: + sample_list.append(dataset.get(self.doc_idx[i])) + sample_lengths.append(len(sample_list[-1])) # And finally add the relevant portion of last document. - sample_list.append( - dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) - ) - samples.append(np.concatenate(sample_list)) - if len(datasets) == 1: - if len(samples[0]) < (self.seq_length + 1): - # Pad with -100s so the masking function can ignore these. - samples[0] = np.pad( - samples[0], - (0, (self.seq_length + 1) - len(samples[0])), - mode="constant", - constant_values=-100, - ) - elif len(samples[0]) > (self.seq_length + 1): - # Check for overflow and truncate. - samples[0] = samples[0][: (self.seq_length + 1)] - return {"text": np.array(samples[0], dtype=np.int64)} - else: - if len(samples[0]) < (self.seq_length + 1): - # Pad with 0s, can use any number since it's masked. - samples[0] = np.pad( - samples[0], - (0, (self.seq_length + 1) - len(samples[0])), - mode="constant", - constant_values=0, - ) - # pad with -100s so we can mask it out - samples[1] = np.pad( - samples[1], - (0, (self.seq_length + 1) - len(samples[1])), + if n == rw_indx: + rw = dataset.get(self.doc_idx[doc_index_l]) + sample_list.append( + np.array([rw[0] for _ in range(sample_lengths[-1])]) + ) + else: + sample_list.append( + dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) + sample_lengths.append(len(sample_list[-1])) + samples.append(np.concatenate(sample_list)) + for i in range(len(samples)): + mask = (self.label_dataset is not None) and (i == 1) + if len(samples[i]) < (self.seq_length + 1): + # Pad + samples[i] = np.pad( + samples[i], + (0, (self.seq_length + 1) - len(samples[i])), mode="constant", - constant_values=-100, + constant_values=-100 if mask else 0, ) - elif len(samples[0]) > (self.seq_length + 1): - # Check for overflow and truncate. - samples[0] = samples[0][: (self.seq_length + 1)] - samples[1] = samples[1][: (self.seq_length + 1)] - return { - "text": np.array(samples[0], dtype=np.int64), - "label": np.array(samples[1], dtype=np.int64), - } - except IndexError: + elif len(samples[i]) > (self.seq_length + 1): + # Truncate + samples[i] = samples[i][: (self.seq_length + 1)] + ret = {"text": np.array(samples[0], dtype=np.int64)} + next_idx = 1 + if self.label_dataset is not None: + ret["label"] = np.array(samples[next_idx], dtype=np.int64) + next_idx += 1 + if self.reward_dataset is not None: + ret["reward"] = np.array(samples[next_idx], dtype=np.float32) + next_idx += 1 + if self.ref_dataset is not None: + ret["ref"] = np.array(samples[next_idx], dtype=np.float32) + return ret + except IndexError as err: new_idx = idx % len(self) print( - f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx}), error: {err}" ) return self[new_idx] diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 4fc43b11d..9ee9f1afd 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1047,9 +1047,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): Dataset implementation, can be one of "gpt2" or "pairwise" """ - train_impl: Literal["normal", "dpo", "rm"] = "normal" + train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal" """ - Training implementation, can be one of "normal", "dpo", or "rm" + Training implementation, can be one of "normal", "dpo", "kto", or "rm" """ dpo_fp32: bool = True @@ -1057,16 +1057,29 @@ class NeoXArgsTraining(NeoXArgsTemplate): Whether to cast logits to fp32 for DPO loss calculation. """ + kto_fp32: bool = True + """ + Whether to cast logits to fp32 for KTO loss calculation. + """ + + kto_desirable_weight: float = 1.0 + """ + Weight for desirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + """ + + kto_undesirable_weight: float = 1.0 + """ + Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + """ + dpo_beta: float = 0.1 """ Beta value for DPO """ - z_loss: float = 0.0 + kto_beta: float = 0.1 """ - Z-loss parameter, only implemented for RM training currently. - https://arxiv.org/pdf/2204.02311 - https://arxiv.org/pdf/2309.10305 + Beta value for KTO """ allow_chopped: bool = True diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index f8d17cf10..1708ac0ab 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -19,6 +19,7 @@ import copy import json +import math import os import time from typing import List, Union @@ -876,7 +877,8 @@ def precompute_logits(neox_args, model): out_dataset = make_builder(out_path + ".bin", neox_args.data_impl) out_dataset._dtype = np.float32 i = 0 - while i < len(dataset): + # Not sure why this requires a multiple of 8 but... + while i < int(math.ceil(len(dataset) / 8.0) * 8): start = time.time() model.module.clear_cache() # clear kv cache between batches if is_mp_rank_0(): diff --git a/megatron/training.py b/megatron/training.py index 8f41fa594..8fe0c1acf 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -317,7 +317,7 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - if neox_args.train_impl == "normal": + if neox_args.train_impl in ["normal", "kto"]: keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] elif neox_args.train_impl in ["dpo", "rm"]: keys = ( @@ -340,6 +340,25 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) + elif neox_args.train_impl == "kto": + tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys, + data=data, + datatype=datatype, + ) + # Remove the last token from the reward since we predict the next token, so + # Reward of will be based on the label of + rw_data = mpu.broadcast_data(["reward"], data, torch.float)["reward"][ + :, :-1 + ].contiguous() + ref_data = ( + mpu.broadcast_data(["ref"], data, torch.float)["ref"][:, :-1].contiguous() + if neox_args.precompute_model_name + else None + ) + return tup + (rw_data, ref_data) elif neox_args.train_impl in ["dpo", "rm"]: pos_tup = _get_batch( neox_args=neox_args, @@ -462,13 +481,15 @@ def mb_moe_loss_func(args, loss_mask, output_tensor=None): return averaged_lbl, loss_dict -def get_pos_neg_logp(logits, labels, force_fp32=False): +def get_logp(logits, labels, force_fp32=False): if force_fp32: logits = logits.float() logp = logits.log_softmax(dim=-1) - per_token_logp = torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) - # Split to pos/neg... - return torch.chunk(per_token_logp, 2, 0) + return torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + +def get_pos_neg_logp(logits, labels, force_fp32=False): + return torch.chunk(get_logp(logits, labels, force_fp32), 2, 0) def forward_step( @@ -493,6 +514,16 @@ def forward_step( tokens, labels, loss_mask, attention_mask, position_ids = get_batch( neox_args=neox_args, data_iterator=data_iterator ) + elif neox_args.train_impl == "kto": + ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + rewards, + ref_logp, + ) = get_batch(neox_args=neox_args, data_iterator=data_iterator) if neox_args.train_impl in ["dpo", "rm"]: tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( neox_args=neox_args, data_iterator=data_iterator @@ -620,6 +651,115 @@ def forward_step( ref_logrations = ref_pos - ref_neg logits = pi_logrations - ref_logrations loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean() + elif neox_args.train_impl == "kto": + # Based on https://github.com/huggingface/trl/blob/main/trl/trainer/kto_trainer.py + # Except we don't have an extra input for KL logp, we just split the batch in half + with torch.no_grad(): + # So we can gather token logps... + token_logp_labels = labels.clone() + token_logp_labels[token_logp_labels == -100] = 0 + if ref_logp is None: + # Did not precompute logits.... + ref_maybe_tuple = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + ref_outputs, _ = ref_maybe_tuple + else: + ref_outputs = ref_maybe_tuple + # gather across tensor parallel group + ref_outputs = gather_from_model_parallel_region(ref_outputs) + + ref_logp = get_logp(ref_outputs, token_logp_labels, neox_args.kto_fp32) + else: + print(f"REF LOGP: {ref_logp.clone().detach().mean()}") + ref_logp = ref_logp * loss_mask + scaling = (rewards.sum(-1) > 0.001).float() * neox_args.kto_desirable_weight + scaling += ( + rewards.sum(-1) < -0.001 + ).float() * neox_args.kto_undesirable_weight + pos_mask = (rewards > 0.001).float() + neg_mask = (rewards < -0.001).float() + chosen_maybe_tuple = model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(chosen_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + chosen_outputs, _ = chosen_maybe_tuple + else: + chosen_outputs = chosen_maybe_tuple + chosen_outputs = gather_from_model_parallel_region(chosen_outputs) + chosen_logp = get_logp(chosen_outputs, token_logp_labels, neox_args.kto_fp32) + chosen_logp = chosen_logp * loss_mask + with torch.no_grad(): + # Collect metrics... + metrics["ref_logp"] = ref_logp.clone().detach().sum(-1).mean() + metrics["policy_logp"] = chosen_logp.clone().detach().sum(-1).mean() + metrics["pos_ref_logp"] = ( + (ref_logp * pos_mask).clone().detach().sum(-1).mean() + ) + metrics["neg_ref_logp"] = ( + (ref_logp * neg_mask).clone().detach().sum(-1).mean() + ) + metrics["pos_policy_logp"] = ( + (chosen_logp * pos_mask).clone().detach().sum(-1).mean() + ) + metrics["neg_policy_logp"] = ( + (chosen_logp * neg_mask).clone().detach().sum(-1).mean() + ) + metrics["kl"] = ( + chosen_logp.clone().detach() - ref_logp.clone().detach() + ).sum() / loss_mask.sum() + policy_rewards = ( + neox_args.kto_beta + * rewards + * (chosen_logp.clone().detach() - ref_logp.clone().detach()) + ) + reward_acc = (policy_rewards.sum(-1) > 0.0).float() + metrics["reward_acc"] = reward_acc.mean() + metrics["policy_rewards"] = policy_rewards.sum() + print(metrics) + pol_logp1, pol_logp2 = torch.chunk(chosen_logp, 2, 0) + ref_logp1, ref_logp2 = torch.chunk(ref_logp, 2, 0) + reward1, reward2 = torch.chunk(rewards, 2, 0) + scaling1, scaling2 = torch.chunk(scaling, 2, 0) + kl1 = torch.clamp((pol_logp1 - ref_logp1).sum(-1), min=0).mean() + kl2 = torch.clamp((pol_logp2 - ref_logp2).sum(-1), min=0).mean() + log_ratio1 = pol_logp1 - ref_logp1 + log_ratio2 = pol_logp2 - ref_logp2 + + # TODO: Add pack_until_overflow sequence support + loss = ( + 0.5 + * scaling1.mean(-1) + * ( + 1 + - F.sigmoid( + ( + neox_args.kto_beta + * reward1.mean(-1) + * (log_ratio1.sum(-1) - kl2.clone().detach()) + ) + ) + ) + ) + ( + 0.5 + * scaling2.mean(-1) + * ( + 1 + - F.sigmoid( + ( + neox_args.kto_beta + * reward2.mean(-1) + * (log_ratio2.sum(-1) - kl1.clone().detach()) + ) + ) + ) + ) + # print(loss.shape) + loss = loss.mean() + # print(loss.shape) if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: @@ -922,9 +1062,9 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" - needs_reference_model = (neox_args.train_impl == "dpo") and ( - neox_args.precompute_model_name is None - ) + needs_reference_model = ( + (neox_args.train_impl == "dpo") and (neox_args.precompute_model_name is None) + ) or ((neox_args.train_impl == "kto") and (neox_args.precompute_model_name is None)) model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: reference_model = get_model(neox_args=neox_args, use_cache=use_cache) diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 3db283ca4..f55561bd8 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -156,7 +156,15 @@ def encode(self, text): self.args.only_last, self.args.for_rm, ) - ids[key] = (text_ids, label_ids) + if self.args.reward_key is not None: + reward = text[self.args.reward_key] + if self.args.binary_reward: + reward = [1] if reward else [-1] + elif type(reward) == float: + reward = [reward] + ids[key] = (text_ids, label_ids, reward) + else: + ids[key] = (text_ids, label_ids, None) return ids, len(text) @@ -197,6 +205,17 @@ def get_args(): help="If set, this will mask everything except the last turn in the chat.", action="store_true", ) + group.add_argument( + "--reward-key", + type=str, + default=None, + help="Optional: key to use for reward data in the input data.", + ) + group.add_argument( + "--binary-reward", + help="If set, this will treat the reward data as a boolean.", + action="store_true", + ) group.add_argument( "--num-docs", default=None, @@ -311,19 +330,36 @@ def main(): assert ( key + "_label" not in args.jsonl_keys ), "label should not be included as it will be generated according to the mask." - key += "_label" - output_bin_files[key] = "{}_{}_{}.bin".format( - args.output_prefix, key, "document" + label_key = key + "_label" + output_bin_files[label_key] = "{}_{}_{}.bin".format( + args.output_prefix, label_key, "document" + ) + output_idx_files[label_key] = "{}_{}_{}.idx".format( + args.output_prefix, label_key, "document" + ) + builders[label_key] = indexed_dataset.make_builder( + output_bin_files[label_key], + impl=args.dataset_impl, + vocab_size=tokenizer.vocab_size, + ) + builders[label_key]._dtype = np.int32 + if args.reward_key is not None: + assert ( + key + "_reward" not in args.jsonl_keys + ), "reward should not be included as it will be generated from the data." + reward_key = key + "_reward" + output_bin_files[reward_key] = "{}_{}_{}.bin".format( + args.output_prefix, reward_key, "document" ) - output_idx_files[key] = "{}_{}_{}.idx".format( - args.output_prefix, key, "document" + output_idx_files[reward_key] = "{}_{}_{}.idx".format( + args.output_prefix, reward_key, "document" ) - builders[key] = indexed_dataset.make_builder( - output_bin_files[key], + builders[reward_key] = indexed_dataset.make_builder( + output_bin_files[reward_key], impl=args.dataset_impl, vocab_size=tokenizer.vocab_size, ) - builders[key]._dtype = np.int32 + builders[reward_key]._dtype = np.int32 # actually do tokenization proc_start = time.time() @@ -339,17 +375,25 @@ def main(): for key, conv in doc.items(): tokens = conv[0] token_mask = conv[1] + reward = conv[2] builders[key].add_item(np.array(tokens, dtype=builders[key].dtype)) builders[key + "_label"].add_item( np.array(token_mask, dtype=builders[key + "_label"].dtype) ) + if args.reward_key is not None: + builders[key + "_reward"].add_item( + np.array(reward, dtype=builders[key + "_reward"].dtype) + ) # add indx... builders[key].end_document() builders[key + "_label"].end_document() + if args.reward_key is not None: + builders[key + "_reward"].end_document() if i == 1: print("key: ", key) print("tokens: ", tokens) print("token_mask: ", token_mask) + print("Reward: ", reward) # log progress if i % args.log_interval == 0: current = time.time() From fadf3e6aaf3306c8d3423641bf473f1d7bfb247b Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Fri, 28 Jun 2024 10:02:34 -0500 Subject: [PATCH 22/28] Update README.md --- configs/README.md | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/configs/README.md b/configs/README.md index 3102a34d1..71a09ebea 100644 --- a/configs/README.md +++ b/configs/README.md @@ -235,7 +235,32 @@ Additional DeepSpeed settings besides those mentioned above should be wrapped in "eval_iters": 10, ``` -However, if you want to use DPO style training you'll need to set pos/neg data paths instead of a single one, e.g. +For KTO style training, you'll need to add the reward & label data path, e.g.: + +```yaml + "data_impl": "mmap", + # Suggested data paths when using GPT-NeoX locally + "train_data_path": "data/enwik8/enwik8_text_document", + "train_label_data_path": "data/enwik8/enwik8_text_label_document", + "train_reward_data_path": "data/enwik8/enwik8_text_reward_document", + "test_data_path": "data/enwik8/enwik8_text_document", + "test_label_data_path": "data/enwik8/enwik8_text_label_document", + "test_reward_data_path": "data/enwik8/enwik8_text_reward_document", + "valid_data_path": "data/enwik8/enwik8_text_document", + "valid_label_data_path": "data/enwik8/enwik8_text_label_document", + "valid_reward_data_path": "data/enwik8/enwik8_text_reward_document", + "vocab_file": "data/gpt2-vocab.json", + "merge_file": "data/gpt2-merges.txt", + "save": "checkpoints", + "load": "checkpoints", + "tensorboard_dir": "tensorboard", + "log_dir": "logs", + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, +``` + +For DPO style training, you'll need to set pos/neg data paths instead of a single one, e.g. ```yaml "dataset_impl": "pairwise", From b4c5a91226c806010f0005341e02e4e8328d48ce Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sun, 8 Sep 2024 17:45:47 -0700 Subject: [PATCH 23/28] precommit --- megatron/text_generation_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 1708ac0ab..293cbaabc 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -877,7 +877,8 @@ def precompute_logits(neox_args, model): out_dataset = make_builder(out_path + ".bin", neox_args.data_impl) out_dataset._dtype = np.float32 i = 0 - # Not sure why this requires a multiple of 8 but... + + # TODO: Not sure why this requires a multiple of 8? Investigate later. while i < int(math.ceil(len(dataset) / 8.0) * 8): start = time.time() model.module.clear_cache() # clear kv cache between batches From d701aeaa2a0b5b9b35387ad356eb52a3a5fcd2c3 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Sun, 8 Sep 2024 21:32:37 -0500 Subject: [PATCH 24/28] - KTO implementation from main... --- megatron/data/data_utils.py | 1 - megatron/data/gpt2_dataset.py | 1 - megatron/training.py | 32 ++----------------- .../preprocess_data_with_chat_template.py | 23 ------------- 4 files changed, 3 insertions(+), 54 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index f3c23dc4d..50d3e8d63 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -522,7 +522,6 @@ def build_train_valid_test_data_iterators(neox_args): ) if neox_args.weight_by_num_documents: - # gets the number of documents in each datapath get_num_docs_list = lambda datasets: [ dataset.indexed_dataset.sizes.shape[0] for dataset in datasets diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 5d200fd72..86d0c5ad3 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -148,7 +148,6 @@ def __getitem__(self, idx): sample_list.append(dataset.get(self.doc_idx[i])) sample_lengths.append(len(sample_list[-1])) # And finally add the relevant portion of last document. - if n == rw_indx: rw = dataset.get(self.doc_idx[doc_index_l]) sample_list.append( diff --git a/megatron/training.py b/megatron/training.py index 8fe0c1acf..486947624 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -319,7 +319,7 @@ def get_batch(neox_args, data_iterator): # Items and their type. if neox_args.train_impl in ["normal", "kto"]: keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] - elif neox_args.train_impl in ["dpo", "rm"]: + elif neox_args.train_impl == "dpo": keys = ( [["pos", "pos_label"], ["neg", "neg_label"]] if neox_args.pos_train_label_data_paths @@ -374,7 +374,7 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) - if (neox_args.precompute_model_name) and (neox_args.train_impl == "dpo"): + if neox_args.precompute_model_name: ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float) else: ref_data = {"pos_ref": None} @@ -565,32 +565,6 @@ def forward_step( else: moe_loss = 0.0 loss = main_loss + moe_loss - elif neox_args.train_impl == "rm": - maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) - if type(maybe_tuple) is tuple: - outputs, _ = maybe_tuple - else: - outputs = maybe_tuple - pos, neg = torch.chunk(outputs, 2, 0) - pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) - # We assume that each pos, neg pair occur in the same order - # e.g. second nonzero pos is the corresponding second nonzero neg - # and that there are also an equal number of pos and neg in each sequence. - pos_indx = pos_loss_mask.nonzero() - neg_indx = neg_loss_mask.nonzero() - # indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index. - pos_indx = pos_indx[:, 1].unsqueeze(1) - neg_indx = neg_indx[:, 1].unsqueeze(1) - pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx) - neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx) - with torch.no_grad(): - metrics["pos_values"] = pos.clone().detach().mean() - metrics["neg_values"] = neg.clone().detach().mean() - metrics["margin"] = (pos - neg).clone().detach().mean() - metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean() - loss = (-F.logsigmoid(pos - neg).mean()) + ( - (neox_args.z_loss * (pos**2 + neg**2)).mean() - ) elif neox_args.train_impl == "dpo": # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 with torch.no_grad(): @@ -801,7 +775,7 @@ def get_model(neox_args, use_cache=False): model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True if neox_args.train_impl != "rm" else False, + parallel_output=True, topology=mpu.get_topology(), use_cache=use_cache, ) diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index f55561bd8..91b16dcc2 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -81,7 +81,6 @@ def build_chat( apply_mask: bool, tokenizer: PreTrainedTokenizer, only_last_turn: bool = False, - for_rm: bool = False, ) -> Tuple[List[int], List[int]]: """ Build a chat from a list of dictionaries. Each dictionary should have a "role" and "content" key, this follows the @@ -92,28 +91,12 @@ def build_chat( :param apply_mask: Whether to apply a loss mask to the chat, if False, all tokens will be included in the loss :param tokenizer: A HF tokenizer :param only_last_turn: Whether to only include the last turn in the chat, needed for some fine-tuning tasks - :param for_rm: Whether this is for a reward model or not, this will mask everything except EOS token. - If you need a more complicated setup, you can modify this function to suit your needs. """ tokens = [] mask = [] if apply_mask is False: tokens = tokenizer.apply_chat_template(chat) mask = tokens - if tokenizer.eos_token_id is not None: - mask.append(tokenizer.eos_token_id) - tokens.append(tokenizer.eos_token_id) - return tokens, mask - elif for_rm: - tokens = tokenizer.apply_chat_template(chat) - mask = [-100] * len(tokens) - if tokenizer.eos_token_id is not None: - mask.append(tokenizer.eos_token_id) - tokens.append(tokenizer.eos_token_id) - else: - raise ValueError( - "Tokenizer does not have an EOS token, unable to determine good mask, please edit and make your own." - ) return tokens, mask for i, turn in enumerate(chat): add_gen = ( @@ -154,7 +137,6 @@ def encode(self, text): not self.args.no_mask, Encoder.tokenizer, self.args.only_last, - self.args.for_rm, ) if self.args.reward_key is not None: reward = text[self.args.reward_key] @@ -189,11 +171,6 @@ def get_args(): help="If set, this will not mask any tokens in the input data.", action="store_true", ) - group.add_argument( - "--for-rm", - help="If set, this will mask everything except the last token in the chat.", - action="store_true", - ) group.add_argument( "--generation-role", type=str, From 9daa1bced5c17718cd73373d00bd41b071d0b611 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 9 Sep 2024 11:30:49 -0500 Subject: [PATCH 25/28] initial changes... --- megatron/data/gpt2_dataset.py | 20 +++++++++++-------- .../preprocess_data_with_chat_template.py | 19 +++++++++++++++++- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 86d0c5ad3..abe819089 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -103,6 +103,8 @@ def __getitem__(self, idx): datasets.append(self.label_dataset) if self.reward_dataset is not None: datasets.append(self.reward_dataset) + else: + rw_indx = -1 if self.ref_dataset is not None: datasets.append(self.ref_dataset) samples = [] @@ -112,15 +114,17 @@ def __getitem__(self, idx): if doc_index_f == doc_index_l: if rw_indx == n: # If we are in the reward dataset, we only need the last token. + rw = dataset.get(self.doc_idx[doc_index_f]) samples.append( - dataset.get( - self.doc_idx[doc_index_f], offset=offset_l, length=1 - ) + np.array([rw[0] for _ in range(len(samples[-1]))]) ) else: - rw = dataset.get(self.doc_idx[doc_index_f]) samples.append( - np.array([rw[0] for _ in range(len(samples[-1]))]) + dataset.get( + self.doc_idx[doc_index_f], + offset=offset_l, + length=offset_l - offset_f + 1, + ) ) else: if n != rw_indx: @@ -130,7 +134,7 @@ def __getitem__(self, idx): if n == rw_indx: rw = dataset.get(self.doc_idx[doc_index_f]) sample_list = [ - np.array([rw[0] for _ in range(sample_lengths[0])]) + np.array([rw[0] for _ in range(sample_lengths.pop(0))]) ] else: sample_list = [ @@ -142,7 +146,7 @@ def __getitem__(self, idx): if n == rw_indx: rw = dataset.get(self.doc_idx[i]) sample_list.append( - np.array([rw[0] for _ in range(sample_lengths[1 + i])]) + np.array([rw[0] for _ in range(sample_lengths.pop(0))]) ) else: sample_list.append(dataset.get(self.doc_idx[i])) @@ -151,7 +155,7 @@ def __getitem__(self, idx): if n == rw_indx: rw = dataset.get(self.doc_idx[doc_index_l]) sample_list.append( - np.array([rw[0] for _ in range(sample_lengths[-1])]) + np.array([rw[0] for _ in range(sample_lengths.pop(0))]) ) else: sample_list.append( diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 91b16dcc2..1083155fa 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -81,6 +81,7 @@ def build_chat( apply_mask: bool, tokenizer: PreTrainedTokenizer, only_last_turn: bool = False, + for_rm: bool = False, ) -> Tuple[List[int], List[int]]: """ Build a chat from a list of dictionaries. Each dictionary should have a "role" and "content" key, this follows the @@ -98,6 +99,17 @@ def build_chat( tokens = tokenizer.apply_chat_template(chat) mask = tokens return tokens, mask + elif for_rm: + tokens = tokenizer.apply_chat_template(chat) + mask = [-100] * len(tokens) + if tokenizer.eos_token_id is not None: + mask.append(tokenizer.eos_token_id) + tokens.append(tokenizer.eos_token_id) + else: + raise ValueError( + "Tokenizer does not have an EOS token, unable to determine good mask, please edit and make your own." + ) + return tokens, mask for i, turn in enumerate(chat): add_gen = ( False if i == len(chat) - 1 else chat[i + 1]["role"] == generation_role @@ -105,7 +117,6 @@ def build_chat( chat_tokens = tokenizer.apply_chat_template( chat[: i + 1], add_generation_prompt=add_gen )[len(tokens) :] - # remove previous stuff... tokens.extend(chat_tokens) if only_last_turn and (i != len(chat) - 1): @@ -137,6 +148,7 @@ def encode(self, text): not self.args.no_mask, Encoder.tokenizer, self.args.only_last, + self.args.for_rm, ) if self.args.reward_key is not None: reward = text[self.args.reward_key] @@ -171,6 +183,11 @@ def get_args(): help="If set, this will not mask any tokens in the input data.", action="store_true", ) + group.add_argument( + "--for-rm", + help="If set, this will mask everything except the last token in the chat.", + action="store_true", + ) group.add_argument( "--generation-role", type=str, From 486840aac1e96fda6bfa055a8678f843a864ef2f Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 9 Sep 2024 11:32:10 -0500 Subject: [PATCH 26/28] pre-commit... --- megatron/data/gpt2_dataset.py | 1 + tools/datasets/preprocess_data_with_chat_template.py | 1 + 2 files changed, 2 insertions(+) diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index abe819089..e37c558d2 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -53,6 +53,7 @@ def __init__( self.reward_dataset = reward_dataset self.ref_dataset = ref_dataset self.seq_length = seq_length + # Checks assert self.reward_dataset is None or ( pack_impl == "unpacked" diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 1083155fa..3e44cf261 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -188,6 +188,7 @@ def get_args(): help="If set, this will mask everything except the last token in the chat.", action="store_true", ) + group.add_argument( "--generation-role", type=str, From f39ccbcd6ba01ba5b909e60e7e96a10043110974 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 9 Sep 2024 14:48:05 -0500 Subject: [PATCH 27/28] hotfix + data loader update --- megatron/data/data_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 50d3e8d63..335bda061 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -89,6 +89,8 @@ def build_the_dataset( precompute_indexed_dataset = None if reward_prefix is not None: reward_dataset = make_indexed_dataset(reward_prefix, data_impl, skip_warmup) + else: + reward_dataset = None elif dataset_impl == "pairwise": pos_indexed_dataset = make_indexed_dataset( pos_data_prefix, data_impl, skip_warmup From cd6b2f2735e7e23c8a208cc298ec9fe0048d272e Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Sat, 14 Sep 2024 07:55:25 +0000 Subject: [PATCH 28/28] precommit --- megatron/neox_arguments/neox_args.py | 2 +- megatron/training.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 852f49ce5..5194047d5 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1061,7 +1061,7 @@ class NeoXArgsTraining(NeoXArgsTemplate): """ Whether to cast logits to fp32 for DPO loss calculation. """ - + dpo_reference_free: bool = False """ Whether to use reference-free DPO. diff --git a/megatron/training.py b/megatron/training.py index 9b8a99816..cca3ab3f9 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -1048,7 +1048,9 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): """Setup model and optimizer.""" needs_reference_model = ( - (neox_args.train_impl == "dpo") and (neox_args.precompute_model_name is None) and (not neox_args.dpo_reference_free) + (neox_args.train_impl == "dpo") + and (neox_args.precompute_model_name is None) + and (not neox_args.dpo_reference_free) ) or ((neox_args.train_impl == "kto") and (neox_args.precompute_model_name is None)) model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: