diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index f0ea55eeb..0ab7a2e5c 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 11a5537 + Default = 621ab25 current git hash of repository @@ -1378,6 +1378,122 @@ Training Arguments +- **replay_config**: dict + + Default = None + + Dictionary storing the replay config. + + + +- **is_replay_enabled**: bool + + Default = False + + Triggers the logic for replay. It is important to deal with replay separately from the general "train_data_paths" logic, as replay + requires reusing the same idx files to know what data was seen the first time a dataset was originally trained on. + If one attempts to do replay by just putting the datasets to be replayed in the train_data_paths instead of the replay params: + - If the exact same dataset files are used as during the 1st time it was seen, and the number of iterations on the replay buffer + corresponds to as many epochs on a replay dataset as the non-replay training, the data will be seen in exactly the same order as + the first time if the seed and sequence length is the same. + - For similar reasons, replaying multiple times on the same dataset (e.g. across multiple tasks) with the same number of epochs + on the replay dataset will lead to seeing the same data in the same order. + - If a different dataset is used for replay (e.g. different shard of Pile), then the shuffling will lead to completely different + indices, which will lead to potentially significant proportions of data being unseen if the original training on the replay dataset + did not see all of it, e.g. when training on 300B tokens of the GPT2-tokenised Pile which contains a few dozen billion more tokens, + then sharding the full dataset into smaller ones. + + + +- **replay_idx_paths_prefixes**: list + + Default = None + + List of paths prefixes to retrieve replay dataset idx files. Those idx files should have been generated when originally training on the dataset + being used for replay. They contain in the filename the number of samples potentially seen during pretraining, the sequence length and the + seed. The exact files (shuffle_idx, sample_idx and doc_idx) will be automatically derived from the prefix. Similarly, the data paths will + be generated from the prefixes. + The *_idx files are important as it allows one to know what data was seen in the dataset during training. If those files are missing, you can + regenerate them by relaunching the same training script (most importantly, config) used originally to pretrain on a given dataset. You + can add an exit(0) statement in training.py in pretrain() after the call to build_train_valid_test_data_iterators(neox_args=neox_args). + It is crucial to use the same dataset shard, sequence length, number of iterations, seed, and potentially batch size, or the indices + generated may not be the same. + For a single replay data source, the value passed looks like ["data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s"] and + the files at the following paths (the paths will be constructed during execution from the prefix), must exist: + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_doc_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_sample_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_shuffle_idx.npy" + "data/mydataset/train/mydataset" + + + +- **replay_data_to_idx_paths**: dict + + Default = None + + As indicated above, gets automatically built from the replay_idx_paths_prefixes by appending to it "_doc_idx.npy", "_sample_idx.npy" and + "_shuffle_idx.npy". It generates a dict of dict, with the data paths as keys, and dictionaries mapping each data path to the relevant + doc_idx, sample_idx and shuffle_idx file paths. Note that these files must exist at the relevant paths. + + + +- **replay_data_paths**: list + + Default = None + + As indicated above, gets automatically built from the replay_idx_paths_prefixes by removing the information about the idx files to retain + only the path to the dataset itself. + + + +- **replay_data_weights**: list + + Default = None + + List of 'weights' that decide how often to sample from each replay dataset when building the replay buffer. + + + +- **replay_idx_offsets**: list + + Default = None + + List of indices that decide where to start in the list of seen indices during pretraining on each replay dataset when building + the replay buffer. For example, when training originally on a dataset seeing 10000 samples, this allows to start looking at the + RESHUFFLED indices starting from idx replay_idx_offsets[i] for replay dataset i. + If not set, this will uniformly sample among all replay datasets. + + + +- **replay_fraction**: float + + Default = 0.05 + + Fraction of a batch dedicated to doing replay. For example, 0.1 means that in a batch of 100, 19 samples will come from the replay + buffer. Note that this means that if we train on 100B tokens, we will have only used 90B tokens from the datasets specified in + train_data_paths. + + + +- **replay_reshuffle_idx**: bool + + Default = True + + When index files are loaded from those the dataset was originally pretrained on, they will follow the exact same sequence of samples + seen when training on that dataset the first time if this is set to False. If True, the indices are reshuffled to prevent that. + + + +- **replay_seed**: int + + Default = 1234 + + Seed used to reshuffle indices accessed when originally training on a dataset, that are used to do replay. This is useful in the case + where replay is done twice on as many passes over the dataset, in which case if the same seed is used, the replay buffers in both case + will be exactly the same. + + + - **weight_by_num_documents**: bool Default = False diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index bc5754cdb..65347e24f 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -61,6 +61,9 @@ def build_the_dataset( skip_warmup, build_index_mappings=True, label_prefix=None, + index_mapping_paths=None, + index_offset=0, + reshuffle_when_loading=True, ): """Build train/valid/test datasets.""" @@ -85,6 +88,9 @@ def build_the_dataset( seed, build_index_mappings=build_index_mappings, label_dataset=label_dataset, + index_mapping_paths=index_mapping_paths, + index_offset=index_offset, + reshuffle_when_loading=reshuffle_when_loading, ) return dataset @@ -191,6 +197,32 @@ def get_normalized_weights_and_num_samples( weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) return weights, weighted_num_samples +def get_normalized_weights_and_num_samples_with_replay( + weights: List[float], replay_weights: List[float], replay_fraction, num_samples: int +) -> Tuple[List[float], List[int]]: + # Normalize weights. weights correspond to the weights from the training data and replay_weights correspond + # to weights from the replay data. The idea is that we will be merge the weights provided for training data + # and replay data into the same array. We know that replay_weights should contribute replay_fraction of all + # weights, so we also need to normalise replay weights by replay_fraction and the rest by (1-replay_fraction). + weight_sum = sum(weights) + assert weight_sum > 0.0 + weights = [(weight / weight_sum) * (1-replay_fraction) for weight in weights] + + replay_weights_sum = sum(replay_weights) + assert replay_weights_sum > 0.0 + replay_weights = [(replay_weight / replay_weights_sum) * replay_fraction for replay_weight in replay_weights] + + # merge weights with the replay weights given the replay_fraction + weights = weights + replay_weights + + # Add 0.5% (the 1.005 factor) so in case the blending dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + weighted_num_samples = [] + for weight in weights: + weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) + return weights, weighted_num_samples + def build_weighted_datasets( neox_args, @@ -201,7 +233,21 @@ def build_weighted_datasets( valid_weights, test_weights, build_index_mappings=True, + concatenate_train_replay_paths=False, ): + + # The concatenate_train_replay_paths bool is necessary to avoid issues when this function gets called a second time. + if neox_args.is_replay_enabled and concatenate_train_replay_paths: + # Merge replay data paths into train data paths logic, but need to keep track of + # what paths in train_data_paths came from replay + num_replay_data_paths = len(neox_args.replay_data_paths) + num_non_replay_data_paths = len(neox_args.train_data_paths) + neox_args.train_data_paths += neox_args.replay_data_paths + else: + num_replay_data_paths = 0 + + assert not (neox_args.label_data_paths and neox_args.is_replay_enabled), "Simultaneous use of label data and replay is untested.\ + Remove assert at your own risk. You might want to add a replay_label_data_paths arg too if relevant." # build individual datasets train_datasets, valid_datasets, test_datasets = [], [], [] for i, (train_path, label_path, valid_path, test_path) in enumerate( @@ -213,19 +259,39 @@ def build_weighted_datasets( ) ): if train_path: - train_datasets.append( - build_the_dataset( - data_prefix=train_path, - name=f"train_{i}", - data_impl=neox_args.data_impl, - num_samples=train_num_samples[i], - seq_length=neox_args.seq_length, - seed=neox_args.seed, - skip_warmup=(not neox_args.mmap_warmup), - build_index_mappings=build_index_mappings, - label_prefix=label_path, + if i < len(neox_args.train_data_paths) - num_replay_data_paths: + train_datasets.append( + build_the_dataset( + data_prefix=train_path, + name=f"train_{i}", + data_impl=neox_args.data_impl, + num_samples=train_num_samples[i], + seq_length=neox_args.seq_length, + seed=neox_args.seed, + skip_warmup=(not neox_args.mmap_warmup), + build_index_mappings=build_index_mappings, + label_prefix=label_path, + ) ) - ) + + # when dealing with replay dataset, will need to pass neox_args to load idx files instead of building them. + else: + i_replay = i - (len(neox_args.train_data_paths) - num_replay_data_paths) + train_datasets.append( + build_the_dataset( + data_prefix=train_path, + name=f"replay_{i_replay}", + data_impl=neox_args.data_impl, + num_samples=train_num_samples[i], + seq_length=neox_args.seq_length, + seed=neox_args.replay_seed, + skip_warmup=(not neox_args.mmap_warmup), + build_index_mappings=False, + index_mapping_paths=neox_args.replay_data_to_idx_paths[train_path], + index_offset=neox_args.replay_idx_offsets[i_replay], + reshuffle_when_loading=neox_args.replay_reshuffle_idx, + ) + ) if valid_path: valid_datasets.append( @@ -326,9 +392,15 @@ def build_train_valid_test_data_iterators(neox_args): if neox_args.train_data_paths: # when individual train / valid / test data paths are provided # normalize weight values and get num samples for each dataset - train_weights, train_num_samples = get_normalized_weights_and_num_samples( - neox_args.train_data_weights, train_val_test_num_samples[0] - ) + if neox_args.is_replay_enabled: + train_weights, train_num_samples = get_normalized_weights_and_num_samples_with_replay( + neox_args.train_data_weights, neox_args.replay_data_weights, + neox_args.replay_fraction, train_val_test_num_samples[0] + ) + else: + train_weights, train_num_samples = get_normalized_weights_and_num_samples( + neox_args.train_data_weights, train_val_test_num_samples[0] + ) valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( neox_args.valid_data_weights, train_val_test_num_samples[1] ) @@ -346,10 +418,13 @@ def build_train_valid_test_data_iterators(neox_args): valid_weights, test_weights, build_index_mappings=not neox_args.weight_by_num_documents, + concatenate_train_replay_paths=True, ) if neox_args.weight_by_num_documents: - + assert not neox_args.is_replay_enabled, "Replay not tested in the case of autoweighting, remove assert at your own risk.\ + I suspect that something might break with the concatenation of the train and replay happening twice due to a second call\ + of build_weighted_datasets, so setting it to False with concatenate_train_replay_paths=False." # gets the number of documents in each datapath get_num_docs_list = lambda datasets: [ dataset.indexed_dataset.sizes.shape[0] for dataset in datasets @@ -394,6 +469,7 @@ def build_train_valid_test_data_iterators(neox_args): train_weights, valid_weights, test_weights, + concatenate_train_replay_paths=False, ) if train_datasets: @@ -403,6 +479,7 @@ def build_train_valid_test_data_iterators(neox_args): if test_datasets: test_ds = BlendableDataset(test_datasets, test_weights) else: + assert not neox_args.is_replay_enabled, "Replay not implemented in the case of autosplitting into train/val/test datasets." # when just data_path is provided # split dataset into train, valid and test from data_path train_ds, valid_ds, test_ds = build_train_valid_test_datasets( diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 75e601fda..4d4986d01 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -39,6 +39,9 @@ def __init__( build_index_mappings=True, use_shared_fs=True, label_dataset=None, + index_mapping_paths=None, + index_offset=0, + reshuffle_when_loading=True, ): self.name = name @@ -48,6 +51,8 @@ def __init__( # Checks assert np.min(documents) >= 0 assert np.max(documents) < indexed_dataset.sizes.shape[0] + if not build_index_mappings: + assert index_mapping_paths, "If not building index mappings, the path to existing ones must be provided." if build_index_mappings: # Build index mappings. @@ -61,13 +66,32 @@ def __init__( seed, use_shared_fs=use_shared_fs, ) - self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 - self.sample_idx_len = self.sample_idx.shape[0] - 1 + + else: + # If not building the index mappings, we need to load them + self.doc_idx, self.sample_idx, self.shuffle_idx = _load_index_mappings( + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + num_samples, + seq_length, + seed, + use_shared_fs=use_shared_fs, + index_mapping_paths=index_mapping_paths, + index_offset=index_offset, + reshuffle_when_loading=reshuffle_when_loading, + ) + + + self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 + self.sample_idx_len = self.sample_idx.shape[0] - 1 + + if self.shuffle_idx_len != self.sample_idx_len: + print( + f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" + ) - if self.shuffle_idx_len != self.sample_idx_len - 1: - print( - f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" - ) def __len__(self): return min(self.shuffle_idx_len, self.sample_idx_len) @@ -242,6 +266,134 @@ def _build_index_mappings( return doc_idx, sample_idx, shuffle_idx +# Warning: only implemented with replay in mind, some issues may arise when dealing with more than 1 epoch over the dataset +def _load_index_mappings( + name, + data_prefix, + documents, + sizes, + num_samples, + seq_length, + seed, + index_mapping_paths, + use_shared_fs=True, + index_offset=0, + reshuffle_when_loading=True, +): + """Build doc-idx, sample-idx, and shuffle-idx from ones loaded. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + is_replay = name.split("_")[0] == "replay" + + + # Filename of the index mappings. + _filename = data_prefix + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + doc_idx_filename = _filename + "_doc_idx.npy" + sample_idx_filename = _filename + "_sample_idx.npy" + shuffle_idx_filename = _filename + "_shuffle_idx.npy" + + + if not use_shared_fs: + should_process_dataset = int(os.environ["LOCAL_RANK"]) == 0 + else: + should_process_dataset = torch.distributed.get_rank() == 0 + + # Build the indexed mapping if not exist. + if should_process_dataset: + start_time = time.time() + print_rank_0(" > loading shuffle-idx mapping from {}".format(index_mapping_paths["shuffle_idx_path"])) + shuffle_idx = np.load(index_mapping_paths["shuffle_idx_path"], allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + + idx_prefix = index_mapping_paths["shuffle_idx_path"].split("_")[:-2] + + ## restrict to samples seen during original pretraining if this is replay + ## careful, this is hardcoded based on the number on idx filenames looking like dataset_train_4_indexmap_5781ns_2048sl_1234s_doc_idx.npy + + # remove 1.005 buffer that was added when estimating numbers of samples in get_normalized_weights_and_num_samples(). Note that of course + # this may be missing a few of the seen samples. + num_samples_originally_seen = np.int64(np.int64(idx_prefix[-3][:-2]) / 1.005) + seq_length_originally_seen = int(idx_prefix[-2][:-2]) + assert seq_length_originally_seen == seq_length, "Current seq len {} does not match seq len the indices were built for ({}).".format( + seq_length, + seq_length_originally_seen, + ) + # Useful if we want to support extending the idx if more than 1 epoch is necessary, but for now will assert 1 epoch + num_epochs_in_replay = _num_epochs(tokens_per_epoch, seq_length, num_samples_originally_seen) + assert num_epochs_in_replay == 1, "Enough samples from replay dataset {} for more than one epoch; this is currently\ + untested.".format(data_prefix) + if is_replay: + shuffle_idx = shuffle_idx[:num_samples_originally_seen] + # get a sufficient length if needed + + # apply offset, but add back the removed elements. + # TODO Do we want to shuffle again the shuffle_idx[:index_offset] term ? + index_offset = index_offset % len(shuffle_idx) + if reshuffle_when_loading: + # Numpy can throw errors if the array isn't writeable, which it is not apparently when it is loaded + if not shuffle_idx.flags.writeable: + try: + shuffle_idx.setflags(write=True) + except: + # copy trick if we couldn't set it to writeable + print("Loaded shuffle_idx at {} is not writeable, need to copy it as workaround...".format("_".join(idx_prefix))) + temp = shuffle_idx.copy() + shuffle_idx = temp + assert shuffle_idx.flags.writeable, "Failed to make shuffle_idx writeable, the shuffling will not work." + print("succesfully copied !") + # For some reason this is faster than shuffling the copied array directly. I'm sure there is a good reason for it. + temp_random_idx = np.array(range(len(shuffle_idx))) + np_rng.shuffle(temp_random_idx) + shuffle_idx = shuffle_idx[temp_random_idx] + del temp_random_idx + + shuffle_idx = np.concatenate([shuffle_idx[index_offset:], shuffle_idx[:index_offset]]) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_io_parallel_group() + ) + + # Load mappings. + start_time = time.time() + print_rank_0(" > loading doc-idx mapping from {}".format(index_mapping_paths["doc_idx_path"])) + doc_idx = np.load(index_mapping_paths["doc_idx_path"], allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading sample-idx mapping from {}".format(index_mapping_paths["sample_idx_path"])) + sample_idx = np.load(index_mapping_paths["sample_idx_path"], allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + print_rank_0(" total number of samples: {}".format(shuffle_idx.shape[0] + 1)) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + def _num_tokens(documents, sizes): """Total number of tokens in the dataset.""" return np.sum(sizes[documents]) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index bf6e3f3e8..f27e895c6 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -822,6 +822,31 @@ def check_batch_parameters(dp_world_size, train_batch, micro_batch, grad_acc): f"{train_batch} != {micro_batch} * {grad_acc} * {dp_world_size}" ) + + def generate_data_and_idx_paths_from_idx_paths_prefixes(self, idx_paths_prefixes): + # All idx filenames will contain "_indexmap_" followed by info we don't need. The two words separated + # by "_" before are the training index and whether it's train/valid/test as seen in build_weighted_datasets(). + # Therefore the dataset name ranges up to index of "indexmap" - 2 if we split the idx path by "_". + data_paths = [None] * len(idx_paths_prefixes) + data_to_idx_paths = {} + for i, idx_path_prefix in enumerate(idx_paths_prefixes): + data_path_prefix_words = idx_path_prefix.split("_") + maximum_index = [idx for idx, data_path_prefix_word in enumerate(data_path_prefix_words) + if data_path_prefix_word == "indexmap"] + assert maximum_index, "Error: idx path prefix {} unsupported; does not contain the word \"indexmap\"".format(idx_path_prefix) + maximum_index = maximum_index[-1] - 2 + data_path = "_".join(data_path_prefix_words[:maximum_index]) + data_paths[i] = data_path + + data_to_idx_paths[data_path] = { + "doc_idx_path": idx_path_prefix + "_doc_idx.npy", + "sample_idx_path": idx_path_prefix + "_sample_idx.npy", + "shuffle_idx_path": idx_path_prefix + "_shuffle_idx.npy", + } + + return data_paths, data_to_idx_paths + + def calculate_derived(self): """ Derives additional configuration values necessary for training from the current config @@ -1104,6 +1129,22 @@ def calculate_derived(self): f"Warning: Flash-Attention version ({str(_flash_version)}) must be >= 2.4.0.post1 to support AliBi. Falling back to flash-attn triton backend, but version 2.4.0.post1 or later will be required in future." ) + + # Replay config + if self.replay_config is not None: + self.update_value("is_replay_enabled", self.replay_config.get("enabled", False)) + + for k, v in self.replay_config.items(): + if k == "enabled": + continue + self.update_value(k, v) + + # If no idx offsets were provided, set them automatically to 0 + if self.replay_idx_offsets is None: + self.replay_idx_offsets = [0] * len(self.replay_idx_paths_prefixes) + + self.replay_data_paths, self.replay_data_to_idx_paths = self.generate_data_and_idx_paths_from_idx_paths_prefixes(self.replay_idx_paths_prefixes) + # Adding equal dataset weights if none are provided if self.train_data_paths and (self.train_data_weights is None): self.train_data_weights = [1.0] * len(self.train_data_paths) @@ -1111,6 +1152,9 @@ def calculate_derived(self): self.valid_data_weights = [1.0] * len(self.valid_data_paths) if self.test_data_paths and (self.test_data_weights is None): self.test_data_weights = [1.0] * len(self.test_data_paths) + if self.is_replay_enabled: + if self.replay_idx_paths_prefixes and (self.replay_data_weights is None): + self.replay_data_weights = [1.0] * len(self.replay_idx_paths_prefixes) if self.label_data_paths: err_str = ( @@ -1290,6 +1334,37 @@ def validate_values(self): if self.test_data_paths is not None: assert len(self.test_data_paths) == len(self.test_data_weights) + # assert that if replay is enabled, replay mandatory args have been passed and we have as many weights as data sources + if self.is_replay_enabled: + required_replay_args = [ + "replay_idx_paths_prefixes", + "replay_data_paths", + "replay_data_to_idx_paths", + ] + for required_replay_arg in required_replay_args: + assert getattr(self, required_replay_arg) is not None, "Missing required replay attribute: {} .".format(required_replay_arg) + + # assert that idx files exist at path specified + for replay_data_to_idx_dict in self.replay_data_to_idx_paths.values(): + for k, v in replay_data_to_idx_dict.items(): + assert os.path.exists(v), "No replay {} file found at path {}".format(k, v) + + # assert that the paths prefixes and the weights have equal length + assert len(self.replay_idx_paths_prefixes) == len(self.replay_data_weights), "Number of replay datasets and weights does not match." + + # assert that the paths prefixes and idx offsets have equal length + assert len(self.replay_idx_paths_prefixes) == len(self.replay_idx_offsets), "Number of replay datasets and offsets does not match." + + # assert that the paths prefixes and replay data paths have equal length + assert len(self.replay_idx_paths_prefixes) == len(self.replay_data_paths), "Number of replay datasets and offsets does not match." + + # assert that replay_fraction is between 0 and 1 + assert 0 <= self.replay_fraction <= 1, "replay_fraction needs to be set to a number in [0, 1], was set to {}".format(self.replay_fraction) + + # replay is only supported when train_data_paths, valid_data_paths and test_data_paths are provided + assert all(has_separate_path), "Replay is only currently supported when train_data_paths, valid_data_paths and test_data_paths are provided." + + return True def validate_types(self): @@ -1385,7 +1460,7 @@ def validate_types(self): ) return False - for field_name in ["fp16", "amp", "flops_profiler"]: + for field_name in ["fp16", "amp", "flops_profiler", "replay_config"]: value = getattr(self, field_name) if isinstance(value, dict): if not "enabled" in value: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 16d6456b4..037fe2079 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -880,6 +880,92 @@ class NeoXArgsTraining(NeoXArgsTemplate): Should be a list the same length as `test_data_paths` """ + replay_config: dict = None + """ + Dictionary storing the replay config. + """ + + is_replay_enabled: bool = False + """ + Triggers the logic for replay. It is important to deal with replay separately from the general "train_data_paths" logic, as replay + requires reusing the same idx files to know what data was seen the first time a dataset was originally trained on. + If one attempts to do replay by just putting the datasets to be replayed in the train_data_paths instead of the replay params: + - If the exact same dataset files are used as during the 1st time it was seen, and the number of iterations on the replay buffer + corresponds to as many epochs on a replay dataset as the non-replay training, the data will be seen in exactly the same order as + the first time if the seed and sequence length is the same. + - For similar reasons, replaying multiple times on the same dataset (e.g. across multiple tasks) with the same number of epochs + on the replay dataset will lead to seeing the same data in the same order. + - If a different dataset is used for replay (e.g. different shard of Pile), then the shuffling will lead to completely different + indices, which will lead to potentially significant proportions of data being unseen if the original training on the replay dataset + did not see all of it, e.g. when training on 300B tokens of the GPT2-tokenised Pile which contains a few dozen billion more tokens, + then sharding the full dataset into smaller ones. + """ + + replay_idx_paths_prefixes: list = None + """ + List of paths prefixes to retrieve replay dataset idx files. Those idx files should have been generated when originally training on the dataset + being used for replay. They contain in the filename the number of samples potentially seen during pretraining, the sequence length and the + seed. The exact files (shuffle_idx, sample_idx and doc_idx) will be automatically derived from the prefix. Similarly, the data paths will + be generated from the prefixes. + The *_idx files are important as it allows one to know what data was seen in the dataset during training. If those files are missing, you can + regenerate them by relaunching the same training script (most importantly, config) used originally to pretrain on a given dataset. You + can add an exit(0) statement in training.py in pretrain() after the call to build_train_valid_test_data_iterators(neox_args=neox_args). + It is crucial to use the same dataset shard, sequence length, number of iterations, seed, and potentially batch size, or the indices + generated may not be the same. + For a single replay data source, the value passed looks like ["data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s"] and + the files at the following paths (the paths will be constructed during execution from the prefix), must exist: + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_doc_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_sample_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_shuffle_idx.npy" + "data/mydataset/train/mydataset" + """ + + replay_data_to_idx_paths: dict = None + """ + As indicated above, gets automatically built from the replay_idx_paths_prefixes by appending to it "_doc_idx.npy", "_sample_idx.npy" and + "_shuffle_idx.npy". It generates a dict of dict, with the data paths as keys, and dictionaries mapping each data path to the relevant + doc_idx, sample_idx and shuffle_idx file paths. Note that these files must exist at the relevant paths. + """ + + replay_data_paths: list = None + """ + As indicated above, gets automatically built from the replay_idx_paths_prefixes by removing the information about the idx files to retain + only the path to the dataset itself. + """ + + replay_data_weights: list = None + """ + List of 'weights' that decide how often to sample from each replay dataset when building the replay buffer. + """ + + replay_idx_offsets: list = None + """ + List of indices that decide where to start in the list of seen indices during pretraining on each replay dataset when building + the replay buffer. For example, when training originally on a dataset seeing 10000 samples, this allows to start looking at the + RESHUFFLED indices starting from idx replay_idx_offsets[i] for replay dataset i. + If not set, this will uniformly sample among all replay datasets. + """ + + replay_fraction: float = 0.05 + """ + Fraction of a batch dedicated to doing replay. For example, 0.1 means that in a batch of 100, 19 samples will come from the replay + buffer. Note that this means that if we train on 100B tokens, we will have only used 90B tokens from the datasets specified in + train_data_paths. + """ + + replay_reshuffle_idx: bool = True + """ + When index files are loaded from those the dataset was originally pretrained on, they will follow the exact same sequence of samples + seen when training on that dataset the first time if this is set to False. If True, the indices are reshuffled to prevent that. + """ + + replay_seed: int = 1234 + """ + Seed used to reshuffle indices accessed when originally training on a dataset, that are used to do replay. This is useful in the case + where replay is done twice on as many passes over the dataset, in which case if the same seed is used, the replay buffers in both case + will be exactly the same. + """ + weight_by_num_documents: bool = False """ If True, Builds dataset weights from a multinomial distribution over groups of data according to the number of diff --git a/tests/config/example_replay_config.yml b/tests/config/example_replay_config.yml new file mode 100644 index 000000000..30bca72a1 --- /dev/null +++ b/tests/config/example_replay_config.yml @@ -0,0 +1,45 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], + "train-data-weights": [ + 4.428184641, + 4.203131326, + 26.688499472, + 3.997293125, + 5.224141056, + 3.371078725, + 52.087671651 + ], + "train-iters": 132366, + "lr-decay-iters": 132366, + "train-dataset-name": 'slim_pajama_300B', + + "replay_config": { + "enabled": true, + # Have to specify idx filenames from original pretraining on tasks, as they contain the num iterations + # and seen indices assuming we're using the same (non-replay) seed as during pretraining + "replay_idx_paths_prefixes": [ + "data/pile/train/pile_train_train_0_indexmap_146862725ns_2048sl_1234s", + ], + "replay_data_weights":[ + 1.00, + ], + "replay_idx_offsets": [ + 1, + ], + # Fraction of samples coming from the replay buffer, between 0 and 1. + "replay_fraction": 0.5, + # Seed and reshuffle go hand in hand. They control whether you want to see the replay data in the same order + # as you've seen it (done by setting reshuffle to false), and if you decide to reshuffle, what seed you should + # use to reshuffle the seen data. + "replay_seed": 1234, + "replay_reshuffle_idx": false, + }, +} \ No newline at end of file diff --git a/tests/model/test_batch_replicability.py b/tests/model/test_batch_replicability.py new file mode 100644 index 000000000..bfe38f938 --- /dev/null +++ b/tests/model/test_batch_replicability.py @@ -0,0 +1,66 @@ +import torch +import numpy as np + +""" +Idea for the verification: save batches per rank on two datasets, one being a mix of say data sources +A+B, and the other being just A. To debug things, we used +A = pile +B = pile+slimp +and B being 50/50 mix of pile and slimp. We save 3 batches for pile+slimp to ensure that even with noise, the sampling +is as deterministic as we might think, the first batch of pile should be in the first 3 batches of pile+slimp. +To save the files used to debug, add in training.py below: + # Data stuff. + timers("train/valid/test data iterators").start() + ( + train_data_iterator, + valid_data_iterator, + test_data_iterator, + ) = build_train_valid_test_data_iterators(neox_args=neox_args) +the following + num_batches_to_test_replicability = 3 + save_dummy = torch.stack([next(train_data_iterator)["text"] for _ in range(num_batches_to_test_replicability)]) + torch.save(save_dummy[0], "pile_rank_{}.my_repro_batchs".format(neox_args.rank)) + print(save_dummy[0]) + # print(next(train_data_iterator), "\npile_slimp1") + # torch.distributed.barrier() + # print(next(train_data_iterator), "\npile_slimp2") + # torch.distributed.barrier() + # print(next(train_data_iterator), "\npile_slimp3") + torch.distributed.barrier() + exit(0) +you need 2 runs of this with the two datasets you consider. Don't forget to change the filename of the torch.save() +""" + +num_ranks = 6 +dataset_A = [None] * num_ranks +dataset_B = [None] * num_ranks +dataset_A_name = "pile_replay" +dataset_B_name = "pile+slimp" +# Use 0 to use all batches saved. If only one batch was saved, this param gets ignored. +num_batches_A = 3 +num_batches_B = 3 + + +# cat all batches in a limit of num_batches +def cat_only_process_num_batches(dataset, num_ranks, num_batches=0): + dim = 1 if len(dataset[0].shape)==3 \ + else 0 + # only use num_batches batches, if tensor has shape [num_batches, sample_idx_in_batch, seq_len] + if num_batches and dim: + new_shape_format = [-1, dataset[0].shape[-1]] + return (torch.cat([dataset[i][:num_batches] for i in range(num_ranks)], dim=dim)).view(new_shape_format) + else: + return torch.cat([dataset[i] for i in range(num_ranks)], dim=dim) + +for i in range(num_ranks): + dataset_A[i] = torch.load("gpt-neox/{}_rank_{}.adam".format(dataset_A_name, i)) + dataset_B[i] = torch.load("gpt-neox/{}_rank_{}.adam".format(dataset_B_name, i)) + +dataset_A_cat = cat_only_process_num_batches(dataset_A, num_ranks, num_batches=num_batches_A) +dataset_B_cat = cat_only_process_num_batches(dataset_B, num_ranks, num_batches=num_batches_B) +# dataset_B_cat = dataset_B_cat.reshape([-1, dataset_B_cat.shape[-1]]) +# dataset_A_cat = torch.cat([dataset_A[i][:num_batches_A] for i in range(num_ranks)], dim=0) + +checks = [False] * dataset_A_cat.shape[0] +for i in range(len(checks)): + checks[i] = torch.all(torch.isin(dataset_A_cat[i], dataset_B_cat)).item()