From defa0a45d1055a7f8f1749e5d80d55a56f7904fa Mon Sep 17 00:00:00 2001 From: Adam Ibrahim Date: Sat, 13 Apr 2024 01:35:33 +0200 Subject: [PATCH 1/8] Initial commit of replay --- megatron/data/data_utils.py | 109 +++++++++++++--- megatron/data/gpt2_dataset.py | 164 +++++++++++++++++++++++- megatron/neox_arguments/arguments.py | 77 ++++++++++- megatron/neox_arguments/neox_args.py | 86 +++++++++++++ tests/config/example_replay_config.yml | 45 +++++++ tests/model/test_batch_replicability.py | 66 ++++++++++ 6 files changed, 524 insertions(+), 23 deletions(-) create mode 100644 tests/config/example_replay_config.yml create mode 100644 tests/model/test_batch_replicability.py diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index bc5754cdb..65347e24f 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -61,6 +61,9 @@ def build_the_dataset( skip_warmup, build_index_mappings=True, label_prefix=None, + index_mapping_paths=None, + index_offset=0, + reshuffle_when_loading=True, ): """Build train/valid/test datasets.""" @@ -85,6 +88,9 @@ def build_the_dataset( seed, build_index_mappings=build_index_mappings, label_dataset=label_dataset, + index_mapping_paths=index_mapping_paths, + index_offset=index_offset, + reshuffle_when_loading=reshuffle_when_loading, ) return dataset @@ -191,6 +197,32 @@ def get_normalized_weights_and_num_samples( weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) return weights, weighted_num_samples +def get_normalized_weights_and_num_samples_with_replay( + weights: List[float], replay_weights: List[float], replay_fraction, num_samples: int +) -> Tuple[List[float], List[int]]: + # Normalize weights. weights correspond to the weights from the training data and replay_weights correspond + # to weights from the replay data. The idea is that we will be merge the weights provided for training data + # and replay data into the same array. We know that replay_weights should contribute replay_fraction of all + # weights, so we also need to normalise replay weights by replay_fraction and the rest by (1-replay_fraction). + weight_sum = sum(weights) + assert weight_sum > 0.0 + weights = [(weight / weight_sum) * (1-replay_fraction) for weight in weights] + + replay_weights_sum = sum(replay_weights) + assert replay_weights_sum > 0.0 + replay_weights = [(replay_weight / replay_weights_sum) * replay_fraction for replay_weight in replay_weights] + + # merge weights with the replay weights given the replay_fraction + weights = weights + replay_weights + + # Add 0.5% (the 1.005 factor) so in case the blending dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + weighted_num_samples = [] + for weight in weights: + weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) + return weights, weighted_num_samples + def build_weighted_datasets( neox_args, @@ -201,7 +233,21 @@ def build_weighted_datasets( valid_weights, test_weights, build_index_mappings=True, + concatenate_train_replay_paths=False, ): + + # The concatenate_train_replay_paths bool is necessary to avoid issues when this function gets called a second time. + if neox_args.is_replay_enabled and concatenate_train_replay_paths: + # Merge replay data paths into train data paths logic, but need to keep track of + # what paths in train_data_paths came from replay + num_replay_data_paths = len(neox_args.replay_data_paths) + num_non_replay_data_paths = len(neox_args.train_data_paths) + neox_args.train_data_paths += neox_args.replay_data_paths + else: + num_replay_data_paths = 0 + + assert not (neox_args.label_data_paths and neox_args.is_replay_enabled), "Simultaneous use of label data and replay is untested.\ + Remove assert at your own risk. You might want to add a replay_label_data_paths arg too if relevant." # build individual datasets train_datasets, valid_datasets, test_datasets = [], [], [] for i, (train_path, label_path, valid_path, test_path) in enumerate( @@ -213,19 +259,39 @@ def build_weighted_datasets( ) ): if train_path: - train_datasets.append( - build_the_dataset( - data_prefix=train_path, - name=f"train_{i}", - data_impl=neox_args.data_impl, - num_samples=train_num_samples[i], - seq_length=neox_args.seq_length, - seed=neox_args.seed, - skip_warmup=(not neox_args.mmap_warmup), - build_index_mappings=build_index_mappings, - label_prefix=label_path, + if i < len(neox_args.train_data_paths) - num_replay_data_paths: + train_datasets.append( + build_the_dataset( + data_prefix=train_path, + name=f"train_{i}", + data_impl=neox_args.data_impl, + num_samples=train_num_samples[i], + seq_length=neox_args.seq_length, + seed=neox_args.seed, + skip_warmup=(not neox_args.mmap_warmup), + build_index_mappings=build_index_mappings, + label_prefix=label_path, + ) ) - ) + + # when dealing with replay dataset, will need to pass neox_args to load idx files instead of building them. + else: + i_replay = i - (len(neox_args.train_data_paths) - num_replay_data_paths) + train_datasets.append( + build_the_dataset( + data_prefix=train_path, + name=f"replay_{i_replay}", + data_impl=neox_args.data_impl, + num_samples=train_num_samples[i], + seq_length=neox_args.seq_length, + seed=neox_args.replay_seed, + skip_warmup=(not neox_args.mmap_warmup), + build_index_mappings=False, + index_mapping_paths=neox_args.replay_data_to_idx_paths[train_path], + index_offset=neox_args.replay_idx_offsets[i_replay], + reshuffle_when_loading=neox_args.replay_reshuffle_idx, + ) + ) if valid_path: valid_datasets.append( @@ -326,9 +392,15 @@ def build_train_valid_test_data_iterators(neox_args): if neox_args.train_data_paths: # when individual train / valid / test data paths are provided # normalize weight values and get num samples for each dataset - train_weights, train_num_samples = get_normalized_weights_and_num_samples( - neox_args.train_data_weights, train_val_test_num_samples[0] - ) + if neox_args.is_replay_enabled: + train_weights, train_num_samples = get_normalized_weights_and_num_samples_with_replay( + neox_args.train_data_weights, neox_args.replay_data_weights, + neox_args.replay_fraction, train_val_test_num_samples[0] + ) + else: + train_weights, train_num_samples = get_normalized_weights_and_num_samples( + neox_args.train_data_weights, train_val_test_num_samples[0] + ) valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( neox_args.valid_data_weights, train_val_test_num_samples[1] ) @@ -346,10 +418,13 @@ def build_train_valid_test_data_iterators(neox_args): valid_weights, test_weights, build_index_mappings=not neox_args.weight_by_num_documents, + concatenate_train_replay_paths=True, ) if neox_args.weight_by_num_documents: - + assert not neox_args.is_replay_enabled, "Replay not tested in the case of autoweighting, remove assert at your own risk.\ + I suspect that something might break with the concatenation of the train and replay happening twice due to a second call\ + of build_weighted_datasets, so setting it to False with concatenate_train_replay_paths=False." # gets the number of documents in each datapath get_num_docs_list = lambda datasets: [ dataset.indexed_dataset.sizes.shape[0] for dataset in datasets @@ -394,6 +469,7 @@ def build_train_valid_test_data_iterators(neox_args): train_weights, valid_weights, test_weights, + concatenate_train_replay_paths=False, ) if train_datasets: @@ -403,6 +479,7 @@ def build_train_valid_test_data_iterators(neox_args): if test_datasets: test_ds = BlendableDataset(test_datasets, test_weights) else: + assert not neox_args.is_replay_enabled, "Replay not implemented in the case of autosplitting into train/val/test datasets." # when just data_path is provided # split dataset into train, valid and test from data_path train_ds, valid_ds, test_ds = build_train_valid_test_datasets( diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 75e601fda..4d4986d01 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -39,6 +39,9 @@ def __init__( build_index_mappings=True, use_shared_fs=True, label_dataset=None, + index_mapping_paths=None, + index_offset=0, + reshuffle_when_loading=True, ): self.name = name @@ -48,6 +51,8 @@ def __init__( # Checks assert np.min(documents) >= 0 assert np.max(documents) < indexed_dataset.sizes.shape[0] + if not build_index_mappings: + assert index_mapping_paths, "If not building index mappings, the path to existing ones must be provided." if build_index_mappings: # Build index mappings. @@ -61,13 +66,32 @@ def __init__( seed, use_shared_fs=use_shared_fs, ) - self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 - self.sample_idx_len = self.sample_idx.shape[0] - 1 + + else: + # If not building the index mappings, we need to load them + self.doc_idx, self.sample_idx, self.shuffle_idx = _load_index_mappings( + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + num_samples, + seq_length, + seed, + use_shared_fs=use_shared_fs, + index_mapping_paths=index_mapping_paths, + index_offset=index_offset, + reshuffle_when_loading=reshuffle_when_loading, + ) + + + self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 + self.sample_idx_len = self.sample_idx.shape[0] - 1 + + if self.shuffle_idx_len != self.sample_idx_len: + print( + f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" + ) - if self.shuffle_idx_len != self.sample_idx_len - 1: - print( - f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" - ) def __len__(self): return min(self.shuffle_idx_len, self.sample_idx_len) @@ -242,6 +266,134 @@ def _build_index_mappings( return doc_idx, sample_idx, shuffle_idx +# Warning: only implemented with replay in mind, some issues may arise when dealing with more than 1 epoch over the dataset +def _load_index_mappings( + name, + data_prefix, + documents, + sizes, + num_samples, + seq_length, + seed, + index_mapping_paths, + use_shared_fs=True, + index_offset=0, + reshuffle_when_loading=True, +): + """Build doc-idx, sample-idx, and shuffle-idx from ones loaded. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + is_replay = name.split("_")[0] == "replay" + + + # Filename of the index mappings. + _filename = data_prefix + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + doc_idx_filename = _filename + "_doc_idx.npy" + sample_idx_filename = _filename + "_sample_idx.npy" + shuffle_idx_filename = _filename + "_shuffle_idx.npy" + + + if not use_shared_fs: + should_process_dataset = int(os.environ["LOCAL_RANK"]) == 0 + else: + should_process_dataset = torch.distributed.get_rank() == 0 + + # Build the indexed mapping if not exist. + if should_process_dataset: + start_time = time.time() + print_rank_0(" > loading shuffle-idx mapping from {}".format(index_mapping_paths["shuffle_idx_path"])) + shuffle_idx = np.load(index_mapping_paths["shuffle_idx_path"], allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + + idx_prefix = index_mapping_paths["shuffle_idx_path"].split("_")[:-2] + + ## restrict to samples seen during original pretraining if this is replay + ## careful, this is hardcoded based on the number on idx filenames looking like dataset_train_4_indexmap_5781ns_2048sl_1234s_doc_idx.npy + + # remove 1.005 buffer that was added when estimating numbers of samples in get_normalized_weights_and_num_samples(). Note that of course + # this may be missing a few of the seen samples. + num_samples_originally_seen = np.int64(np.int64(idx_prefix[-3][:-2]) / 1.005) + seq_length_originally_seen = int(idx_prefix[-2][:-2]) + assert seq_length_originally_seen == seq_length, "Current seq len {} does not match seq len the indices were built for ({}).".format( + seq_length, + seq_length_originally_seen, + ) + # Useful if we want to support extending the idx if more than 1 epoch is necessary, but for now will assert 1 epoch + num_epochs_in_replay = _num_epochs(tokens_per_epoch, seq_length, num_samples_originally_seen) + assert num_epochs_in_replay == 1, "Enough samples from replay dataset {} for more than one epoch; this is currently\ + untested.".format(data_prefix) + if is_replay: + shuffle_idx = shuffle_idx[:num_samples_originally_seen] + # get a sufficient length if needed + + # apply offset, but add back the removed elements. + # TODO Do we want to shuffle again the shuffle_idx[:index_offset] term ? + index_offset = index_offset % len(shuffle_idx) + if reshuffle_when_loading: + # Numpy can throw errors if the array isn't writeable, which it is not apparently when it is loaded + if not shuffle_idx.flags.writeable: + try: + shuffle_idx.setflags(write=True) + except: + # copy trick if we couldn't set it to writeable + print("Loaded shuffle_idx at {} is not writeable, need to copy it as workaround...".format("_".join(idx_prefix))) + temp = shuffle_idx.copy() + shuffle_idx = temp + assert shuffle_idx.flags.writeable, "Failed to make shuffle_idx writeable, the shuffling will not work." + print("succesfully copied !") + # For some reason this is faster than shuffling the copied array directly. I'm sure there is a good reason for it. + temp_random_idx = np.array(range(len(shuffle_idx))) + np_rng.shuffle(temp_random_idx) + shuffle_idx = shuffle_idx[temp_random_idx] + del temp_random_idx + + shuffle_idx = np.concatenate([shuffle_idx[index_offset:], shuffle_idx[:index_offset]]) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_io_parallel_group() + ) + + # Load mappings. + start_time = time.time() + print_rank_0(" > loading doc-idx mapping from {}".format(index_mapping_paths["doc_idx_path"])) + doc_idx = np.load(index_mapping_paths["doc_idx_path"], allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading sample-idx mapping from {}".format(index_mapping_paths["sample_idx_path"])) + sample_idx = np.load(index_mapping_paths["sample_idx_path"], allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + print_rank_0(" total number of samples: {}".format(shuffle_idx.shape[0] + 1)) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + def _num_tokens(documents, sizes): """Total number of tokens in the dataset.""" return np.sum(sizes[documents]) diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index bf6e3f3e8..f27e895c6 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -822,6 +822,31 @@ def check_batch_parameters(dp_world_size, train_batch, micro_batch, grad_acc): f"{train_batch} != {micro_batch} * {grad_acc} * {dp_world_size}" ) + + def generate_data_and_idx_paths_from_idx_paths_prefixes(self, idx_paths_prefixes): + # All idx filenames will contain "_indexmap_" followed by info we don't need. The two words separated + # by "_" before are the training index and whether it's train/valid/test as seen in build_weighted_datasets(). + # Therefore the dataset name ranges up to index of "indexmap" - 2 if we split the idx path by "_". + data_paths = [None] * len(idx_paths_prefixes) + data_to_idx_paths = {} + for i, idx_path_prefix in enumerate(idx_paths_prefixes): + data_path_prefix_words = idx_path_prefix.split("_") + maximum_index = [idx for idx, data_path_prefix_word in enumerate(data_path_prefix_words) + if data_path_prefix_word == "indexmap"] + assert maximum_index, "Error: idx path prefix {} unsupported; does not contain the word \"indexmap\"".format(idx_path_prefix) + maximum_index = maximum_index[-1] - 2 + data_path = "_".join(data_path_prefix_words[:maximum_index]) + data_paths[i] = data_path + + data_to_idx_paths[data_path] = { + "doc_idx_path": idx_path_prefix + "_doc_idx.npy", + "sample_idx_path": idx_path_prefix + "_sample_idx.npy", + "shuffle_idx_path": idx_path_prefix + "_shuffle_idx.npy", + } + + return data_paths, data_to_idx_paths + + def calculate_derived(self): """ Derives additional configuration values necessary for training from the current config @@ -1104,6 +1129,22 @@ def calculate_derived(self): f"Warning: Flash-Attention version ({str(_flash_version)}) must be >= 2.4.0.post1 to support AliBi. Falling back to flash-attn triton backend, but version 2.4.0.post1 or later will be required in future." ) + + # Replay config + if self.replay_config is not None: + self.update_value("is_replay_enabled", self.replay_config.get("enabled", False)) + + for k, v in self.replay_config.items(): + if k == "enabled": + continue + self.update_value(k, v) + + # If no idx offsets were provided, set them automatically to 0 + if self.replay_idx_offsets is None: + self.replay_idx_offsets = [0] * len(self.replay_idx_paths_prefixes) + + self.replay_data_paths, self.replay_data_to_idx_paths = self.generate_data_and_idx_paths_from_idx_paths_prefixes(self.replay_idx_paths_prefixes) + # Adding equal dataset weights if none are provided if self.train_data_paths and (self.train_data_weights is None): self.train_data_weights = [1.0] * len(self.train_data_paths) @@ -1111,6 +1152,9 @@ def calculate_derived(self): self.valid_data_weights = [1.0] * len(self.valid_data_paths) if self.test_data_paths and (self.test_data_weights is None): self.test_data_weights = [1.0] * len(self.test_data_paths) + if self.is_replay_enabled: + if self.replay_idx_paths_prefixes and (self.replay_data_weights is None): + self.replay_data_weights = [1.0] * len(self.replay_idx_paths_prefixes) if self.label_data_paths: err_str = ( @@ -1290,6 +1334,37 @@ def validate_values(self): if self.test_data_paths is not None: assert len(self.test_data_paths) == len(self.test_data_weights) + # assert that if replay is enabled, replay mandatory args have been passed and we have as many weights as data sources + if self.is_replay_enabled: + required_replay_args = [ + "replay_idx_paths_prefixes", + "replay_data_paths", + "replay_data_to_idx_paths", + ] + for required_replay_arg in required_replay_args: + assert getattr(self, required_replay_arg) is not None, "Missing required replay attribute: {} .".format(required_replay_arg) + + # assert that idx files exist at path specified + for replay_data_to_idx_dict in self.replay_data_to_idx_paths.values(): + for k, v in replay_data_to_idx_dict.items(): + assert os.path.exists(v), "No replay {} file found at path {}".format(k, v) + + # assert that the paths prefixes and the weights have equal length + assert len(self.replay_idx_paths_prefixes) == len(self.replay_data_weights), "Number of replay datasets and weights does not match." + + # assert that the paths prefixes and idx offsets have equal length + assert len(self.replay_idx_paths_prefixes) == len(self.replay_idx_offsets), "Number of replay datasets and offsets does not match." + + # assert that the paths prefixes and replay data paths have equal length + assert len(self.replay_idx_paths_prefixes) == len(self.replay_data_paths), "Number of replay datasets and offsets does not match." + + # assert that replay_fraction is between 0 and 1 + assert 0 <= self.replay_fraction <= 1, "replay_fraction needs to be set to a number in [0, 1], was set to {}".format(self.replay_fraction) + + # replay is only supported when train_data_paths, valid_data_paths and test_data_paths are provided + assert all(has_separate_path), "Replay is only currently supported when train_data_paths, valid_data_paths and test_data_paths are provided." + + return True def validate_types(self): @@ -1385,7 +1460,7 @@ def validate_types(self): ) return False - for field_name in ["fp16", "amp", "flops_profiler"]: + for field_name in ["fp16", "amp", "flops_profiler", "replay_config"]: value = getattr(self, field_name) if isinstance(value, dict): if not "enabled" in value: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 16d6456b4..037fe2079 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -880,6 +880,92 @@ class NeoXArgsTraining(NeoXArgsTemplate): Should be a list the same length as `test_data_paths` """ + replay_config: dict = None + """ + Dictionary storing the replay config. + """ + + is_replay_enabled: bool = False + """ + Triggers the logic for replay. It is important to deal with replay separately from the general "train_data_paths" logic, as replay + requires reusing the same idx files to know what data was seen the first time a dataset was originally trained on. + If one attempts to do replay by just putting the datasets to be replayed in the train_data_paths instead of the replay params: + - If the exact same dataset files are used as during the 1st time it was seen, and the number of iterations on the replay buffer + corresponds to as many epochs on a replay dataset as the non-replay training, the data will be seen in exactly the same order as + the first time if the seed and sequence length is the same. + - For similar reasons, replaying multiple times on the same dataset (e.g. across multiple tasks) with the same number of epochs + on the replay dataset will lead to seeing the same data in the same order. + - If a different dataset is used for replay (e.g. different shard of Pile), then the shuffling will lead to completely different + indices, which will lead to potentially significant proportions of data being unseen if the original training on the replay dataset + did not see all of it, e.g. when training on 300B tokens of the GPT2-tokenised Pile which contains a few dozen billion more tokens, + then sharding the full dataset into smaller ones. + """ + + replay_idx_paths_prefixes: list = None + """ + List of paths prefixes to retrieve replay dataset idx files. Those idx files should have been generated when originally training on the dataset + being used for replay. They contain in the filename the number of samples potentially seen during pretraining, the sequence length and the + seed. The exact files (shuffle_idx, sample_idx and doc_idx) will be automatically derived from the prefix. Similarly, the data paths will + be generated from the prefixes. + The *_idx files are important as it allows one to know what data was seen in the dataset during training. If those files are missing, you can + regenerate them by relaunching the same training script (most importantly, config) used originally to pretrain on a given dataset. You + can add an exit(0) statement in training.py in pretrain() after the call to build_train_valid_test_data_iterators(neox_args=neox_args). + It is crucial to use the same dataset shard, sequence length, number of iterations, seed, and potentially batch size, or the indices + generated may not be the same. + For a single replay data source, the value passed looks like ["data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s"] and + the files at the following paths (the paths will be constructed during execution from the prefix), must exist: + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_doc_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_sample_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_shuffle_idx.npy" + "data/mydataset/train/mydataset" + """ + + replay_data_to_idx_paths: dict = None + """ + As indicated above, gets automatically built from the replay_idx_paths_prefixes by appending to it "_doc_idx.npy", "_sample_idx.npy" and + "_shuffle_idx.npy". It generates a dict of dict, with the data paths as keys, and dictionaries mapping each data path to the relevant + doc_idx, sample_idx and shuffle_idx file paths. Note that these files must exist at the relevant paths. + """ + + replay_data_paths: list = None + """ + As indicated above, gets automatically built from the replay_idx_paths_prefixes by removing the information about the idx files to retain + only the path to the dataset itself. + """ + + replay_data_weights: list = None + """ + List of 'weights' that decide how often to sample from each replay dataset when building the replay buffer. + """ + + replay_idx_offsets: list = None + """ + List of indices that decide where to start in the list of seen indices during pretraining on each replay dataset when building + the replay buffer. For example, when training originally on a dataset seeing 10000 samples, this allows to start looking at the + RESHUFFLED indices starting from idx replay_idx_offsets[i] for replay dataset i. + If not set, this will uniformly sample among all replay datasets. + """ + + replay_fraction: float = 0.05 + """ + Fraction of a batch dedicated to doing replay. For example, 0.1 means that in a batch of 100, 19 samples will come from the replay + buffer. Note that this means that if we train on 100B tokens, we will have only used 90B tokens from the datasets specified in + train_data_paths. + """ + + replay_reshuffle_idx: bool = True + """ + When index files are loaded from those the dataset was originally pretrained on, they will follow the exact same sequence of samples + seen when training on that dataset the first time if this is set to False. If True, the indices are reshuffled to prevent that. + """ + + replay_seed: int = 1234 + """ + Seed used to reshuffle indices accessed when originally training on a dataset, that are used to do replay. This is useful in the case + where replay is done twice on as many passes over the dataset, in which case if the same seed is used, the replay buffers in both case + will be exactly the same. + """ + weight_by_num_documents: bool = False """ If True, Builds dataset weights from a multinomial distribution over groups of data according to the number of diff --git a/tests/config/example_replay_config.yml b/tests/config/example_replay_config.yml new file mode 100644 index 000000000..30bca72a1 --- /dev/null +++ b/tests/config/example_replay_config.yml @@ -0,0 +1,45 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], + "train-data-weights": [ + 4.428184641, + 4.203131326, + 26.688499472, + 3.997293125, + 5.224141056, + 3.371078725, + 52.087671651 + ], + "train-iters": 132366, + "lr-decay-iters": 132366, + "train-dataset-name": 'slim_pajama_300B', + + "replay_config": { + "enabled": true, + # Have to specify idx filenames from original pretraining on tasks, as they contain the num iterations + # and seen indices assuming we're using the same (non-replay) seed as during pretraining + "replay_idx_paths_prefixes": [ + "data/pile/train/pile_train_train_0_indexmap_146862725ns_2048sl_1234s", + ], + "replay_data_weights":[ + 1.00, + ], + "replay_idx_offsets": [ + 1, + ], + # Fraction of samples coming from the replay buffer, between 0 and 1. + "replay_fraction": 0.5, + # Seed and reshuffle go hand in hand. They control whether you want to see the replay data in the same order + # as you've seen it (done by setting reshuffle to false), and if you decide to reshuffle, what seed you should + # use to reshuffle the seen data. + "replay_seed": 1234, + "replay_reshuffle_idx": false, + }, +} \ No newline at end of file diff --git a/tests/model/test_batch_replicability.py b/tests/model/test_batch_replicability.py new file mode 100644 index 000000000..bfe38f938 --- /dev/null +++ b/tests/model/test_batch_replicability.py @@ -0,0 +1,66 @@ +import torch +import numpy as np + +""" +Idea for the verification: save batches per rank on two datasets, one being a mix of say data sources +A+B, and the other being just A. To debug things, we used +A = pile +B = pile+slimp +and B being 50/50 mix of pile and slimp. We save 3 batches for pile+slimp to ensure that even with noise, the sampling +is as deterministic as we might think, the first batch of pile should be in the first 3 batches of pile+slimp. +To save the files used to debug, add in training.py below: + # Data stuff. + timers("train/valid/test data iterators").start() + ( + train_data_iterator, + valid_data_iterator, + test_data_iterator, + ) = build_train_valid_test_data_iterators(neox_args=neox_args) +the following + num_batches_to_test_replicability = 3 + save_dummy = torch.stack([next(train_data_iterator)["text"] for _ in range(num_batches_to_test_replicability)]) + torch.save(save_dummy[0], "pile_rank_{}.my_repro_batchs".format(neox_args.rank)) + print(save_dummy[0]) + # print(next(train_data_iterator), "\npile_slimp1") + # torch.distributed.barrier() + # print(next(train_data_iterator), "\npile_slimp2") + # torch.distributed.barrier() + # print(next(train_data_iterator), "\npile_slimp3") + torch.distributed.barrier() + exit(0) +you need 2 runs of this with the two datasets you consider. Don't forget to change the filename of the torch.save() +""" + +num_ranks = 6 +dataset_A = [None] * num_ranks +dataset_B = [None] * num_ranks +dataset_A_name = "pile_replay" +dataset_B_name = "pile+slimp" +# Use 0 to use all batches saved. If only one batch was saved, this param gets ignored. +num_batches_A = 3 +num_batches_B = 3 + + +# cat all batches in a limit of num_batches +def cat_only_process_num_batches(dataset, num_ranks, num_batches=0): + dim = 1 if len(dataset[0].shape)==3 \ + else 0 + # only use num_batches batches, if tensor has shape [num_batches, sample_idx_in_batch, seq_len] + if num_batches and dim: + new_shape_format = [-1, dataset[0].shape[-1]] + return (torch.cat([dataset[i][:num_batches] for i in range(num_ranks)], dim=dim)).view(new_shape_format) + else: + return torch.cat([dataset[i] for i in range(num_ranks)], dim=dim) + +for i in range(num_ranks): + dataset_A[i] = torch.load("gpt-neox/{}_rank_{}.adam".format(dataset_A_name, i)) + dataset_B[i] = torch.load("gpt-neox/{}_rank_{}.adam".format(dataset_B_name, i)) + +dataset_A_cat = cat_only_process_num_batches(dataset_A, num_ranks, num_batches=num_batches_A) +dataset_B_cat = cat_only_process_num_batches(dataset_B, num_ranks, num_batches=num_batches_B) +# dataset_B_cat = dataset_B_cat.reshape([-1, dataset_B_cat.shape[-1]]) +# dataset_A_cat = torch.cat([dataset_A[i][:num_batches_A] for i in range(num_ranks)], dim=0) + +checks = [False] * dataset_A_cat.shape[0] +for i in range(len(checks)): + checks[i] = torch.all(torch.isin(dataset_A_cat[i], dataset_B_cat)).item() From 5cdff7646da9c84065de7f5b9dd54e8f816d3487 Mon Sep 17 00:00:00 2001 From: github-actions Date: Sat, 13 Apr 2024 00:09:06 +0000 Subject: [PATCH 2/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 118 +++++++++++++++++++++++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index f0ea55eeb..e40e39dad 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 11a5537 + Default = defa0a4 current git hash of repository @@ -1378,6 +1378,122 @@ Training Arguments +- **replay_config**: dict + + Default = None + + Dictionary storing the replay config. + + + +- **is_replay_enabled**: bool + + Default = False + + Triggers the logic for replay. It is important to deal with replay separately from the general "train_data_paths" logic, as replay + requires reusing the same idx files to know what data was seen the first time a dataset was originally trained on. + If one attempts to do replay by just putting the datasets to be replayed in the train_data_paths instead of the replay params: + - If the exact same dataset files are used as during the 1st time it was seen, and the number of iterations on the replay buffer + corresponds to as many epochs on a replay dataset as the non-replay training, the data will be seen in exactly the same order as + the first time if the seed and sequence length is the same. + - For similar reasons, replaying multiple times on the same dataset (e.g. across multiple tasks) with the same number of epochs + on the replay dataset will lead to seeing the same data in the same order. + - If a different dataset is used for replay (e.g. different shard of Pile), then the shuffling will lead to completely different + indices, which will lead to potentially significant proportions of data being unseen if the original training on the replay dataset + did not see all of it, e.g. when training on 300B tokens of the GPT2-tokenised Pile which contains a few dozen billion more tokens, + then sharding the full dataset into smaller ones. + + + +- **replay_idx_paths_prefixes**: list + + Default = None + + List of paths prefixes to retrieve replay dataset idx files. Those idx files should have been generated when originally training on the dataset + being used for replay. They contain in the filename the number of samples potentially seen during pretraining, the sequence length and the + seed. The exact files (shuffle_idx, sample_idx and doc_idx) will be automatically derived from the prefix. Similarly, the data paths will + be generated from the prefixes. + The *_idx files are important as it allows one to know what data was seen in the dataset during training. If those files are missing, you can + regenerate them by relaunching the same training script (most importantly, config) used originally to pretrain on a given dataset. You + can add an exit(0) statement in training.py in pretrain() after the call to build_train_valid_test_data_iterators(neox_args=neox_args). + It is crucial to use the same dataset shard, sequence length, number of iterations, seed, and potentially batch size, or the indices + generated may not be the same. + For a single replay data source, the value passed looks like ["data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s"] and + the files at the following paths (the paths will be constructed during execution from the prefix), must exist: + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_doc_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_sample_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_shuffle_idx.npy" + "data/mydataset/train/mydataset" + + + +- **replay_data_to_idx_paths**: dict + + Default = None + + As indicated above, gets automatically built from the replay_idx_paths_prefixes by appending to it "_doc_idx.npy", "_sample_idx.npy" and + "_shuffle_idx.npy". It generates a dict of dict, with the data paths as keys, and dictionaries mapping each data path to the relevant + doc_idx, sample_idx and shuffle_idx file paths. Note that these files must exist at the relevant paths. + + + +- **replay_data_paths**: list + + Default = None + + As indicated above, gets automatically built from the replay_idx_paths_prefixes by removing the information about the idx files to retain + only the path to the dataset itself. + + + +- **replay_data_weights**: list + + Default = None + + List of 'weights' that decide how often to sample from each replay dataset when building the replay buffer. + + + +- **replay_idx_offsets**: list + + Default = None + + List of indices that decide where to start in the list of seen indices during pretraining on each replay dataset when building + the replay buffer. For example, when training originally on a dataset seeing 10000 samples, this allows to start looking at the + RESHUFFLED indices starting from idx replay_idx_offsets[i] for replay dataset i. + If not set, this will uniformly sample among all replay datasets. + + + +- **replay_fraction**: float + + Default = 0.05 + + Fraction of a batch dedicated to doing replay. For example, 0.1 means that in a batch of 100, 19 samples will come from the replay + buffer. Note that this means that if we train on 100B tokens, we will have only used 90B tokens from the datasets specified in + train_data_paths. + + + +- **replay_reshuffle_idx**: bool + + Default = True + + When index files are loaded from those the dataset was originally pretrained on, they will follow the exact same sequence of samples + seen when training on that dataset the first time if this is set to False. If True, the indices are reshuffled to prevent that. + + + +- **replay_seed**: int + + Default = 1234 + + Seed used to reshuffle indices accessed when originally training on a dataset, that are used to do replay. This is useful in the case + where replay is done twice on as many passes over the dataset, in which case if the same seed is used, the replay buffers in both case + will be exactly the same. + + + - **weight_by_num_documents**: bool Default = False From 6ff3ae6ce5ffb4d5cac95de77f5e686a18b1c8bd Mon Sep 17 00:00:00 2001 From: bentherien Date: Sun, 14 Apr 2024 16:24:21 -0400 Subject: [PATCH 3/8] added CPT code --- .../train/pile+slim_pajama_300B_each.yml | 25 +++++ configs/datasets/train/pile_shard0.yml | 33 ++++++ configs/datasets/train/pile_train.yml | 11 ++ configs/datasets/train/rp.yml | 32 ++++++ configs/datasets/train/slim_pajama_100B_1.yml | 24 +++++ .../train/slim_pajama_100B_1_replay5.yml | 28 +++++ configs/datasets/train/slim_pajama_100B_2.yml | 24 +++++ .../train/slim_pajama_100B_2_replay5.yml | 45 ++++++++ configs/datasets/train/slim_pajama_100B_3.yml | 24 +++++ .../train/slim_pajama_100B_3_replay5.yml | 61 +++++++++++ configs/datasets/train/slim_pajama_150B.yml | 23 ++++ configs/datasets/train/slim_pajama_200B_1.yml | 24 +++++ .../train/slim_pajama_200B_1_replay5.yml | 23 ++++ configs/datasets/train/slim_pajama_200B_2.yml | 24 +++++ .../train/slim_pajama_200B_2_replay5.yml | 45 ++++++++ configs/datasets/train/slim_pajama_200B_3.yml | 24 +++++ .../train/slim_pajama_200B_3_replay5.yml | 62 +++++++++++ configs/datasets/train/slim_pajama_300B.yml | 23 ++++ .../train/slim_pajama_300B_50_replay.yml | 45 ++++++++ .../train/slim_pajama_300B_replay0-5.yml | 23 ++++ .../train/slim_pajama_300B_replay1.yml | 23 ++++ .../train/slim_pajama_300B_replay10.yml | 23 ++++ .../train/slim_pajama_300B_replay5.yml | 23 ++++ .../train/slim_pajama_300B_replay50.yml | 24 +++++ configs/datasets/train/slim_pajama_606B.yml | 23 ++++ configs/datasets/train/slim_pajama_75B.yml | 24 +++++ .../datasets/train/slim_pajama_workshop.yml | 55 ++++++++++ configs/datasets/train/taskwise/ArXiv.yml | 11 ++ configs/datasets/train/taskwise/Book.yml | 12 +++ configs/datasets/train/taskwise/C4.yml | 12 +++ .../datasets/train/taskwise/CommonCrawl.yml | 11 ++ configs/datasets/train/taskwise/Github.yml | 12 +++ .../datasets/train/taskwise/StackExchange.yml | 11 ++ configs/datasets/train/taskwise/Wikipedia.yml | 11 ++ configs/datasets/val/pile_german.yml | 15 +++ configs/datasets/val/pile_rp-no-se.yml | 36 +++++++ configs/datasets/val/pile_rp.yml | 40 +++++++ configs/datasets/val/pile_rp_subsets.yml | 62 +++++++++++ configs/datasets/val/pile_slimp.yml | 15 +++ configs/datasets/val/pile_slimp_domains.yml | 30 ++++++ configs/datasets/val/pile_slimp_workshop.yml | 32 ++++++ configs/datasets/val/pile_val.yml | 13 +++ configs/datasets/val/pile_val_shard0.yml | 13 +++ configs/iclr_models/3b_test.yml | 101 ++++++++++++++++++ configs/iclr_models/410M.yml | 99 +++++++++++++++++ configs/iclr_models/410M_ckpt_100.yml | 100 +++++++++++++++++ configs/iclr_models/49M.yml | 100 +++++++++++++++++ configs/iclr_models/7_1B.yml | 94 ++++++++++++++++ configs/load/3e-5const_0_410M_143_CPT.yml | 3 + configs/load/none.yml | 3 + configs/load/pythia_2-8B_143000.yml | 3 + configs/load/pythia_410m.yml | 3 + configs/load/pythia_410m_10000.yml | 3 + configs/load/pythia_410m_143000.yml | 3 + configs/load/pythia_410m_27000.yml | 3 + configs/load/pythia_6-9B_143000.yml | 3 + configs/load/pythia_deduped_410m_10000.yml | 3 + configs/load/pythia_deduped_410m_143000.yml | 3 + configs/load/pythia_deduped_410m_27000.yml | 3 + .../load/resume_1-2e-4_001_7-1B_pile_PT.yml | 3 + ...resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml | 3 + .../resume_1-2e-4_001_7-1B_slim_pajama_PT.yml | 3 + .../load/resume_1-5e-4_001_410M_143_CPT.yml | 3 + configs/load/resume_3e-4_001_410M_143_CPT.yml | 3 + ...sume_3e-4_001_410M_slim_pajama_CPT_r05.yml | 3 + ...esume_3e-4_001_410M_slim_pajama_CPT_r1.yml | 3 + ...sume_3e-4_001_410M_slim_pajama_CPT_r10.yml | 3 + ...esume_3e-4_001_410M_slim_pajama_CPT_r5.yml | 3 + configs/load/resume_3e-4_001_7-1B_pile_PT.yml | 3 + configs/load/resume_6e-4_001_410M_143_CPT.yml | 3 + configs/load/scratch.yml | 3 + .../load/test_3e-5const_0_410M_143_CPT.yml | 3 + configs/load/wu_001_lr1-5e-4_pile.yml | 3 + configs/load/wu_001_lr3e-4_pile.yml | 3 + configs/load/wu_001_lr6e-4_pile.yml | 3 + configs/pythia_410m_llama_setup_finetune.yml | 24 +++++ configs/pythia_410m_llama_setup_resume.yml | 24 +++++ ...B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml | 16 +++ .../adam_constant_lr3e-4_3e-4_wu-001.yml | 13 +++ .../adam_constant_lr3e-5_3e-5_wu-0.yml | 13 +++ .../adam_cosine-inf_lr3e-4_3e-5_wu-001.yml | 14 +++ .../adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml | 13 +++ .../adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml | 14 +++ .../adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml | 13 +++ .../adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml | 13 +++ .../adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml | 13 +++ .../adam_cosine_lr3e-4_3e-5_wu-0.yml | 13 +++ .../adam_cosine_lr3e-4_3e-5_wu-0005.yml | 13 +++ .../adam_cosine_lr3e-4_3e-5_wu-001.yml | 13 +++ .../adam_cosine_lr3e-4_3e-5_wu-002.yml | 13 +++ .../adam_cosine_lr6e-4_6e-5_wu-0.yml | 14 +++ .../adam_cosine_lr6e-4_6e-5_wu-0005.yml | 13 +++ .../adam_cosine_lr6e-4_6e-5_wu-001.yml | 14 +++ .../adam_cosine_lr6e-4_6e-5_wu-002.yml | 13 +++ .../adam_infcos_lr3e-4_3e-5_wu-001.yml | 17 +++ .../adam_infinv_lr3e-4_3e-5_wu-001.yml | 17 +++ .../adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml | 16 +++ megatron/data/data_utils.py | 92 ++++++++++++++++ megatron/neox_arguments/neox_args.py | 20 ++++ megatron/training.py | 48 ++++++--- megatron_config_1.json | 1 + train.py | 40 +++++++ 102 files changed, 2244 insertions(+), 12 deletions(-) create mode 100644 configs/datasets/train/pile+slim_pajama_300B_each.yml create mode 100644 configs/datasets/train/pile_shard0.yml create mode 100644 configs/datasets/train/pile_train.yml create mode 100644 configs/datasets/train/rp.yml create mode 100644 configs/datasets/train/slim_pajama_100B_1.yml create mode 100644 configs/datasets/train/slim_pajama_100B_1_replay5.yml create mode 100644 configs/datasets/train/slim_pajama_100B_2.yml create mode 100644 configs/datasets/train/slim_pajama_100B_2_replay5.yml create mode 100644 configs/datasets/train/slim_pajama_100B_3.yml create mode 100644 configs/datasets/train/slim_pajama_100B_3_replay5.yml create mode 100644 configs/datasets/train/slim_pajama_150B.yml create mode 100644 configs/datasets/train/slim_pajama_200B_1.yml create mode 100644 configs/datasets/train/slim_pajama_200B_1_replay5.yml create mode 100644 configs/datasets/train/slim_pajama_200B_2.yml create mode 100644 configs/datasets/train/slim_pajama_200B_2_replay5.yml create mode 100644 configs/datasets/train/slim_pajama_200B_3.yml create mode 100644 configs/datasets/train/slim_pajama_200B_3_replay5.yml create mode 100644 configs/datasets/train/slim_pajama_300B.yml create mode 100644 configs/datasets/train/slim_pajama_300B_50_replay.yml create mode 100644 configs/datasets/train/slim_pajama_300B_replay0-5.yml create mode 100644 configs/datasets/train/slim_pajama_300B_replay1.yml create mode 100644 configs/datasets/train/slim_pajama_300B_replay10.yml create mode 100644 configs/datasets/train/slim_pajama_300B_replay5.yml create mode 100644 configs/datasets/train/slim_pajama_300B_replay50.yml create mode 100644 configs/datasets/train/slim_pajama_606B.yml create mode 100644 configs/datasets/train/slim_pajama_75B.yml create mode 100644 configs/datasets/train/slim_pajama_workshop.yml create mode 100644 configs/datasets/train/taskwise/ArXiv.yml create mode 100644 configs/datasets/train/taskwise/Book.yml create mode 100644 configs/datasets/train/taskwise/C4.yml create mode 100644 configs/datasets/train/taskwise/CommonCrawl.yml create mode 100644 configs/datasets/train/taskwise/Github.yml create mode 100644 configs/datasets/train/taskwise/StackExchange.yml create mode 100644 configs/datasets/train/taskwise/Wikipedia.yml create mode 100644 configs/datasets/val/pile_german.yml create mode 100644 configs/datasets/val/pile_rp-no-se.yml create mode 100644 configs/datasets/val/pile_rp.yml create mode 100644 configs/datasets/val/pile_rp_subsets.yml create mode 100644 configs/datasets/val/pile_slimp.yml create mode 100644 configs/datasets/val/pile_slimp_domains.yml create mode 100644 configs/datasets/val/pile_slimp_workshop.yml create mode 100644 configs/datasets/val/pile_val.yml create mode 100644 configs/datasets/val/pile_val_shard0.yml create mode 100644 configs/iclr_models/3b_test.yml create mode 100644 configs/iclr_models/410M.yml create mode 100644 configs/iclr_models/410M_ckpt_100.yml create mode 100644 configs/iclr_models/49M.yml create mode 100644 configs/iclr_models/7_1B.yml create mode 100644 configs/load/3e-5const_0_410M_143_CPT.yml create mode 100644 configs/load/none.yml create mode 100644 configs/load/pythia_2-8B_143000.yml create mode 100644 configs/load/pythia_410m.yml create mode 100644 configs/load/pythia_410m_10000.yml create mode 100644 configs/load/pythia_410m_143000.yml create mode 100644 configs/load/pythia_410m_27000.yml create mode 100644 configs/load/pythia_6-9B_143000.yml create mode 100644 configs/load/pythia_deduped_410m_10000.yml create mode 100644 configs/load/pythia_deduped_410m_143000.yml create mode 100644 configs/load/pythia_deduped_410m_27000.yml create mode 100644 configs/load/resume_1-2e-4_001_7-1B_pile_PT.yml create mode 100644 configs/load/resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml create mode 100644 configs/load/resume_1-2e-4_001_7-1B_slim_pajama_PT.yml create mode 100644 configs/load/resume_1-5e-4_001_410M_143_CPT.yml create mode 100644 configs/load/resume_3e-4_001_410M_143_CPT.yml create mode 100644 configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r05.yml create mode 100644 configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r1.yml create mode 100644 configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r10.yml create mode 100644 configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r5.yml create mode 100644 configs/load/resume_3e-4_001_7-1B_pile_PT.yml create mode 100644 configs/load/resume_6e-4_001_410M_143_CPT.yml create mode 100644 configs/load/scratch.yml create mode 100644 configs/load/test_3e-5const_0_410M_143_CPT.yml create mode 100644 configs/load/wu_001_lr1-5e-4_pile.yml create mode 100644 configs/load/wu_001_lr3e-4_pile.yml create mode 100644 configs/load/wu_001_lr6e-4_pile.yml create mode 100644 configs/pythia_410m_llama_setup_finetune.yml create mode 100644 configs/pythia_410m_llama_setup_resume.yml create mode 100644 configs/schedules/7_1B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml create mode 100644 configs/schedules/adam_constant_lr3e-4_3e-4_wu-001.yml create mode 100644 configs/schedules/adam_constant_lr3e-5_3e-5_wu-0.yml create mode 100644 configs/schedules/adam_cosine-inf_lr3e-4_3e-5_wu-001.yml create mode 100644 configs/schedules/adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml create mode 100644 configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml create mode 100644 configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml create mode 100644 configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml create mode 100644 configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml create mode 100644 configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0.yml create mode 100644 configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0005.yml create mode 100644 configs/schedules/adam_cosine_lr3e-4_3e-5_wu-001.yml create mode 100644 configs/schedules/adam_cosine_lr3e-4_3e-5_wu-002.yml create mode 100644 configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0.yml create mode 100644 configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0005.yml create mode 100644 configs/schedules/adam_cosine_lr6e-4_6e-5_wu-001.yml create mode 100644 configs/schedules/adam_cosine_lr6e-4_6e-5_wu-002.yml create mode 100644 configs/schedules/adam_infcos_lr3e-4_3e-5_wu-001.yml create mode 100644 configs/schedules/adam_infinv_lr3e-4_3e-5_wu-001.yml create mode 100644 configs/schedules/adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml create mode 100644 megatron_config_1.json diff --git a/configs/datasets/train/pile+slim_pajama_300B_each.yml b/configs/datasets/train/pile+slim_pajama_300B_each.yml new file mode 100644 index 000000000..94f14fd58 --- /dev/null +++ b/configs/datasets/train/pile+slim_pajama_300B_each.yml @@ -0,0 +1,25 @@ +{ + # This will sample with equal likelihood Pile and SlimPajama: + "train-data-paths": [ + "data/pile/train/pile_train", + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], + "train-data-weights": [ + 50.0, + 2.2140923205, + 2.101565663, + 13.344249736, + 1.9986465625, + 2.612070528, + 1.6855393625, + 26.0438358255 + ], + "train-dataset-name": 'pile+slim_pajama_300B_each', + "train-iters": 264732, + "lr-decay-iters": 264732, +} \ No newline at end of file diff --git a/configs/datasets/train/pile_shard0.yml b/configs/datasets/train/pile_shard0.yml new file mode 100644 index 000000000..421fd2804 --- /dev/null +++ b/configs/datasets/train/pile_shard0.yml @@ -0,0 +1,33 @@ +{ + "train-data-paths": [ + "data/pile/shard_0/shard_0_text_document", + ], + "train-data-weights": [ + 1., + ], + "train-dataset-name": 'pile_shard0', + "train-iters": 1000, + "lr-decay-iters": 1000, + "is_replay_enabled": true, + "replay_config": { + "enabled": true, + # Have to specify idx filenames from original pretraining on tasks, as they contain the num iterations + # and seen indices assuming we're using the same (non-replay) seed as during pretraining + "replay_idx_paths_prefixes": [ + "data/pile/shard_0/shard_0_text_document_train_0_indexmap_32160ns_2048sl_1234s", + ], + "replay_data_weights":[ + 1.00, + ], + "replay_idx_offsets": [ + 1, + ], + # Fraction of samples coming from the replay buffer, between 0 and 1. + "replay_fraction": 0.5, + # Seed and reshuffle go hand in hand. They control whether you want to see the replay data in the same order + # as you've seen it (done by setting reshuffle to false), and if you decide to reshuffle, what seed you should + # use to reshuffle the seen data. + "replay_seed": 1234, + "replay_reshuffle_idx": false, + }, +} \ No newline at end of file diff --git a/configs/datasets/train/pile_train.yml b/configs/datasets/train/pile_train.yml new file mode 100644 index 000000000..ff37e2c74 --- /dev/null +++ b/configs/datasets/train/pile_train.yml @@ -0,0 +1,11 @@ +{ + "train-data-paths": [ + "data/pile/train/pile_train", + ], + "train-data-weights": [ + 1., + ], + "train-dataset-name": 'pile_train', + "train-iters": 132366, + "lr-decay-iters": 132366, +} \ No newline at end of file diff --git a/configs/datasets/train/rp.yml b/configs/datasets/train/rp.yml new file mode 100644 index 000000000..cfea01a64 --- /dev/null +++ b/configs/datasets/train/rp.yml @@ -0,0 +1,32 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/arxiv/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/book/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/c4/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/wikipedia/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/github/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/stackexchange/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2019-30/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2020-05/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2021-04/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2022-05/folder_train/tokenized_text_document", + "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2023-06/folder_train/tokenized_text_document", + ], + "train-data-weights": [ + 2.5, + 4.5, + 15.0, + 4.5, + 4.5, + 2.0, + 13.4, + 13.4, + 13.4, + 13.4, + 13.4 + ], + "train-dataset-name": 'rp', + + +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_1.yml b/configs/datasets/train/slim_pajama_100B_1.yml new file mode 100644 index 000000000..9bbf75249 --- /dev/null +++ b/configs/datasets/train/slim_pajama_100B_1.yml @@ -0,0 +1,24 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_0-100B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_0-100B/Book/Book', + 'data/slim_pajama/tokenized_train_0-100B/C4/C4', + 'data/slim_pajama/tokenized_train_0-100B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_0-100B/Github/Github', + 'data/slim_pajama/tokenized_train_0-100B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_0-100B/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [ + 3.4703977435152775, + 3.904381603212791, + 25.641950653802013, + 3.804228253591696, + 4.9994643949282045, + 3.1815838172641993, + 49.99799353368582, + ], + "train-iters": 44229, + "lr-decay-iters": 44229, + "train-dataset-name": 'slim_pajama_100B_1', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_1_replay5.yml b/configs/datasets/train/slim_pajama_100B_1_replay5.yml new file mode 100644 index 000000000..2d957c3e5 --- /dev/null +++ b/configs/datasets/train/slim_pajama_100B_1_replay5.yml @@ -0,0 +1,28 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_0-100B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_0-100B/Book/Book', + 'data/slim_pajama/tokenized_train_0-100B/C4/C4', + 'data/slim_pajama/tokenized_train_0-100B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_0-100B/Github/Github', + 'data/slim_pajama/tokenized_train_0-100B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_0-100B/CommonCrawl/CommonCrawl', + + 'data/pile_replay_shards/replay_10B_1/splits', + ], + "train-data-weights": [ + 3.4703977435152775, + 3.904381603212791, + 25.641950653802013, + 3.804228253591696, + 4.9994643949282045, + 3.1815838172641993, + 49.99799353368582, + + 5.0 + ], + "train-iters": 44229, + "lr-decay-iters": 44229, + "train-dataset-name": 'slim_pajama_100B_1_replay5', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_2.yml b/configs/datasets/train/slim_pajama_100B_2.yml new file mode 100644 index 000000000..f59f56860 --- /dev/null +++ b/configs/datasets/train/slim_pajama_100B_2.yml @@ -0,0 +1,24 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_100B-200B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_100B-200B/Book/Book', + 'data/slim_pajama/tokenized_train_100B-200B/C4/C4', + 'data/slim_pajama/tokenized_train_100B-200B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_100B-200B/Github/Github', + 'data/slim_pajama/tokenized_train_100B-200B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_100B-200B/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [ + 4.03666599074094, + 3.927523855378127, + 25.467175464208918, + 3.7984379710376293, + 4.990226864678155, + 3.1957646326079723, + 49.58420522134826, + ], + "train-iters": 44229, + "lr-decay-iters": 44229, + "train-dataset-name": 'slim_pajama_100B_2', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_2_replay5.yml b/configs/datasets/train/slim_pajama_100B_2_replay5.yml new file mode 100644 index 000000000..ce1ba6f62 --- /dev/null +++ b/configs/datasets/train/slim_pajama_100B_2_replay5.yml @@ -0,0 +1,45 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_100B-200B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_100B-200B/Book/Book', + 'data/slim_pajama/tokenized_train_100B-200B/C4/C4', + 'data/slim_pajama/tokenized_train_100B-200B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_100B-200B/Github/Github', + 'data/slim_pajama/tokenized_train_100B-200B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_100B-200B/CommonCrawl/CommonCrawl', + + 'data/pile_replay_shards/replay_10B_2/splits', + + 'data/sp_replay_shards/100B_1_shard1/ArXiv/ArXiv', + 'data/sp_replay_shards/100B_1_shard1/Book/Book', + 'data/sp_replay_shards/100B_1_shard1/C4/C4', + 'data/sp_replay_shards/100B_1_shard1/Wikipedia/Wikipedia', + 'data/sp_replay_shards/100B_1_shard1/Github/Github', + 'data/sp_replay_shards/100B_1_shard1/StackExchange/StackExchange', + 'data/sp_replay_shards/100B_1_shard1/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [ + 4.03666599074094, + 3.927523855378127, + 25.467175464208918, + 3.7984379710376293, + 4.990226864678155, + 3.1957646326079723, + 49.58420522134826, + + 3.8125, + + # total: 1.1875, + 0.04337997179394097, + 0.04880477004015989, + 0.3205243831725252, + 0.0475528531698962, + 0.06249330493660256, + 0.03976979771580249, + 0.6249749191710727, + ], + "train-iters": 44229, + "lr-decay-iters": 44229, + "train-dataset-name": 'slim_pajama_100B_2_replay5', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_3.yml b/configs/datasets/train/slim_pajama_100B_3.yml new file mode 100644 index 000000000..6cf015267 --- /dev/null +++ b/configs/datasets/train/slim_pajama_100B_3.yml @@ -0,0 +1,24 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_200B-300B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_200B-300B/Book/Book', + 'data/slim_pajama/tokenized_train_200B-300B/C4/C4', + 'data/slim_pajama/tokenized_train_200B-300B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_200B-300B/Github/Github', + 'data/slim_pajama/tokenized_train_200B-300B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_200B-300B/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [ + 3.491756366873565, + 4.084283062119696, + 25.524317038754475, + 3.8109321899190314, + 4.89534056131328, + 3.254459546224121, + 49.93891123479581, + ], + "train-iters": 44229, + "lr-decay-iters": 44229, + "train-dataset-name": 'slim_pajama_100B_3', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_3_replay5.yml b/configs/datasets/train/slim_pajama_100B_3_replay5.yml new file mode 100644 index 000000000..9f537c532 --- /dev/null +++ b/configs/datasets/train/slim_pajama_100B_3_replay5.yml @@ -0,0 +1,61 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_200B-300B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_200B-300B/Book/Book', + 'data/slim_pajama/tokenized_train_200B-300B/C4/C4', + 'data/slim_pajama/tokenized_train_200B-300B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_200B-300B/Github/Github', + 'data/slim_pajama/tokenized_train_200B-300B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_200B-300B/CommonCrawl/CommonCrawl', + + 'data/pile_replay_shards/replay_10B_3/splits', + + 'data/sp_replay_shards/100B_1_shard2/ArXiv/ArXiv', + 'data/sp_replay_shards/100B_1_shard2/Book/Book', + 'data/sp_replay_shards/100B_1_shard2/C4/C4', + 'data/sp_replay_shards/100B_1_shard2/Wikipedia/Wikipedia', + 'data/sp_replay_shards/100B_1_shard2/Github/Github', + 'data/sp_replay_shards/100B_1_shard2/StackExchange/StackExchange', + 'data/sp_replay_shards/100B_1_shard2/CommonCrawl/CommonCrawl', + + 'data/sp_replay_shards/100B_2_shard1/ArXiv/ArXiv', + 'data/sp_replay_shards/100B_2_shard1/Book/Book', + 'data/sp_replay_shards/100B_2_shard1/C4/C4', + 'data/sp_replay_shards/100B_2_shard1/Wikipedia/Wikipedia', + 'data/sp_replay_shards/100B_2_shard1/Github/Github', + 'data/sp_replay_shards/100B_2_shard1/StackExchange/StackExchange', + 'data/sp_replay_shards/100B_2_shard1/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [3.491756366873565, + 4.084283062119696, + 25.524317038754475, + 3.8109321899190314, + 4.89534056131328, + 3.254459546224121, + 49.93891123479581, + + 3.088125, + + # total: 0.961875, + 0.03513777715309219, + 0.03953186373252951, + 0.2596247503697454, + 0.03851781106761592, + 0.05061957699864807, + 0.03221353614980002, + 0.506229684528569, + + #total: 0.95, + 0.0403666599074094, + 0.03927523855378127, + 0.25467175464208913, + 0.03798437971037629, + 0.049902268646781545, + 0.03195764632607972, + 0.4958420522134826, + ], + "train-iters": 44229, + "lr-decay-iters": 44229, + "train-dataset-name": 'slim_pajama_100B_3_replay5', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_150B.yml b/configs/datasets/train/slim_pajama_150B.yml new file mode 100644 index 000000000..477334a4a --- /dev/null +++ b/configs/datasets/train/slim_pajama_150B.yml @@ -0,0 +1,23 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_150B/ArXiv/ArXiv', + 'data/slim_pajama/train_150B/Book/Book', + 'data/slim_pajama/train_150B/C4/C4', + 'data/slim_pajama/train_150B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_150B/Github/Github', + 'data/slim_pajama/train_150B/StackExchange/StackExchange', + 'data/slim_pajama/train_150B/CommonCrawl/CommonCrawl',], + "train-data-weights": [ + 4.576447650075095, + 4.198505982426652, + 26.62982374026485, + 3.9945183507095225, + 5.218824282422116, + 3.372167199706489, + 52.00971279439528 + ], + "train-dataset-name": 'slim_pajama_150B', + "train-iters": 66342, + "lr-decay-iters": 66342, +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_1.yml b/configs/datasets/train/slim_pajama_200B_1.yml new file mode 100644 index 000000000..b242479a4 --- /dev/null +++ b/configs/datasets/train/slim_pajama_200B_1.yml @@ -0,0 +1,24 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_0-200B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_0-200B/Book/Book', + 'data/slim_pajama/tokenized_train_0-200B/C4/C4', + 'data/slim_pajama/tokenized_train_0-200B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_0-200B/Github/Github', + 'data/slim_pajama/tokenized_train_0-200B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_0-200B/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [ + 3.4703977435152775, + 3.904381603212791, + 25.641950653802013, + 3.804228253591696, + 4.9994643949282045, + 3.1815838172641993, + 49.99799353368582, + ], + "train-iters": 88457, + "lr-decay-iters": 88457, + "train-dataset-name": 'slim_pajama_200B_1', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_1_replay5.yml b/configs/datasets/train/slim_pajama_200B_1_replay5.yml new file mode 100644 index 000000000..518391cd4 --- /dev/null +++ b/configs/datasets/train/slim_pajama_200B_1_replay5.yml @@ -0,0 +1,23 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_0-200B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_0-200B/Book/Book', + 'data/slim_pajama/tokenized_train_0-200B/C4/C4', + 'data/slim_pajama/tokenized_train_0-200B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_0-200B/Github/Github', + 'data/slim_pajama/tokenized_train_0-200B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_0-200B/CommonCrawl/CommonCrawl', + 'data/pile_replay_shards/replay_10B_1/splits',], + "train-data-weights": [3.4703977435152775, + 3.904381603212791, + 25.641950653802013, + 3.804228253591696, + 4.9994643949282045, + 3.1815838172641993, + 49.99799353368582, + 5.0], + "train-iters": 88457, + "lr-decay-iters": 88457, + "train-dataset-name": 'slim_pajama_200B_1_replay5', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_2.yml b/configs/datasets/train/slim_pajama_200B_2.yml new file mode 100644 index 000000000..753831a72 --- /dev/null +++ b/configs/datasets/train/slim_pajama_200B_2.yml @@ -0,0 +1,24 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_200B-400B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_200B-400B/Book/Book', + 'data/slim_pajama/tokenized_train_200B-400B/C4/C4', + 'data/slim_pajama/tokenized_train_200B-400B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_200B-400B/Github/Github', + 'data/slim_pajama/tokenized_train_200B-400B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_200B-400B/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [ + 4.3217120887898215, + 4.0000865058486115, + 25.313223824892418, + 3.7875979441876595, + 4.916178735276899, + 3.205115989375923, + 49.45608491162867 + ], + "train-iters": 88457, + "lr-decay-iters": 88457, + "train-dataset-name": 'slim_pajama_200B_2', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_2_replay5.yml b/configs/datasets/train/slim_pajama_200B_2_replay5.yml new file mode 100644 index 000000000..b71420699 --- /dev/null +++ b/configs/datasets/train/slim_pajama_200B_2_replay5.yml @@ -0,0 +1,45 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_200B-400B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_200B-400B/Book/Book', + 'data/slim_pajama/tokenized_train_200B-400B/C4/C4', + 'data/slim_pajama/tokenized_train_200B-400B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_200B-400B/Github/Github', + 'data/slim_pajama/tokenized_train_200B-400B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_200B-400B/CommonCrawl/CommonCrawl', + + 'data/pile_replay_shards/replay_10B_1/splits', + + 'data/sp_replay_shards/200B_1_shard1/ArXiv/ArXiv', + 'data/sp_replay_shards/200B_1_shard1/Book/Book', + 'data/sp_replay_shards/200B_1_shard1/C4/C4', + 'data/sp_replay_shards/200B_1_shard1/Wikipedia/Wikipedia', + 'data/sp_replay_shards/200B_1_shard1/Github/Github', + 'data/sp_replay_shards/200B_1_shard1/StackExchange/StackExchange', + 'data/sp_replay_shards/200B_1_shard1/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [ + 4.3217120887898215, + 4.0000865058486115, + 25.313223824892418, + 3.7875979441876595, + 4.916178735276899, + 3.205115989375923, + 49.45608491162867, + + 3.8125, + + #total 1.1875 + 0.04337997179394097, + 0.04880477004015989, + 0.3205243831725252, + 0.0475528531698962, + 0.06249330493660256, + 0.03976979771580249, + 0.6249749191710727, + ], + "train-iters": 88457, + "lr-decay-iters": 88457, + "train-dataset-name": 'slim_pajama_200B_2_replay5', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_3.yml b/configs/datasets/train/slim_pajama_200B_3.yml new file mode 100644 index 000000000..9d36e7c4d --- /dev/null +++ b/configs/datasets/train/slim_pajama_200B_3.yml @@ -0,0 +1,24 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_400B-600B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_400B-600B/Book/Book', + 'data/slim_pajama/tokenized_train_400B-600B/C4/C4', + 'data/slim_pajama/tokenized_train_400B-600B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_400B-600B/Github/Github', + 'data/slim_pajama/tokenized_train_400B-600B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_400B-600B/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [ + 4.3217120887898215, + 4.0000865058486115, + 25.313223824892418, + 3.7875979441876595, + 4.916178735276899, + 3.205115989375923, + 49.45608491162867 + ], + "train-iters": 88457, + "lr-decay-iters": 88457, + "train-dataset-name": 'slim_pajama_200B_3', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_3_replay5.yml b/configs/datasets/train/slim_pajama_200B_3_replay5.yml new file mode 100644 index 000000000..cdcfcefad --- /dev/null +++ b/configs/datasets/train/slim_pajama_200B_3_replay5.yml @@ -0,0 +1,62 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/tokenized_train_400B-600B/ArXiv/ArXiv', + 'data/slim_pajama/tokenized_train_400B-600B/Book/Book', + 'data/slim_pajama/tokenized_train_400B-600B/C4/C4', + 'data/slim_pajama/tokenized_train_400B-600B/Wikipedia/Wikipedia', + 'data/slim_pajama/tokenized_train_400B-600B/Github/Github', + 'data/slim_pajama/tokenized_train_400B-600B/StackExchange/StackExchange', + 'data/slim_pajama/tokenized_train_400B-600B/CommonCrawl/CommonCrawl', + + 'data/pile_replay_shards/replay_10B_3/splits', + + 'data/sp_replay_shards/200B_1_shard2/ArXiv/ArXiv', + 'data/sp_replay_shards/200B_1_shard2/Book/Book', + 'data/sp_replay_shards/200B_1_shard2/C4/C4', + 'data/sp_replay_shards/200B_1_shard2/Wikipedia/Wikipedia', + 'data/sp_replay_shards/200B_1_shard2/Github/Github', + 'data/sp_replay_shards/200B_1_shard2/StackExchange/StackExchange', + 'data/sp_replay_shards/200B_1_shard2/CommonCrawl/CommonCrawl', + + 'data/sp_replay_shards/200B_2_shard1/ArXiv/ArXiv', + 'data/sp_replay_shards/200B_2_shard1/Book/Book', + 'data/sp_replay_shards/200B_2_shard1/C4/C4', + 'data/sp_replay_shards/200B_2_shard1/Wikipedia/Wikipedia', + 'data/sp_replay_shards/200B_2_shard1/Github/Github', + 'data/sp_replay_shards/200B_2_shard1/StackExchange/StackExchange', + 'data/sp_replay_shards/200B_2_shard1/CommonCrawl/CommonCrawl', + ], + "train-data-weights": [ + 4.3217120887898215, + 4.0000865058486115, + 25.313223824892418, + 3.7875979441876595, + 4.916178735276899, + 3.205115989375923, + 49.45608491162867, + + 3.088125, + + #total: 0.961875, + 0.03513777715309219, + 0.03953186373252951, + 0.2596247503697454, + 0.03851781106761592, + 0.05061957699864807, + 0.03221353614980002, + 0.506229684528569, + + #total: 0.95, + 0.043217120887898204, + 0.0400008650584861, + 0.2531322382489241, + 0.03787597944187659, + 0.04916178735276898, + 0.03205115989375923, + 0.49456084911628667, + ], + "train-iters": 88457, + "lr-decay-iters": 88457, + "train-dataset-name": 'slim_pajama_200B_3_replay5', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B.yml b/configs/datasets/train/slim_pajama_300B.yml new file mode 100644 index 000000000..1bc5be68d --- /dev/null +++ b/configs/datasets/train/slim_pajama_300B.yml @@ -0,0 +1,23 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], + "train-data-weights": [ + 4.428184641, + 4.203131326, + 26.688499472, + 3.997293125, + 5.224141056, + 3.371078725, + 52.087671651 + ], + "train-iters": 132366, + "lr-decay-iters": 132366, + "train-dataset-name": 'slim_pajama_300B', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_50_replay.yml b/configs/datasets/train/slim_pajama_300B_50_replay.yml new file mode 100644 index 000000000..841230d2a --- /dev/null +++ b/configs/datasets/train/slim_pajama_300B_50_replay.yml @@ -0,0 +1,45 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], + "train-data-weights": [ + 4.428184641, + 4.203131326, + 26.688499472, + 3.997293125, + 5.224141056, + 3.371078725, + 52.087671651 + ], + "train-iters": 132366, + "lr-decay-iters": 132366, + "train-dataset-name": 'slim_pajama_300B', + + "replay_config": { + "enabled": true, + # Have to specify idx filenames from original pretraining on tasks, as they contain the num iterations + # and seen indices assuming we're using the same (non-replay) seed as during pretraining + "replay_idx_paths_prefixes": [ + "data/pile/train/pile_train_train_0_indexmap_146862725ns_2048sl_1234s", + ], + "replay_data_weights":[ + 1.00, + ], + "replay_idx_offsets": [ + 0, + ], + # Fraction of samples coming from the replay buffer, between 0 and 1. + "replay_fraction": 0.5, + # Will need to reshuffle the shuffled indices. If you have replay multiple times on the same task, don't + # forget to change it every time if not manually managing offsets ! Otherwise you will see the same replay + # buffer in the same order. + "replay_seed": 1234, + "replay_reshuffle_idx": true, + }, +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay0-5.yml b/configs/datasets/train/slim_pajama_300B_replay0-5.yml new file mode 100644 index 000000000..6069b4cd9 --- /dev/null +++ b/configs/datasets/train/slim_pajama_300B_replay0-5.yml @@ -0,0 +1,23 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', + 'data/pile_replay_shards/replay_1-5B/splits',], + "train-data-weights": [4.406043717971242, + 4.182115669537286, + 26.555056975702207, + 3.977306659534093, + 5.198020350927921, + 3.354223331509169, + 51.82723329481809, + 0.5], + "train-iters": 132366, + "lr-decay-iters": 132366, + "train-dataset-name": 'slim_pajama_300B_replay05', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay1.yml b/configs/datasets/train/slim_pajama_300B_replay1.yml new file mode 100644 index 000000000..6983c3a27 --- /dev/null +++ b/configs/datasets/train/slim_pajama_300B_replay1.yml @@ -0,0 +1,23 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', + 'data/pile_replay_shards/replay_3B/splits',], + "train-data-weights": [4.383902794765357, + 4.161100012906445, + 26.42161447833687, + 3.9573201939082936, + 5.171899645646876, + 3.337367937883495, + 51.566794936552675, + 1.0], + "train-iters": 132366, + "lr-decay-iters": 132366, + "train-dataset-name": 'slim_pajama_300B_replay1', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay10.yml b/configs/datasets/train/slim_pajama_300B_replay10.yml new file mode 100644 index 000000000..a85ede531 --- /dev/null +++ b/configs/datasets/train/slim_pajama_300B_replay10.yml @@ -0,0 +1,23 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', + 'data/pile_replay_shards/replay_30B/splits',], + "train-data-weights": [3.985366177059415, + 3.7828181935513134, + 24.01964952576079, + 3.5975638126439033, + 4.701726950588069, + 3.033970852621359, + 46.87890448777516, + 10.0], + "train-iters": 132366, + "lr-decay-iters": 132366, + "train-dataset-name": 'slim_pajama_300B_replay10', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay5.yml b/configs/datasets/train/slim_pajama_300B_replay5.yml new file mode 100644 index 000000000..c3e1b5713 --- /dev/null +++ b/configs/datasets/train/slim_pajama_300B_replay5.yml @@ -0,0 +1,23 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', + 'data/pile_replay_shards/replay_15B/splits',], + "train-data-weights": [4.206775409118271, + 3.99297475985972, + 25.354074499414168, + 3.797428468901898, + 4.962934003398518, + 3.2025247888781014, + 49.48328807042934, + 5.0], + "train-iters": 132366, + "lr-decay-iters": 132366, + "train-dataset-name": 'slim_pajama_300B_replay5', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay50.yml b/configs/datasets/train/slim_pajama_300B_replay50.yml new file mode 100644 index 000000000..372a4ad7b --- /dev/null +++ b/configs/datasets/train/slim_pajama_300B_replay50.yml @@ -0,0 +1,24 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + 'data/slim_pajama/train_300B/Book/Book', + 'data/slim_pajama/train_300B/C4/C4', + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_300B/Github/Github', + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', + 'data/pile/train/pile_train',], + "train-data-weights": [ + 2.2140923205885636, + 2.101565663084063, + 13.344249736533772, + 1.9986465625799463, + 2.612070528104483, + 1.6855393625674218, + 26.043835826541756, + 50.0], + "train-iters": 132366, + "lr-decay-iters": 132366, + "train-dataset-name": 'slim_pajama_300B_replay50', +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_606B.yml b/configs/datasets/train/slim_pajama_606B.yml new file mode 100644 index 000000000..159b382f5 --- /dev/null +++ b/configs/datasets/train/slim_pajama_606B.yml @@ -0,0 +1,23 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_606B/ArXiv/ArXiv', + 'data/slim_pajama/train_606B/Book/Book', + 'data/slim_pajama/train_606B/C4/C4', + 'data/slim_pajama/train_606B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_606B/Github/Github', + 'data/slim_pajama/train_606B/StackExchange/StackExchange', + 'data/slim_pajama/train_606B/CommonCrawl/CommonCrawl',], + "train-data-weights": [ + 4.576447650075095, + 4.198505982426652, + 26.62982374026485, + 3.9945183507095225, + 5.218824282422116, + 3.372167199706489, + 52.00971279439528 + ], + "train-dataset-name": 'slim_pajama_606B', + "train-iters": 268023, + "lr-decay-iters": 268023, +} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_75B.yml b/configs/datasets/train/slim_pajama_75B.yml new file mode 100644 index 000000000..c8195e881 --- /dev/null +++ b/configs/datasets/train/slim_pajama_75B.yml @@ -0,0 +1,24 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_75B/ArXiv/ArXiv', + 'data/slim_pajama/train_75B/Book/Book', + 'data/slim_pajama/train_75B/C4/C4', + 'data/slim_pajama/train_75B/Wikipedia/Wikipedia', + 'data/slim_pajama/train_75B/Github/Github', + 'data/slim_pajama/train_75B/StackExchange/StackExchange', + 'data/slim_pajama/train_75B/CommonCrawl/CommonCrawl',], + "train-data-weights": [ + 4.576447650075095, + 4.198505982426652, + 26.62982374026485, + 3.9945183507095225, + 5.218824282422116, + 3.372167199706489, + 52.00971279439528 + ], + "train-dataset-name": 'slim_pajama_75B', + "train-iters": 33171, + "lr-decay-iters": 33171, +} + diff --git a/configs/datasets/train/slim_pajama_workshop.yml b/configs/datasets/train/slim_pajama_workshop.yml new file mode 100644 index 000000000..894be1f06 --- /dev/null +++ b/configs/datasets/train/slim_pajama_workshop.yml @@ -0,0 +1,55 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/tokenized300B/train_splits/ArXiv/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk1/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk2/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk3/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk4/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk5/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk6/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk7/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk8/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk9/tokenized_text_document', + 'data/tokenized300B/train_splits/C4/chunk10/tokenized_text_document', + 'data/tokenized300B/train_splits/Github/tokenized_text_document', + 'data/tokenized300B/train_splits/StackExchange/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk1/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk2/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk3/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk4/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk5/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk6/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk7/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk8/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk9/tokenized_text_document', + 'data/tokenized300B/train_splits/CommonCrawl/chunk10/tokenized_text_document'], + "train-data-weights": [ + 2.5, + 1.5419773919518147, + 1.1151591815617388, + 1.5434825189166101, + 1.5436987632648516, + 1.5477789754805582, + 1.5415865562140247, + 1.5418730205678666, + 1.5429308066550182, + 1.5411863039403448, + 1.5403264814471718, + 4.5, + 2.0, + 6.879159444832248, + 4.995774267161819, + 6.872989472503661, + 6.874004825491571, + 6.892691875394247, + 6.889262861528587, + 6.869394427542272, + 6.909669621263436, + 6.915576906743681, + 6.90147629753848 + ], + "train-dataset-name": 'slim_pajama_workshop', + +} + diff --git a/configs/datasets/train/taskwise/ArXiv.yml b/configs/datasets/train/taskwise/ArXiv.yml new file mode 100644 index 000000000..83f68b22d --- /dev/null +++ b/configs/datasets/train/taskwise/ArXiv.yml @@ -0,0 +1,11 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/ArXiv/ArXiv', + ], + "train-data-weights": [ + 1.0, + ], + # 13252597099 + "train-iters": 5861, +} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/Book.yml b/configs/datasets/train/taskwise/Book.yml new file mode 100644 index 000000000..80311f835 --- /dev/null +++ b/configs/datasets/train/taskwise/Book.yml @@ -0,0 +1,12 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/Book/Book', + ], + "train-data-weights": [ + 1.0, + ], + + # 12579061292 5563 + "train-iters": 5563, +} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/C4.yml b/configs/datasets/train/taskwise/C4.yml new file mode 100644 index 000000000..d67b30b15 --- /dev/null +++ b/configs/datasets/train/taskwise/C4.yml @@ -0,0 +1,12 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/C4/C4', + ], + "train-data-weights": [ + 1.0, + ], + + # 79872895846 + "train-iters": 35326, +} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/CommonCrawl.yml b/configs/datasets/train/taskwise/CommonCrawl.yml new file mode 100644 index 000000000..a86b86ddf --- /dev/null +++ b/configs/datasets/train/taskwise/CommonCrawl.yml @@ -0,0 +1,11 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], + "train-data-weights": [ + 1.0, + ], + + # 155887114486 + "train-iters": 68946, +} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/Github.yml b/configs/datasets/train/taskwise/Github.yml new file mode 100644 index 000000000..309e885fe --- /dev/null +++ b/configs/datasets/train/taskwise/Github.yml @@ -0,0 +1,12 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/Github/Github', + ], + "train-data-weights": [ + 1.0, + ], + + # 15634722173 + "train-iters": 6914, +} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/StackExchange.yml b/configs/datasets/train/taskwise/StackExchange.yml new file mode 100644 index 000000000..16ea275e5 --- /dev/null +++ b/configs/datasets/train/taskwise/StackExchange.yml @@ -0,0 +1,11 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/StackExchange/StackExchange', + ], + "train-data-weights": [ + 1.0, + ], + # 10088908154 + "train-iters": 4462, +} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/Wikipedia.yml b/configs/datasets/train/taskwise/Wikipedia.yml new file mode 100644 index 000000000..9c087e8dc --- /dev/null +++ b/configs/datasets/train/taskwise/Wikipedia.yml @@ -0,0 +1,11 @@ +{ + # or for weighted datasets: + "train-data-paths": [ + 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', + ], + "train-data-weights": [ + 1.0, + ], + # 11963032160 + "train-iters": 5291, +} \ No newline at end of file diff --git a/configs/datasets/val/pile_german.yml b/configs/datasets/val/pile_german.yml new file mode 100644 index 000000000..1496469b9 --- /dev/null +++ b/configs/datasets/val/pile_german.yml @@ -0,0 +1,15 @@ +{ + "test-data-paths": ["data/pile/test/pile_test_text_document"], + "test-data-weights": [ + 1. + ], + "valid-data-paths": [ + ["data/pile/val/pile_val_text_document"], + ["data/german/val/val"], + ], + "valid-data-weights": [ + [1.], + [1.], + ], + "val-dataset-name": 'pile_german', +} diff --git a/configs/datasets/val/pile_rp-no-se.yml b/configs/datasets/val/pile_rp-no-se.yml new file mode 100644 index 000000000..601b742fe --- /dev/null +++ b/configs/datasets/val/pile_rp-no-se.yml @@ -0,0 +1,36 @@ + + "test-data-paths": ["data/red_pajama_400B/the_pile/test_tokenized_text_document"], + "test-data-weights": [ + 1. + ], + "valid-data-paths": [ + ["data/red_pajama_400B/the_pile/val_tokenized_text_document"], + [ + "data/red_pajama_400B/arxiv/folder_val/tokenized_text_document", + "data/red_pajama_400B/book/folder_val/tokenized_text_document", + "data/red_pajama_400B/c4/folder_val/tokenized_text_document", + "data/red_pajama_400B/wikipedia/folder_val/tokenized_text_document", + "data/red_pajama_400B/github/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2019-30/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2020-05/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2021-04/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2022-05/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2023-06/folder_val/tokenized_text_document", + ], + ], + "valid-data-weights": [ + [1.], + [ + 2.5, + 4.5, + 15.0, + 4.5, + 4.5, + 13.4, + 13.4, + 13.4, + 13.4, + 13.4 + ], + "val-dataset-name": 'pile_rp-no-se', + ], \ No newline at end of file diff --git a/configs/datasets/val/pile_rp.yml b/configs/datasets/val/pile_rp.yml new file mode 100644 index 000000000..16e368b48 --- /dev/null +++ b/configs/datasets/val/pile_rp.yml @@ -0,0 +1,40 @@ +{ + "test-data-paths": ["data/red_pajama_400B/the_pile/test_tokenized_text_document"], + "test-data-weights": [ + 1. + ], + "valid-data-paths": [ + ["data/red_pajama_400B/the_pile/val_tokenized_text_document"], + [ + "data/red_pajama_400B/arxiv/folder_val/tokenized_text_document", + "data/red_pajama_400B/book/folder_val/tokenized_text_document", + "data/red_pajama_400B/c4/folder_val/tokenized_text_document", + "data/red_pajama_400B/wikipedia/folder_val/tokenized_text_document", + "data/red_pajama_400B/github/folder_val/tokenized_text_document", + "data/red_pajama_400B/stackexchange/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2019-30/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2020-05/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2021-04/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2022-05/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2023-06/folder_val/tokenized_text_document", + ], + ], + "valid-data-weights": [ + [1.], + [ + 2.5, + 4.5, + 15.0, + 4.5, + 4.5, + 2.0, + 13.4, + 13.4, + 13.4, + 13.4, + 13.4 + ], + ], + "val-dataset-name": 'pile_rp', + + } \ No newline at end of file diff --git a/configs/datasets/val/pile_rp_subsets.yml b/configs/datasets/val/pile_rp_subsets.yml new file mode 100644 index 000000000..005417c49 --- /dev/null +++ b/configs/datasets/val/pile_rp_subsets.yml @@ -0,0 +1,62 @@ + +{ + "test-data-paths": ["data/red_pajama_400B/the_pile/test_tokenized_text_document"], + "test-data-weights": [ + 1. + ], + "valid-data-paths": [ + ["data/red_pajama_400B/the_pile/val_tokenized_text_document"], + [ + "data/red_pajama_400B/arxiv/folder_val/tokenized_text_document", + "data/red_pajama_400B/book/folder_val/tokenized_text_document", + "data/red_pajama_400B/c4/folder_val/tokenized_text_document", + "data/red_pajama_400B/wikipedia/folder_val/tokenized_text_document", + "data/red_pajama_400B/github/folder_val/tokenized_text_document", + "data/red_pajama_400B/stackexchange/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2019-30/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2020-05/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2021-04/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2022-05/folder_val/tokenized_text_document", + "data/red_pajama_400B/common_crawl/2023-06/folder_val/tokenized_text_document", + ], + ["data/red_pajama_400B/arxiv/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/book/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/c4/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/wikipedia/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/github/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/stackexchange/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/common_crawl/2019-30/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/common_crawl/2020-05/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/common_crawl/2021-04/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/common_crawl/2022-05/folder_val/tokenized_text_document"], + ["data/red_pajama_400B/common_crawl/2023-06/folder_val/tokenized_text_document"], + ], + "valid-data-weights": [ + [1.], + [ + 2.5, + 4.5, + 15.0, + 4.5, + 4.5, + 2.0, + 13.4, + 13.4, + 13.4, + 13.4, + 13.4 + ], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + ], + "val-dataset-name": 'pile_rp_subsets', +} \ No newline at end of file diff --git a/configs/datasets/val/pile_slimp.yml b/configs/datasets/val/pile_slimp.yml new file mode 100644 index 000000000..6956b4a9d --- /dev/null +++ b/configs/datasets/val/pile_slimp.yml @@ -0,0 +1,15 @@ +{ + "test-data-paths": ["data/pile/test/pile_test_text_document"], + "test-data-weights": [ + 1. + ], + "valid-data-paths": [ + ["data/pile/val/pile_val_text_document"], + ["data/slim_pajama/val/all/sp_val"], + ], + "valid-data-weights": [ + [1.], + [1.], + ], + "val-dataset-name": 'pile_slimp', +} diff --git a/configs/datasets/val/pile_slimp_domains.yml b/configs/datasets/val/pile_slimp_domains.yml new file mode 100644 index 000000000..344abe95b --- /dev/null +++ b/configs/datasets/val/pile_slimp_domains.yml @@ -0,0 +1,30 @@ +{ + "test-data-paths": ["data/pile/test/pile_test_text_document"], + "test-data-weights": [ + 1. + ], + "valid-data-paths": [ + ["data/pile/val/pile_val_text_document"], + ["data/slim_pajama/val/all/sp_val"], + ["data/slim_pajama/val/ArXiv/tokenized_arxiv_text_document"], + ["data/slim_pajama/val/Book/tokenized_book_text_document"], + ["data/slim_pajama/val/C4/tokenized_c4_text_document"], + ["data/slim_pajama/val/Wikipedia/tokenized_wikipedia_text_document"], + ["data/slim_pajama/val/Github/tokenized_github_text_document"], + ["data/slim_pajama/val/StackExchange/tokenized_stackexchange_text_document"], + ["data/slim_pajama/val/CommonCrawl/tokenized_commoncrawl_text_document"], + ], + "valid-data-weights": [ + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + ], + "val-dataset-name": 'pile_slimp_domains', + +} \ No newline at end of file diff --git a/configs/datasets/val/pile_slimp_workshop.yml b/configs/datasets/val/pile_slimp_workshop.yml new file mode 100644 index 000000000..93530eccd --- /dev/null +++ b/configs/datasets/val/pile_slimp_workshop.yml @@ -0,0 +1,32 @@ +{ + "test-data-paths": ["data/red_pajama_400B/the_pile/test_tokenized_text_document"], + "test-data-weights": [ + 1. + ], + "valid-data-paths": [ + ["data/red_pajama_400B/the_pile/val_tokenized_text_document"], + [ + "data/tokenized300B/val_splits/ArXiv/tokenized_text_document", + "data/tokenized300B/val_splits/Book/tokenized_text_document", + "data/tokenized300B/val_splits/C4/tokenized_text_document", + "data/tokenized300B/val_splits/Wikipedia/tokenized_text_document", + "data/tokenized300B/val_splits/Github/tokenized_text_document", + "data/tokenized300B/val_splits/StackExchange/tokenized_text_document", + "data/tokenized300B/val_splits/CommonCrawl/tokenized_text_document", + ], + ], + "valid-data-weights": [ + [1.], + [ + 2.5, + 4.5, + 15.0, + 4.5, + 4.5, + 2.0, + 67.0 + ], + ], + "val-dataset-name": 'pile_slimp_workshop', + +} diff --git a/configs/datasets/val/pile_val.yml b/configs/datasets/val/pile_val.yml new file mode 100644 index 000000000..bea136935 --- /dev/null +++ b/configs/datasets/val/pile_val.yml @@ -0,0 +1,13 @@ +{ + "test-data-paths": ["data/pile/test/pile_test_text_document"], + "test-data-weights": [ + 1. + ], + "valid-data-paths": [ + ["data/pile/val/pile_val_text_document"], + ], + "valid-data-weights": [ + [1.], + ], + "val-dataset-name": 'pile_val', +} diff --git a/configs/datasets/val/pile_val_shard0.yml b/configs/datasets/val/pile_val_shard0.yml new file mode 100644 index 000000000..704e1d2e9 --- /dev/null +++ b/configs/datasets/val/pile_val_shard0.yml @@ -0,0 +1,13 @@ +{ + "test-data-paths": ["data/pile/shard_0/shard_0_text_document"], + "test-data-weights": [ + 1. + ], + "valid-data-paths": [ + ["data/pile/shard_0/shard_0_text_document"], + ], + "valid-data-weights": [ + [1.], + ], + "val-dataset-name": 'pile_val_shard0', +} diff --git a/configs/iclr_models/3b_test.yml b/configs/iclr_models/3b_test.yml new file mode 100644 index 000000000..3bec093d9 --- /dev/null +++ b/configs/iclr_models/3b_test.yml @@ -0,0 +1,101 @@ +# GPT-2 pretraining setup +{ + # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages + # across the node boundaries ) + "pipe-parallel-size": 1, + "model-parallel-size": 6, # one copy of the model per node + + # model settings + + + "num_layers": 32, + "hidden_size": 3072, + "num_attention_heads": 24, + "seq_length": 2048, + "max_position_embeddings": 2048, + "pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + "output_layer_parallelism": "column", + + "attention_config": [[["global"], 32]], + + "scaled_upper_triang_masked_softmax_fusion": true, + "bias_gelu_fusion": true, + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + #optimizer settings + # "optimizer": { + # "type": "Adam", + # "params": { + # "lr": 0.00012, + # "betas": [0.9, 0.95], + # "eps": 1.0e-8, + # } + + # }, + # "min_lr": 0.000012, + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + #"train_batch_size": 1, # across 1024 nodes... fingers crossed + "train_micro_batch_size_per_gpu": 8, + #"gradient_accumulation_steps": 2, + "gradient_accumulation_steps": 2, + # "gradient_accumulation_steps": 8, + "data-impl": "mmap", + "split": "949,50,1", + + # activation checkpointing + "checkpoint-activations": true, + "checkpoint-num-layers": 1, + "partition-activations": true, + "synchronize-each-layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight-decay": 0.1, + "hidden-dropout": 0.0, + "attention-dropout": 0.0, + + # precision settings + "fp16": { + "enabled": true, + # "type": "bfloat16", # set bf16 as precision + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 + # misc. training settings + # "train-iters": 250000, + # "lr-decay-iters": 250000, + "distributed-backend": "nccl", + # "lr-decay-style": "cosine", + # "warmup": 0.01, + "checkpoint-factor": 1000, + "eval-interval": 1000, + "eval-iters": 10, + + # logging + "log-interval": 1, + "steps_per_print": 1, + "keep-last-n-checkpoints": 1000, + "wall_clock_breakdown": true, + +} diff --git a/configs/iclr_models/410M.yml b/configs/iclr_models/410M.yml new file mode 100644 index 000000000..eb4b51994 --- /dev/null +++ b/configs/iclr_models/410M.yml @@ -0,0 +1,99 @@ +# GPT-2 pretraining setup +{ + #identifier string for this config used while logging + "identifier_string": "410M", + + # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages + # across the node boundaries ) + "pipe-parallel-size": 1, + "model-parallel-size": 1, # one copy of the model per node + + # model settings + "num-layers": 24, + "hidden-size": 1024, + "seq-length": 2048, + "num-attention-heads": 16, + "max-position-embeddings": 2048, + "pos-emb": "rotary", + "rotary-pct": 0.25, + "no-weight-tying": true, + "gpt-j-residual": true, + "output-layer-parallelism": "column", + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled-upper-triang-masked-softmax-fusion": true, + "bias-gelu-fusion": true, + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # "optimizer": { + # "type": "Adam", + # "params": { + # "lr": 3.0e-4, + # "betas": [0.9, 0.95], + # "eps": 1.0e-8, + # } + # }, + # "min_lr": 3.0e-5, + + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + "cpu_offload": False + }, + + # LLAMA Config + # batch / data settings + "train_batch_size": 1104, #1104, #1104, #1104, #1104, #1104 # approximately 2.2M batch size across 46 nodes + "train_micro_batch_size_per_gpu": 4, + "data-impl": "mmap", + "split": "949,50,1", + + # activation checkpointing + "checkpoint-activations": true, + "checkpoint-num-layers": 1, + "partition-activations": true, + "synchronize-each-layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight-decay": 0.1, + "hidden-dropout": 0.0, + "attention-dropout": 0.0, + + # precision settings of LLaMa + "fp16": { + "enabled": true, + # "type": "bfloat16", # set bf16 as precision + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 + + + + "distributed-backend": "nccl", + # "lr-decay-style": "cosine", + # "warmup": 0.01, + "checkpoint-factor": 400, + "eval-interval": 100, + "warup-eval-interval": 50, + "eval-iters": 10, + + # logging + "log-interval": 1, + "steps_per_print": 1, + "keep-last-n-checkpoints": 1000, + "wall_clock_breakdown": true, + +} diff --git a/configs/iclr_models/410M_ckpt_100.yml b/configs/iclr_models/410M_ckpt_100.yml new file mode 100644 index 000000000..c2934729a --- /dev/null +++ b/configs/iclr_models/410M_ckpt_100.yml @@ -0,0 +1,100 @@ +# GPT-2 pretraining setup +{ + #identifier string for this config used while logging + "identifier_string": "410M", + + # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages + # across the node boundaries ) + "pipe-parallel-size": 1, + "model-parallel-size": 1, # one copy of the model per node + + # model settings + "num-layers": 24, + "hidden-size": 1024, + "seq-length": 2048, + "num-attention-heads": 16, + "max-position-embeddings": 2048, + "pos-emb": "rotary", + "rotary-pct": 0.25, + "no-weight-tying": true, + "gpt-j-residual": true, + "output-layer-parallelism": "column", + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled-upper-triang-masked-softmax-fusion": true, + "bias-gelu-fusion": true, + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # "optimizer": { + # "type": "Adam", + # "params": { + # "lr": 3.0e-4, + # "betas": [0.9, 0.95], + # "eps": 1.0e-8, + # } + # }, + # "min_lr": 3.0e-5, + + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + "cpu_offload": False + }, + + # LLAMA Config + # batch / data settings + "train_batch_size": 1104, #1104, #1104, #1104, #1104, #1104 # approximately 2.2M batch size across 46 nodes + "train_micro_batch_size_per_gpu": 4, + "data-impl": "mmap", + "split": "949,50,1", + + # activation checkpointing + "checkpoint-activations": true, + "checkpoint-num-layers": 1, + "partition-activations": true, + "synchronize-each-layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight-decay": 0.1, + "hidden-dropout": 0.0, + "attention-dropout": 0.0, + + # precision settings of LLaMa + "fp16": { + "enabled": true, + # "type": "bfloat16", # set bf16 as precision + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 + + + + "distributed-backend": "nccl", + # "lr-decay-style": "cosine", + # "warmup": 0.01, + "checkpoint-factor": 100, + "extra_save_iters": [4462, 5291, 5563, 5861, 6914], + "eval-interval": 100, + "warup-eval-interval": 50, + "eval-iters": 10, + + # logging + "log-interval": 1, + "steps_per_print": 1, + "keep-last-n-checkpoints": 1000, + "wall_clock_breakdown": true, + +} \ No newline at end of file diff --git a/configs/iclr_models/49M.yml b/configs/iclr_models/49M.yml new file mode 100644 index 000000000..c07c1eca9 --- /dev/null +++ b/configs/iclr_models/49M.yml @@ -0,0 +1,100 @@ +# GPT-2 pretraining setup +{ + #identifier string for this config used while logging + "identifier_string": "410M", + + # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages + # across the node boundaries ) + "pipe-parallel-size": 1, + "model-parallel-size": 1, # one copy of the model per node + + # model settings + "num-layers": 10, + "hidden-size": 640, + "num-attention-heads": 10, + "seq-length": 2048, + "max-position-embeddings": 2048, + "pos-emb": "rotary", + "rotary-pct": 0.25, + "no-weight-tying": true, + "gpt-j-residual": true, + "output-layer-parallelism": "column", + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled-upper-triang-masked-softmax-fusion": true, + "bias-gelu-fusion": true, + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # "optimizer": { + # "type": "Adam", + # "params": { + # "lr": 3.0e-4, + # "betas": [0.9, 0.95], + # "eps": 1.0e-8, + # } + # }, + # "min_lr": 3.0e-5, + + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + "cpu_offload": False + }, + + # LLAMA Config + # batch / data settings + # "train_batch_size": 1104, #1104, #1104, #1104, #1104, #1104 # approximately 2.2M batch size across 46 nodes + "train_micro_batch_size_per_gpu": 16, + 'gradient_accumulation_steps': 1, + "data-impl": "mmap", + "split": "949,50,1", + + # activation checkpointing + "checkpoint-activations": true, + "checkpoint-num-layers": 1, + "partition-activations": true, + "synchronize-each-layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight-decay": 0.1, + "hidden-dropout": 0.0, + "attention-dropout": 0.0, + + # precision settings of LLaMa + "fp16": { + "enabled": true, + "type": "float16", # set bf16 as precision + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 + + + + "distributed-backend": "nccl", + # "lr-decay-style": "cosine", + # "warmup": 0.01, + "checkpoint-factor": 20, + "eval-interval": 10, + "warup-eval-interval": 50, + "eval-iters": 10, + + # logging + "log-interval": 1, + "steps_per_print": 1, + "keep-last-n-checkpoints": 1000, + "wall_clock_breakdown": true, + +} diff --git a/configs/iclr_models/7_1B.yml b/configs/iclr_models/7_1B.yml new file mode 100644 index 000000000..2d5ebf615 --- /dev/null +++ b/configs/iclr_models/7_1B.yml @@ -0,0 +1,94 @@ +{ + #identifier string for this config used while logging + "identifier_string": "7-1B", + + "pipe-parallel-size": 4, + "model-parallel-size": 6, + + "num-layers": 36, + "hidden-size": 4608, + "seq-length": 2048, + "num-attention-heads": 36, + "max_position_embeddings": 2048, + "pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + "output_layer_parallelism": "column", + "attention_config": [[["global"], 36]], + + "scaled_upper_triang_masked_softmax_fusion": true, + "bias_gelu_fusion": true, + + + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + #optimizer settings + # "optimizer": { + # "type": "Adam", + # "params": { + # "lr": 0.00012, + # "betas": [0.9, 0.95], + # "eps": 1.0e-8, + # } + + # "min_lr": 0.000012, + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + #"train_batch_size": 1, # across 1024 nodes... fingers crossed + "train_micro_batch_size_per_gpu": 4, + #"gradient_accumulation_steps": 2, + # "gradient_accumulation_steps": 8, + "gradient_accumulation_steps": 4, + "data-impl": "mmap", + + + # activation checkpointing + "checkpoint-activations": true, + "checkpoint-num-layers": 1, + "partition-activations": true, + "synchronize-each-layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight-decay": 0.1, + "hidden-dropout": 0.0, + "attention-dropout": 0.0, + + # precision settings + "fp16": { + "enabled": true, + # "type": "bfloat16", # set bf16 as precision + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 + # misc. training settings + "distributed-backend": "nccl", + # "warmup": 0.01, + "checkpoint-factor": 100, + "extra-save-iters": [143051], # [0,1,2,4,8,16,32,64,128,256,512], + + "eval-interval": 100, + "eval-iters": 10, + + # logging + "log-interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, + +} \ No newline at end of file diff --git a/configs/load/3e-5const_0_410M_143_CPT.yml b/configs/load/3e-5const_0_410M_143_CPT.yml new file mode 100644 index 000000000..61d048e86 --- /dev/null +++ b/configs/load/3e-5const_0_410M_143_CPT.yml @@ -0,0 +1,3 @@ +{ + "load":"checkpoints/continued_slim_pajama/JOB-3061457_pythia-deduped-410M-iters-131296_warmup-0.0_max-lr-3e-05_min-lr-3e-05_pretrain_slim_pajama_resume" +} \ No newline at end of file diff --git a/configs/load/none.yml b/configs/load/none.yml new file mode 100644 index 000000000..f975cb0cb --- /dev/null +++ b/configs/load/none.yml @@ -0,0 +1,3 @@ +{ + "load": "none", +} \ No newline at end of file diff --git a/configs/load/pythia_2-8B_143000.yml b/configs/load/pythia_2-8B_143000.yml new file mode 100644 index 000000000..f975cb0cb --- /dev/null +++ b/configs/load/pythia_2-8B_143000.yml @@ -0,0 +1,3 @@ +{ + "load": "none", +} \ No newline at end of file diff --git a/configs/load/pythia_410m.yml b/configs/load/pythia_410m.yml new file mode 100644 index 000000000..df769c23f --- /dev/null +++ b/configs/load/pythia_410m.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/neox_converted/mp1_pp1/pythia", +} \ No newline at end of file diff --git a/configs/load/pythia_410m_10000.yml b/configs/load/pythia_410m_10000.yml new file mode 100644 index 000000000..76aba61b4 --- /dev/null +++ b/configs/load/pythia_410m_10000.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/neox_converted/mp1_pp1/pythia/410_10000", +} \ No newline at end of file diff --git a/configs/load/pythia_410m_143000.yml b/configs/load/pythia_410m_143000.yml new file mode 100644 index 000000000..2d703aa6c --- /dev/null +++ b/configs/load/pythia_410m_143000.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/neox_converted/mp1_pp1/pythia/410_143000", +} \ No newline at end of file diff --git a/configs/load/pythia_410m_27000.yml b/configs/load/pythia_410m_27000.yml new file mode 100644 index 000000000..266ccd51f --- /dev/null +++ b/configs/load/pythia_410m_27000.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/neox_converted/mp1_pp1/pythia/410_27000", +} \ No newline at end of file diff --git a/configs/load/pythia_6-9B_143000.yml b/configs/load/pythia_6-9B_143000.yml new file mode 100644 index 000000000..f975cb0cb --- /dev/null +++ b/configs/load/pythia_6-9B_143000.yml @@ -0,0 +1,3 @@ +{ + "load": "none", +} \ No newline at end of file diff --git a/configs/load/pythia_deduped_410m_10000.yml b/configs/load/pythia_deduped_410m_10000.yml new file mode 100644 index 000000000..4492ff72a --- /dev/null +++ b/configs/load/pythia_deduped_410m_10000.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/neox_converted/mp1_pp1/pythia/410m_deduped_step10000", +} \ No newline at end of file diff --git a/configs/load/pythia_deduped_410m_143000.yml b/configs/load/pythia_deduped_410m_143000.yml new file mode 100644 index 000000000..5b46b4eb5 --- /dev/null +++ b/configs/load/pythia_deduped_410m_143000.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/neox_converted/mp1_pp1/pythia/410m_deduped_step143000", +} \ No newline at end of file diff --git a/configs/load/pythia_deduped_410m_27000.yml b/configs/load/pythia_deduped_410m_27000.yml new file mode 100644 index 000000000..2957457a7 --- /dev/null +++ b/configs/load/pythia_deduped_410m_27000.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/neox_converted/mp1_pp1/pythia/410m_deduped_step27000", +} \ No newline at end of file diff --git a/configs/load/resume_1-2e-4_001_7-1B_pile_PT.yml b/configs/load/resume_1-2e-4_001_7-1B_pile_PT.yml new file mode 100644 index 000000000..bff295f9e --- /dev/null +++ b/configs/load/resume_1-2e-4_001_7-1B_pile_PT.yml @@ -0,0 +1,3 @@ +{ +"load":"checkpoints/cpt_iclr_2/JOB-3178176_7-1B_it-132366_wu-0.01_mxlr-0.00012_mnlr-1.2e-05_sch-cosine_tr-pile-train_scratch" +} \ No newline at end of file diff --git a/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml b/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml new file mode 100644 index 000000000..2cb78d8dc --- /dev/null +++ b/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml @@ -0,0 +1,3 @@ +{ +"load":"checkpoints/cpt_iclr_2/JOB-3199708_7-1B_it-132366_wu-0.01_mxlr-0.00012_mnlr-1.2e-05_sch-cosine_tr-slim-pajama-300B_finetune" +} \ No newline at end of file diff --git a/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_PT.yml b/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_PT.yml new file mode 100644 index 000000000..ada5b3ead --- /dev/null +++ b/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_PT.yml @@ -0,0 +1,3 @@ +{ +"load":"checkpoints/cpt_iclr_2/JOB-3199748_7-1B_it-132366_wu-0.01_mxlr-0.00012_mnlr-1.2e-05_sch-cosine_tr-slim-pajama-300B_scratch" +} \ No newline at end of file diff --git a/configs/load/resume_1-5e-4_001_410M_143_CPT.yml b/configs/load/resume_1-5e-4_001_410M_143_CPT.yml new file mode 100644 index 000000000..8ad24fba8 --- /dev/null +++ b/configs/load/resume_1-5e-4_001_410M_143_CPT.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/continued_slim_pajama/JOB-3046061_pythia-deduped-410M-iters-131296_warmup-0.01_max-lr-0.00015_min-lr-1.5e-05_pretrain_slim_pajama_resume", +} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_410M_143_CPT.yml b/configs/load/resume_3e-4_001_410M_143_CPT.yml new file mode 100644 index 000000000..87239b92f --- /dev/null +++ b/configs/load/resume_3e-4_001_410M_143_CPT.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/continued_slim_pajama/JOB-3051279_pythia-deduped-410M-iters-131296_warmup-0.01_max-lr-0.0003_min-lr-3e-05_pretrain_slim_pajama_resume", +} diff --git a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r05.yml b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r05.yml new file mode 100644 index 000000000..22b44ad8b --- /dev/null +++ b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r05.yml @@ -0,0 +1,3 @@ +{ +"load":"checkpoints/cpt_iclr_2/JOB-3199939_410M_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-slim-pajama-300B-replay05_finetune" +} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r1.yml b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r1.yml new file mode 100644 index 000000000..9a0c93bc5 --- /dev/null +++ b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r1.yml @@ -0,0 +1,3 @@ +{ +"load":"checkpoints/cpt_iclr_2/JOB-3200003_410M_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-slim-pajama-300B-replay1_finetune" +} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r10.yml b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r10.yml new file mode 100644 index 000000000..b1527a45d --- /dev/null +++ b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r10.yml @@ -0,0 +1,3 @@ +{ +"load":"checkpoints/cpt_iclr_2/JOB-3199999_410M_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-slim-pajama-300B-replay10_finetune" +} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r5.yml b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r5.yml new file mode 100644 index 000000000..4788d814e --- /dev/null +++ b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r5.yml @@ -0,0 +1,3 @@ +{ +"load":"checkpoints/cpt_iclr_2/JOB-3200000_410M_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-slim-pajama-300B-replay5_finetune" +} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_7-1B_pile_PT.yml b/configs/load/resume_3e-4_001_7-1B_pile_PT.yml new file mode 100644 index 000000000..a6ad9144d --- /dev/null +++ b/configs/load/resume_3e-4_001_7-1B_pile_PT.yml @@ -0,0 +1,3 @@ +{ +"load":"checkpoints/cpt_iclr/JOB-3156148_7-1B_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-pile-train_scratch" +} \ No newline at end of file diff --git a/configs/load/resume_6e-4_001_410M_143_CPT.yml b/configs/load/resume_6e-4_001_410M_143_CPT.yml new file mode 100644 index 000000000..b4fb5571e --- /dev/null +++ b/configs/load/resume_6e-4_001_410M_143_CPT.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/continued_slim_pajama/JOB-3047316_pythia-deduped-410M-iters-131296_warmup-0.01_max-lr-0.0006_min-lr-6e-05_pretrain_slim_pajama_resume", +} \ No newline at end of file diff --git a/configs/load/scratch.yml b/configs/load/scratch.yml new file mode 100644 index 000000000..dd4c8ec8e --- /dev/null +++ b/configs/load/scratch.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/continued_slim_pajama/JOB-3047440_pythia-deduped-410M-iters-131296_warmup-0.01_max-lr-0.0003_min-lr-3e-05_pretrain_slim_pajama_none", +} \ No newline at end of file diff --git a/configs/load/test_3e-5const_0_410M_143_CPT.yml b/configs/load/test_3e-5const_0_410M_143_CPT.yml new file mode 100644 index 000000000..fd41eeee3 --- /dev/null +++ b/configs/load/test_3e-5const_0_410M_143_CPT.yml @@ -0,0 +1,3 @@ +{ +"load":"checkpoints/continued_slim_pajama/JOB-3057769_pythia-deduped-410M-iters-131296_warmup-0.0_max-lr-3e-05_min-lr-3e-05_pretrain_slim_pajama_resume" +} \ No newline at end of file diff --git a/configs/load/wu_001_lr1-5e-4_pile.yml b/configs/load/wu_001_lr1-5e-4_pile.yml new file mode 100644 index 000000000..adfab6c49 --- /dev/null +++ b/configs/load/wu_001_lr1-5e-4_pile.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/continued_test8/JOB-2966423_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.00015_min-lr-1.5e-05_finetune_pile", +} \ No newline at end of file diff --git a/configs/load/wu_001_lr3e-4_pile.yml b/configs/load/wu_001_lr3e-4_pile.yml new file mode 100644 index 000000000..edc1bc825 --- /dev/null +++ b/configs/load/wu_001_lr3e-4_pile.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/continued_test8/JOB-2966421_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.0003_min-lr-3e-05_finetune_pile", +} \ No newline at end of file diff --git a/configs/load/wu_001_lr6e-4_pile.yml b/configs/load/wu_001_lr6e-4_pile.yml new file mode 100644 index 000000000..6d7357118 --- /dev/null +++ b/configs/load/wu_001_lr6e-4_pile.yml @@ -0,0 +1,3 @@ +{ + "load": "checkpoints/continued_test8/JOB-2966422_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.0006_min-lr-6e-05_finetune_pile", +} \ No newline at end of file diff --git a/configs/pythia_410m_llama_setup_finetune.yml b/configs/pythia_410m_llama_setup_finetune.yml new file mode 100644 index 000000000..c01771008 --- /dev/null +++ b/configs/pythia_410m_llama_setup_finetune.yml @@ -0,0 +1,24 @@ +# Suggested data paths when using GPT-NeoX locally +{ + # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group. + # WARNING: setting this to True will override any user provided weights + # "weight_by_num_documents": false, + # "weighted_sampler_alpha": 0.3, + + "tokenizer-type": "HFTokenizer", + "vocab-file": "data/20B_tokenizer.json", + + "checkpoint_validation_with_forward_pass": False, + "use_wandb": False, + # "wandb_host": "https://api.wandb.ai", + + "launcher": "jsrun", + "deepspeed_jsrun": true, + "num_workers": 1, + "finetune": true, + + "save": "checkpoints/cpt_iclr_2", + "tensorboard-dir": "tensorboard/cpt_iclr_2", + "log-dir": "logs", + "wandb_project": "cpt_iclr_2", +} \ No newline at end of file diff --git a/configs/pythia_410m_llama_setup_resume.yml b/configs/pythia_410m_llama_setup_resume.yml new file mode 100644 index 000000000..e49d1db20 --- /dev/null +++ b/configs/pythia_410m_llama_setup_resume.yml @@ -0,0 +1,24 @@ +# Suggested data paths when using GPT-NeoX locally +{ + # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group. + # WARNING: setting this to True will override any user provided weights + # "weight_by_num_documents": false, + # "weighted_sampler_alpha": 0.3, + + "tokenizer-type": "HFTokenizer", + "vocab-file": "data/20B_tokenizer.json", + + "checkpoint_validation_with_forward_pass": False, + "use_wandb": False, + # "wandb_host": "https://api.wandb.ai", + + "launcher": "openmpi", + # "deepspeed_jsrun": true, + "num_workers": 2, + "finetune": false, + + "save": "checkpoints/cpt_iclr_2", + "tensorboard-dir": "tensorboard/cpt_iclr_2", + "log-dir": "logs", + "wandb_project": "cpt_iclr_2", +} \ No newline at end of file diff --git a/configs/schedules/7_1B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml b/configs/schedules/7_1B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml new file mode 100644 index 000000000..2f563a1a3 --- /dev/null +++ b/configs/schedules/7_1B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml @@ -0,0 +1,16 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 0.00012, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.000012, + "lr-decay-style": "inverse_sqrt_infinite", + "num_repeats": 1, + "warmup": 0.01, + "constant_iters_percent": 0.98, + "constant_lr": 0.000017, +} \ No newline at end of file diff --git a/configs/schedules/adam_constant_lr3e-4_3e-4_wu-001.yml b/configs/schedules/adam_constant_lr3e-4_3e-4_wu-001.yml new file mode 100644 index 000000000..638439c4d --- /dev/null +++ b/configs/schedules/adam_constant_lr3e-4_3e-4_wu-001.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-4, + "lr-decay-style": "constant", # this will coincide with the else in AnnealingLR + "warmup": 0.01, +} \ No newline at end of file diff --git a/configs/schedules/adam_constant_lr3e-5_3e-5_wu-0.yml b/configs/schedules/adam_constant_lr3e-5_3e-5_wu-0.yml new file mode 100644 index 000000000..640728c5f --- /dev/null +++ b/configs/schedules/adam_constant_lr3e-5_3e-5_wu-0.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-5, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-5, + "lr-decay-style": "constant", # this will coincide with the else in AnnealingLR + "warmup": 0., +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine-inf_lr3e-4_3e-5_wu-001.yml b/configs/schedules/adam_cosine-inf_lr3e-4_3e-5_wu-001.yml new file mode 100644 index 000000000..383e476d1 --- /dev/null +++ b/configs/schedules/adam_cosine-inf_lr3e-4_3e-5_wu-001.yml @@ -0,0 +1,14 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-5, + "lr-decay-style": "cosine-inf", + "num_repeats": 3, + "warmup": 0.01, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml b/configs/schedules/adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml new file mode 100644 index 000000000..f4451118d --- /dev/null +++ b/configs/schedules/adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 1.2e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 1.2e-5, + "lr-decay-style": "cosine", + "warmup": 0.01, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml new file mode 100644 index 000000000..72e42e239 --- /dev/null +++ b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml @@ -0,0 +1,14 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 1.5e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 1.5e-5, + "lr-decay-style": "cosine", + "warmup": 0., + # "load": "/gpfs/alpine/csc499/scratch/btherien/gpt-neox/checkpoints/continued_test8/JOB-2959220_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.00015_min-lr-1.5e-05_finetune", +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml new file mode 100644 index 000000000..58ffad137 --- /dev/null +++ b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 1.5e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 1.5e-5, + "lr-decay-style": "cosine", + "warmup": 0.005, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml new file mode 100644 index 000000000..4780250d3 --- /dev/null +++ b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 1.5e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 1.5e-5, + "lr-decay-style": "cosine", + "warmup": 0.01, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml new file mode 100644 index 000000000..653d47d0f --- /dev/null +++ b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 1.5e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 1.5e-5, + "lr-decay-style": "cosine", + "warmup": 0.02, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0.yml b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0.yml new file mode 100644 index 000000000..955708b5c --- /dev/null +++ b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-5, + "lr-decay-style": "cosine", + "warmup": 0.0, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0005.yml b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0005.yml new file mode 100644 index 000000000..c86695af7 --- /dev/null +++ b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0005.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-5, + "lr-decay-style": "cosine", + "warmup": 0.005, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-001.yml b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-001.yml new file mode 100644 index 000000000..52944b4a5 --- /dev/null +++ b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-001.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-5, + "lr-decay-style": "cosine", + "warmup": 0.01, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-002.yml b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-002.yml new file mode 100644 index 000000000..d8da20aa5 --- /dev/null +++ b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-002.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-5, + "lr-decay-style": "cosine", + "warmup": 0.02, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0.yml b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0.yml new file mode 100644 index 000000000..776225363 --- /dev/null +++ b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0.yml @@ -0,0 +1,14 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 6.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 6.0e-5, + "lr-decay-style": "cosine", + "warmup": 0.0, + # "load": "/gpfs/alpine/csc499/scratch/btherien/gpt-neox/checkpoints/continued_test8/JOB-2959221_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.0006_min-lr-6e-05_finetune" +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0005.yml b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0005.yml new file mode 100644 index 000000000..73f38a6cc --- /dev/null +++ b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0005.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 6.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 6.0e-5, + "lr-decay-style": "cosine", + "warmup": 0.005, +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-001.yml b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-001.yml new file mode 100644 index 000000000..d3ae77ba9 --- /dev/null +++ b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-001.yml @@ -0,0 +1,14 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 6.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 6.0e-5, + "lr-decay-style": "cosine", + "warmup": 0.01, + # "load": "/gpfs/alpine/csc499/scratch/btherien/gpt-neox/checkpoints/continued_test8/JOB-2959221_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.0006_min-lr-6e-05_finetune" +} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-002.yml b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-002.yml new file mode 100644 index 000000000..b01da6ed1 --- /dev/null +++ b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-002.yml @@ -0,0 +1,13 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 6.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 6.0e-5, + "lr-decay-style": "cosine", + "warmup": 0.02, +} \ No newline at end of file diff --git a/configs/schedules/adam_infcos_lr3e-4_3e-5_wu-001.yml b/configs/schedules/adam_infcos_lr3e-4_3e-5_wu-001.yml new file mode 100644 index 000000000..53838af53 --- /dev/null +++ b/configs/schedules/adam_infcos_lr3e-4_3e-5_wu-001.yml @@ -0,0 +1,17 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-5, + "lr-decay-style": "cosine_cooldown_infinite", + "warmup": 0.01, + "constant_lr": 0.000165, + "constant_iters_percent" : 0.85, + "cooldown_iters_percent" : 0.6, + "timescale" : 10, +} \ No newline at end of file diff --git a/configs/schedules/adam_infinv_lr3e-4_3e-5_wu-001.yml b/configs/schedules/adam_infinv_lr3e-4_3e-5_wu-001.yml new file mode 100644 index 000000000..9e833f4cc --- /dev/null +++ b/configs/schedules/adam_infinv_lr3e-4_3e-5_wu-001.yml @@ -0,0 +1,17 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-5, + "lr-decay-style": "inverse_sqrt_infinite", + "warmup": 0.01, + "constant_lr": 0.000165, + "constant_iters_percent" : 0.85, + "cooldown_iters_percent" : 0.6, + "timescale" : 10, +} \ No newline at end of file diff --git a/configs/schedules/adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml b/configs/schedules/adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml new file mode 100644 index 000000000..0e7134ec3 --- /dev/null +++ b/configs/schedules/adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml @@ -0,0 +1,16 @@ +{ +"optimizer": { + "type": "Adam", + "params": { + "lr": 3.0e-4, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 3.0e-5, + "lr-decay-style": "inverse_sqrt_infinite", + "num_repeats": 1, + "warmup": 0.01, + "constant_iters_percent": 0.98, + "constant_lr": 8.0e-5, +} \ No newline at end of file diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 65347e24f..c1119c89c 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -565,6 +565,98 @@ def build_train_valid_test_data_iterators(neox_args): return train_data_iterator, valid_data_iterator, test_data_iterator +def build_validation_iterator(neox_args): + """XXX""" + + valid_dataloader = None + + print_rank_0("> building validation ...") + + # Ensure only the first/last pipeline stages have data loaders + if neox_args.is_pipe_parallel: + is_first_stage = mpu.get_pipe_parallel_rank() == 0 + is_last_stage = ( + mpu.get_pipe_parallel_rank() == mpu.get_pipe_parallel_world_size() - 1 + ) + pipe_load = is_first_stage or is_last_stage + else: + pipe_load = True + + # Data loader only on rank 0 of each model parallel group. + if mpu.get_model_parallel_rank() == 0 and pipe_load: + # Number of train/valid/test samples. + train_iters = neox_args.train_iters + eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters + test_iters = neox_args.eval_iters + num_samples = eval_iters * neox_args.train_batch_size + + valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( + neox_args.valid_data_weights, num_samples + ) + + # build individual datasets + _, valid_datasets, _ = build_weighted_datasets( + neox_args, + 0, + valid_num_samples, + 0, + 0, + valid_weights, + 0, + build_index_mappings=not neox_args.weight_by_num_documents, + concatenate_train_replay_paths=False, + ) + + # if neox_args.weight_by_num_documents: # Not supported for now + + if valid_datasets: + valid_ds = BlendableDataset(valid_datasets, valid_weights) + valid_dataloader = make_data_loader(valid_ds, neox_args=neox_args) + + # Flags to know if we need to do training/validation/testing. + do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 + + # Need to broadcast num_tokens and num_type_tokens. + flags = torch.cuda.LongTensor([0, int(do_valid), 0]) + else: + flags = torch.cuda.LongTensor([0, 0, 0]) + + # Broadcast num tokens. + if neox_args.is_pipe_parallel: + # Only first/last pipeline stages have data loaders, so pipeline parallelism should + # broadcast globally instead of just the model parallel group. + torch.distributed.broadcast(flags, src=0) + else: + torch.distributed.broadcast( + flags, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + neox_args.do_train = flags[0].item() + neox_args.do_valid = flags[1].item() + neox_args.do_test = flags[2].item() + + if valid_dataloader is not None: + start_iter_val = ( + (neox_args.iteration * neox_args.gradient_accumulation_steps) + // neox_args.eval_interval + ) * neox_args.eval_iters + valid_dataloader.batch_sampler.start_iter = start_iter_val % len( + valid_dataloader + ) + print_rank_0( + "setting validation data start iteration to {}".format( + valid_dataloader.batch_sampler.start_iter + ) + ) + + if valid_dataloader is not None: + valid_data_iterator = iter(valid_dataloader) + else: + valid_data_iterator = None + + return valid_data_iterator + def compile_helper(): """Compile helper function at runtime. Make sure this is invoked on a single process.""" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 037fe2079..ea13659a8 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -467,6 +467,16 @@ class NeoXArgsModel(NeoXArgsTemplate): Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ + identifier_string: str = "" + """ + an identifier for the model, used for saving checkpoints,logging, etc. + """ + + warup_eval_interval: int = 50 + """ + the evaluation interval to use during warmup + """ + @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): @@ -1275,6 +1285,16 @@ class NeoXArgsTraining(NeoXArgsTemplate): What to scale width by when creating the delta model for mup """ + train_dataset_name: str = "no_train_dataset_name_given" + """ + An identified for the training dataset used for logging + """ + + val_dataset_name: str = "no_val_dataset_name_given" + """ + An identified for the training dataset used for logging + """ + @dataclass class NeoXArgsTextgen(NeoXArgsTemplate): diff --git a/megatron/training.py b/megatron/training.py index 1ee5d256d..bfe8c90a8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -25,6 +25,7 @@ import math import sys from contextlib import nullcontext +import copy import torch import deepspeed @@ -45,7 +46,7 @@ get_params_for_weight_decay_optimization, ) from megatron.checkpointing import load_checkpoint, save_checkpoint -from megatron.data.data_utils import build_train_valid_test_data_iterators +from megatron.data.data_utils import build_train_valid_test_data_iterators, build_validation_iterator from megatron.initialize import initialize_megatron from megatron.learning_rates import AnnealingLR from megatron.logging import tb_wandb_log, training_log @@ -197,6 +198,18 @@ def pretrain(neox_args): ) timers("model and optimizer").stop() + tensorboard_writer = neox_args.tensorboard_writer + neox_args.tensorboard_writer = None + neox_args_val = copy.deepcopy(neox_args) + neox_args.tensorboard_writer = tensorboard_writer + neox_args_val.train_data_paths = [None] + neox_args_val.test_data_paths = [None] + neox_args.valid_data_paths = neox_args.valid_data_paths[0] + neox_args.valid_data_weights = neox_args.valid_data_weights[0] + + + print(neox_args.is_replay_enabled) + # Data stuff. timers("train/valid/test data iterators").start() ( @@ -204,6 +217,14 @@ def pretrain(neox_args): valid_data_iterator, test_data_iterator, ) = build_train_valid_test_data_iterators(neox_args=neox_args) + val_iters = [valid_data_iterator] + if neox_args_val.valid_data_paths is not None and len(neox_args_val.valid_data_paths) > 1: + for i in range(1, len(neox_args_val.valid_data_paths)): + temp_copy = copy.deepcopy(neox_args_val) + temp_copy.valid_data_paths = temp_copy.valid_data_paths[i] + temp_copy.valid_data_weights = temp_copy.valid_data_weights[i] + temp_copy.num_workers = 0 + val_iters.append(build_validation_iterator(neox_args=temp_copy)) timers("train/valid/test data iterators").stop() if neox_args.use_mup and neox_args.coord_check: @@ -237,17 +258,20 @@ def pretrain(neox_args): ) if neox_args.do_valid: - prefix = "the end of training for val data" - evaluate_and_print_results( - neox_args=neox_args, - prefix=prefix, - forward_step_func=forward_step, - data_iterator=valid_data_iterator, - model=model, - iteration=iteration, - verbose=False, - timers=timers, - ) + prefix = "the start of training for val data" + for i in range(len(val_iters)): + print_rank_0("in if neox_args.do_valid for val_iters[i]",i, val_iters[i]) + evaluate_and_print_results( + neox_args=neox_args, + prefix=prefix, + forward_step_func=forward_step, + data_iterator=val_iters[i], + model=model, + iteration=iteration, + verbose=False, + timers=timers, + eval_name=f"val_{i}", + ) if neox_args.save and iteration != 0: save_checkpoint( diff --git a/megatron_config_1.json b/megatron_config_1.json new file mode 100644 index 000000000..e92fb185d --- /dev/null +++ b/megatron_config_1.json @@ -0,0 +1 @@ +{"launcher": "jsrun", "train_batch_size": 276, "train_micro_batch_size_per_gpu": 138, "optimizer": {"type": "Adam", "params": {"lr": 0.00012, "betas": [0.9, 0.95], "eps": 1e-08}}, "fp16": {"enabled": true, "loss_scale": 0, "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1}, "gradient_clipping": 1.0, "zero_optimization": {"stage": 1, "allgather_partitions": true, "allgather_bucket_size": 500000000, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 500000000, "contiguous_gradients": true, "cpu_offload": false}, "steps_per_print": 1, "wall_clock_breakdown": true, "precision": "fp16", "num_layers": 10, "hidden_size": 640, "num_attention_heads": 10, "seq_length": 2048, "max_position_embeddings": 2048, "pos_emb": "rotary", "no_weight_tying": true, "attention_config": ["global", "global", "global", "global", "global", "global", "global", "global", "global", "global"], "sparsity_config": {}, "scaled_upper_triang_masked_softmax_fusion": true, "bias_gelu_fusion": true, "rotary_pct": 0.25, "init_method": "small_init", "output_layer_init_method": "wang_init", "gpt_j_residual": true, "output_layer_parallelism": "column", "identifier_string": "410M", "lr_decay_style": "cosine", "lr_decay_iters": 132366, "min_lr": 1.2e-05, "optimizer_type": "Adam", "zero_stage": 1, "zero_reduce_scatter": true, "zero_contiguous_gradients": true, "zero_reduce_bucket_size": 500000000, "zero_allgather_bucket_size": 500000000, "lr": 0.00012, "tokenizer_type": "HFTokenizer", "train_data_paths": ["data/pile/train/pile_train"], "test_data_paths": ["data/pile/test/pile_test_text_document"], "valid_data_paths": [["data/pile/val/pile_val_text_document"], ["data/slim_pajama/val/all/sp_val"]], "train_data_weights": [1.0], "valid_data_weights": [[1.0], [1.0]], "test_data_weights": [1.0], "data_impl": "mmap", "save": "checkpoints/cpt_iclr_2", "config_files": {"pythia_410m_llama_setup_resume.yml": "# Suggested data paths when using GPT-NeoX locally\n{\n # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group.\n # WARNING: setting this to True will override any user provided weights\n # \"weight_by_num_documents\": false,\n # \"weighted_sampler_alpha\": 0.3,\n\n \"tokenizer-type\": \"HFTokenizer\",\n \"vocab-file\": \"data/20B_tokenizer.json\",\n\n \"checkpoint_validation_with_forward_pass\": False,\n \"use_wandb\": False,\n # \"wandb_host\": \"https://api.wandb.ai\",\n\n \"launcher\": \"jsrun\",\n \"deepspeed_jsrun\": true,\n \"num_workers\": 1,\n \"finetune\": false,\n\n \"save\": \"checkpoints/cpt_iclr_2\",\n \"tensorboard-dir\": \"tensorboard/cpt_iclr_2\",\n \"log-dir\": \"logs\",\n \"wandb_project\": \"cpt_iclr_2\",\n}", "49M.yml": "# GPT-2 pretraining setup\n{\n #identifier string for this config used while logging\n \"identifier_string\": \"410M\",\n\n # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages\n # across the node boundaries )\n \"pipe-parallel-size\": 1,\n \"model-parallel-size\": 1, # one copy of the model per node\n\n # model settings\n \"num-layers\": 10,\n \"hidden-size\": 640,\n \"num-attention-heads\": 10,\n \"seq-length\": 2048,\n \"max-position-embeddings\": 2048,\n \"pos-emb\": \"rotary\",\n \"rotary-pct\": 0.25,\n \"no-weight-tying\": true,\n \"gpt-j-residual\": true,\n \"output-layer-parallelism\": \"column\",\n\n # these should provide some speedup but takes a while to build, set to true if desired\n \"scaled-upper-triang-masked-softmax-fusion\": true,\n \"bias-gelu-fusion\": true,\n\n # init methods\n \"init_method\": \"small_init\",\n \"output_layer_init_method\": \"wang_init\",\n\n # \"optimizer\": {\n # \"type\": \"Adam\",\n # \"params\": {\n # \"lr\": 3.0e-4,\n # \"betas\": [0.9, 0.95],\n # \"eps\": 1.0e-8,\n # }\n # },\n # \"min_lr\": 3.0e-5,\n\n \"zero_optimization\": {\n \"stage\": 1,\n \"allgather_partitions\": True,\n \"allgather_bucket_size\": 500000000,\n \"overlap_comm\": True,\n \"reduce_scatter\": True,\n \"reduce_bucket_size\": 500000000,\n \"contiguous_gradients\": True,\n \"cpu_offload\": False\n },\n\n # LLAMA Config\n # batch / data settings\n # \"train_batch_size\": 1104, #1104, #1104, #1104, #1104, #1104 # approximately 2.2M batch size across 46 nodes \n \"train_micro_batch_size_per_gpu\": 138,\n 'gas': 4,\n \"data-impl\": \"mmap\",\n \"split\": \"949,50,1\",\n\n # activation checkpointing\n \"checkpoint-activations\": true,\n \"checkpoint-num-layers\": 1,\n \"partition-activations\": true,\n \"synchronize-each-layer\": true,\n\n # regularization\n \"gradient_clipping\": 1.0,\n \"weight-decay\": 0.1,\n \"hidden-dropout\": 0.0,\n \"attention-dropout\": 0.0,\n\n # precision settings of LLaMa\n \"fp16\": {\n \"enabled\": true,\n # \"type\": \"bfloat16\", # set bf16 as precision\n \"loss_scale\": 0,\n \"loss_scale_window\": 1000,\n \"hysteresis\": 2,\n \"min_loss_scale\": 1\n },\n\n # \"fp32_allreduce\": True, # without a patch to torch, bf16 models have to do the allreduce in fp32\n\n\n\n \"distributed-backend\": \"nccl\",\n # \"lr-decay-style\": \"cosine\",\n # \"warmup\": 0.01,\n \"checkpoint-factor\": 400,\n \"eval-interval\": 100,\n \"warup-eval-interval\": 50,\n \"eval-iters\": 10,\n\n # logging\n \"log-interval\": 1,\n \"steps_per_print\": 1,\n \"keep-last-n-checkpoints\": 1000,\n \"wall_clock_breakdown\": true,\n\n}\n", "pile_train.yml": "{ \n \"train-data-paths\": [\n \"data/pile/train/pile_train\",\n ],\n \"train-data-weights\": [\n 1.,\n ],\n \"train-dataset-name\": 'pile_train',\n \"train-iters\": 132366,\n \"lr-decay-iters\": 132366,\n}", "pile_slimp.yml": "{\n \"test-data-paths\": [\"data/pile/test/pile_test_text_document\"],\n \"test-data-weights\": [\n 1.\n ],\n \"valid-data-paths\": [\n [\"data/pile/val/pile_val_text_document\"],\n [\"data/slim_pajama/val/all/sp_val\"],\n ],\n \"valid-data-weights\": [\n [1.],\n [1.],\n ],\n \"val-dataset-name\": 'pile_slimp',\n}\n", "none.yml": "{\n \"load\": \"none\",\n}", "adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml": "{\n\"optimizer\": {\n \"type\": \"Adam\",\n \"params\": {\n \"lr\": 1.2e-4,\n \"betas\": [0.9, 0.95],\n \"eps\": 1.0e-8,\n }\n },\n \"min_lr\": 1.2e-5,\n \"lr-decay-style\": \"cosine\",\n \"warmup\": 0.01,\n}"}, "load": "none", "checkpoint_factor": 400, "batch_size": 138, "train_iters": 132366, "eval_iters": 10, "keep_last_n_checkpoints": 1000, "eval_interval": 100, "split": "949,50,1", "vocab_file": "data/20B_tokenizer.json", "num_workers": 1, "attention_dropout": 0.0, "hidden_dropout": 0.0, "weight_decay": 0.1, "checkpoint_activations": true, "synchronize_each_layer": true, "partition_activations": true, "gas": 1, "clip_grad": 1.0, "dynamic_loss_scale": true, "train_dataset_name": "pile_train", "val_dataset_name": "pile_slimp", "pipe_parallel_size": 1, "world_size": 1, "is_pipe_parallel": true, "use_wandb": false, "wandb_project": "cpt_iclr_2", "log_dir": "logs", "tensorboard_dir": "tensorboard/cpt_iclr_2", "log_interval": 1, "text_gen_type": "unconditional", "local_rank": 0, "rank": 0, "deepspeed_jsrun": true, "user_script": "train.py", "save_iters": [400, 800, 1200, 1600, 2000, 2400, 2800, 3200, 3600, 4000, 4400, 4800, 5200, 5600, 6000, 6400, 6800, 7200, 7600, 8000, 8400, 8800, 9200, 9600, 10000, 10400, 10800, 11200, 11600, 12000, 12400, 12800, 13200, 13600, 14000, 14400, 14800, 15200, 15600, 16000, 16400, 16800, 17200, 17600, 18000, 18400, 18800, 19200, 19600, 20000, 20400, 20800, 21200, 21600, 22000, 22400, 22800, 23200, 23600, 24000, 24400, 24800, 25200, 25600, 26000, 26400, 26800, 27200, 27600, 28000, 28400, 28800, 29200, 29600, 30000, 30400, 30800, 31200, 31600, 32000, 32400, 32800, 33200, 33600, 34000, 34400, 34800, 35200, 35600, 36000, 36400, 36800, 37200, 37600, 38000, 38400, 38800, 39200, 39600, 40000, 40400, 40800, 41200, 41600, 42000, 42400, 42800, 43200, 43600, 44000, 44400, 44800, 45200, 45600, 46000, 46400, 46800, 47200, 47600, 48000, 48400, 48800, 49200, 49600, 50000, 50400, 50800, 51200, 51600, 52000, 52400, 52800, 53200, 53600, 54000, 54400, 54800, 55200, 55600, 56000, 56400, 56800, 57200, 57600, 58000, 58400, 58800, 59200, 59600, 60000, 60400, 60800, 61200, 61600, 62000, 62400, 62800, 63200, 63600, 64000, 64400, 64800, 65200, 65600, 66000, 66400, 66800, 67200, 67600, 68000, 68400, 68800, 69200, 69600, 70000, 70400, 70800, 71200, 71600, 72000, 72400, 72800, 73200, 73600, 74000, 74400, 74800, 75200, 75600, 76000, 76400, 76800, 77200, 77600, 78000, 78400, 78800, 79200, 79600, 80000, 80400, 80800, 81200, 81600, 82000, 82400, 82800, 83200, 83600, 84000, 84400, 84800, 85200, 85600, 86000, 86400, 86800, 87200, 87600, 88000, 88400, 88800, 89200, 89600, 90000, 90400, 90800, 91200, 91600, 92000, 92400, 92800, 93200, 93600, 94000, 94400, 94800, 95200, 95600, 96000, 96400, 96800, 97200, 97600, 98000, 98400, 98800, 99200, 99600, 100000, 100400, 100800, 101200, 101600, 102000, 102400, 102800, 103200, 103600, 104000, 104400, 104800, 105200, 105600, 106000, 106400, 106800, 107200, 107600, 108000, 108400, 108800, 109200, 109600, 110000, 110400, 110800, 111200, 111600, 112000, 112400, 112800, 113200, 113600, 114000, 114400, 114800, 115200, 115600, 116000, 116400, 116800, 117200, 117600, 118000, 118400, 118800, 119200, 119600, 120000, 120400, 120800, 121200, 121600, 122000, 122400, 122800, 123200, 123600, 124000, 124400, 124800, 125200, 125600, 126000, 126400, 126800, 127200, 127600, 128000, 128400, 128800, 129200, 129600, 130000, 130400, 130800, 131200, 131600, 132000], "global_num_gpus": 2} \ No newline at end of file diff --git a/train.py b/train.py index 2e4b09954..17f6ef91c 100644 --- a/train.py +++ b/train.py @@ -19,6 +19,9 @@ from megatron.neox_arguments import NeoXArgs from megatron.training import pretrain +import os +import numpy as np + def main(input_args=None, overwrite_values=None): neox_args = NeoXArgs.consume_neox_args( @@ -26,6 +29,43 @@ def main(input_args=None, overwrite_values=None): ) neox_args.configure_distributed_args() neox_args.build_tokenizer() # tokenizer needs to be build in training in order to set the padding vocab + + if neox_args.load.split('/')[-1].startswith('JOB'): + if 'scratch' in neox_args.load: + training_mode = 'scratch' + elif 'finetune' in neox_args.load: + training_mode = 'finetune' + else: + training_mode = 'resume' + + elif neox_args.load == 'none': + training_mode = 'scratch' + elif neox_args.finetune: + training_mode = 'finetune' + else: + training_mode = 'resume' + + dir_str = "JOB-{}_{}_it-{}_wu-{}_mxlr-{}_mnlr-{}_sch-{}_tr-{}_{}".format( + "ENTER_YOUR_JOBID_IN_TRAIN.PY",# os.environ['LSB_JOBID'], + neox_args.identifier_string.replace('_',"-"), + neox_args.train_iters, + neox_args.warmup, + neox_args.optimizer['params']['lr'], + neox_args.min_lr, + neox_args.lr_decay_style, + neox_args.train_dataset_name.replace('_',"-"), + training_mode) + + + + neox_args.tensorboard_dir = os.path.join(neox_args.tensorboard_dir, dir_str) + neox_args.save = os.path.join(neox_args.save, dir_str) + print("NEOX ARGS tensorboard_dir: ", neox_args.tensorboard_dir) + print("NEOX ARGS save: ", neox_args.save) + # exit(0) + + + neox_args.initialize_tensorboard_writer() # is initialized if tensorboard directory is defined pretrain(neox_args=neox_args) From d66134087a2c64a8c8fca25fcbccf8d3d3d406ed Mon Sep 17 00:00:00 2001 From: github-actions Date: Sun, 14 Apr 2024 20:24:46 +0000 Subject: [PATCH 4/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index e40e39dad..53ffa37c1 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = defa0a4 + Default = 6ff3ae6 current git hash of repository @@ -762,6 +762,22 @@ Model Arguments +- **identifier_string**: str + + Default = + + an identifier for the model, used for saving checkpoints,logging, etc. + + + +- **warup_eval_interval**: int + + Default = 50 + + the evaluation interval to use during warmup + + + ## NeoXArgsOptimizer Optimizer Arguments @@ -1971,6 +1987,22 @@ Training Arguments +- **train_dataset_name**: str + + Default = no_train_dataset_name_given + + An identified for the training dataset used for logging + + + +- **val_dataset_name**: str + + Default = no_val_dataset_name_given + + An identified for the training dataset used for logging + + + ## NeoXArgsDeepspeedConfig Args for deepspeed config From 2d33aaac78c6ba48f9f68ed11db361df25dc574a Mon Sep 17 00:00:00 2001 From: bentherien Date: Sun, 14 Apr 2024 16:35:54 -0400 Subject: [PATCH 5/8] Revert "Update NeoXArgs docs automatically" This reverts commit 5cdff7646da9c84065de7f5b9dd54e8f816d3487. --- configs/neox_arguments.md | 120 ++------------------------------------ 1 file changed, 4 insertions(+), 116 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 53ffa37c1..c5faa9fc0 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,11 @@ Logging Arguments - **git_hash**: str +<<<<<<< HEAD Default = 6ff3ae6 +======= + Default = 11a5537 +>>>>>>> parent of 5cdff764 (Update NeoXArgs docs automatically) current git hash of repository @@ -1394,122 +1398,6 @@ Training Arguments -- **replay_config**: dict - - Default = None - - Dictionary storing the replay config. - - - -- **is_replay_enabled**: bool - - Default = False - - Triggers the logic for replay. It is important to deal with replay separately from the general "train_data_paths" logic, as replay - requires reusing the same idx files to know what data was seen the first time a dataset was originally trained on. - If one attempts to do replay by just putting the datasets to be replayed in the train_data_paths instead of the replay params: - - If the exact same dataset files are used as during the 1st time it was seen, and the number of iterations on the replay buffer - corresponds to as many epochs on a replay dataset as the non-replay training, the data will be seen in exactly the same order as - the first time if the seed and sequence length is the same. - - For similar reasons, replaying multiple times on the same dataset (e.g. across multiple tasks) with the same number of epochs - on the replay dataset will lead to seeing the same data in the same order. - - If a different dataset is used for replay (e.g. different shard of Pile), then the shuffling will lead to completely different - indices, which will lead to potentially significant proportions of data being unseen if the original training on the replay dataset - did not see all of it, e.g. when training on 300B tokens of the GPT2-tokenised Pile which contains a few dozen billion more tokens, - then sharding the full dataset into smaller ones. - - - -- **replay_idx_paths_prefixes**: list - - Default = None - - List of paths prefixes to retrieve replay dataset idx files. Those idx files should have been generated when originally training on the dataset - being used for replay. They contain in the filename the number of samples potentially seen during pretraining, the sequence length and the - seed. The exact files (shuffle_idx, sample_idx and doc_idx) will be automatically derived from the prefix. Similarly, the data paths will - be generated from the prefixes. - The *_idx files are important as it allows one to know what data was seen in the dataset during training. If those files are missing, you can - regenerate them by relaunching the same training script (most importantly, config) used originally to pretrain on a given dataset. You - can add an exit(0) statement in training.py in pretrain() after the call to build_train_valid_test_data_iterators(neox_args=neox_args). - It is crucial to use the same dataset shard, sequence length, number of iterations, seed, and potentially batch size, or the indices - generated may not be the same. - For a single replay data source, the value passed looks like ["data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s"] and - the files at the following paths (the paths will be constructed during execution from the prefix), must exist: - "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_doc_idx.npy" - "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_sample_idx.npy" - "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_shuffle_idx.npy" - "data/mydataset/train/mydataset" - - - -- **replay_data_to_idx_paths**: dict - - Default = None - - As indicated above, gets automatically built from the replay_idx_paths_prefixes by appending to it "_doc_idx.npy", "_sample_idx.npy" and - "_shuffle_idx.npy". It generates a dict of dict, with the data paths as keys, and dictionaries mapping each data path to the relevant - doc_idx, sample_idx and shuffle_idx file paths. Note that these files must exist at the relevant paths. - - - -- **replay_data_paths**: list - - Default = None - - As indicated above, gets automatically built from the replay_idx_paths_prefixes by removing the information about the idx files to retain - only the path to the dataset itself. - - - -- **replay_data_weights**: list - - Default = None - - List of 'weights' that decide how often to sample from each replay dataset when building the replay buffer. - - - -- **replay_idx_offsets**: list - - Default = None - - List of indices that decide where to start in the list of seen indices during pretraining on each replay dataset when building - the replay buffer. For example, when training originally on a dataset seeing 10000 samples, this allows to start looking at the - RESHUFFLED indices starting from idx replay_idx_offsets[i] for replay dataset i. - If not set, this will uniformly sample among all replay datasets. - - - -- **replay_fraction**: float - - Default = 0.05 - - Fraction of a batch dedicated to doing replay. For example, 0.1 means that in a batch of 100, 19 samples will come from the replay - buffer. Note that this means that if we train on 100B tokens, we will have only used 90B tokens from the datasets specified in - train_data_paths. - - - -- **replay_reshuffle_idx**: bool - - Default = True - - When index files are loaded from those the dataset was originally pretrained on, they will follow the exact same sequence of samples - seen when training on that dataset the first time if this is set to False. If True, the indices are reshuffled to prevent that. - - - -- **replay_seed**: int - - Default = 1234 - - Seed used to reshuffle indices accessed when originally training on a dataset, that are used to do replay. This is useful in the case - where replay is done twice on as many passes over the dataset, in which case if the same seed is used, the replay buffers in both case - will be exactly the same. - - - - **weight_by_num_documents**: bool Default = False From dd6d832db98ffef595b0e051287dc12522f16bb2 Mon Sep 17 00:00:00 2001 From: github-actions Date: Sun, 14 Apr 2024 20:36:22 +0000 Subject: [PATCH 6/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 122 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 117 insertions(+), 5 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index c5faa9fc0..37b16455e 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,11 +111,7 @@ Logging Arguments - **git_hash**: str -<<<<<<< HEAD - Default = 6ff3ae6 -======= - Default = 11a5537 ->>>>>>> parent of 5cdff764 (Update NeoXArgs docs automatically) + Default = 2d33aaa current git hash of repository @@ -1398,6 +1394,122 @@ Training Arguments +- **replay_config**: dict + + Default = None + + Dictionary storing the replay config. + + + +- **is_replay_enabled**: bool + + Default = False + + Triggers the logic for replay. It is important to deal with replay separately from the general "train_data_paths" logic, as replay + requires reusing the same idx files to know what data was seen the first time a dataset was originally trained on. + If one attempts to do replay by just putting the datasets to be replayed in the train_data_paths instead of the replay params: + - If the exact same dataset files are used as during the 1st time it was seen, and the number of iterations on the replay buffer + corresponds to as many epochs on a replay dataset as the non-replay training, the data will be seen in exactly the same order as + the first time if the seed and sequence length is the same. + - For similar reasons, replaying multiple times on the same dataset (e.g. across multiple tasks) with the same number of epochs + on the replay dataset will lead to seeing the same data in the same order. + - If a different dataset is used for replay (e.g. different shard of Pile), then the shuffling will lead to completely different + indices, which will lead to potentially significant proportions of data being unseen if the original training on the replay dataset + did not see all of it, e.g. when training on 300B tokens of the GPT2-tokenised Pile which contains a few dozen billion more tokens, + then sharding the full dataset into smaller ones. + + + +- **replay_idx_paths_prefixes**: list + + Default = None + + List of paths prefixes to retrieve replay dataset idx files. Those idx files should have been generated when originally training on the dataset + being used for replay. They contain in the filename the number of samples potentially seen during pretraining, the sequence length and the + seed. The exact files (shuffle_idx, sample_idx and doc_idx) will be automatically derived from the prefix. Similarly, the data paths will + be generated from the prefixes. + The *_idx files are important as it allows one to know what data was seen in the dataset during training. If those files are missing, you can + regenerate them by relaunching the same training script (most importantly, config) used originally to pretrain on a given dataset. You + can add an exit(0) statement in training.py in pretrain() after the call to build_train_valid_test_data_iterators(neox_args=neox_args). + It is crucial to use the same dataset shard, sequence length, number of iterations, seed, and potentially batch size, or the indices + generated may not be the same. + For a single replay data source, the value passed looks like ["data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s"] and + the files at the following paths (the paths will be constructed during execution from the prefix), must exist: + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_doc_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_sample_idx.npy" + "data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_shuffle_idx.npy" + "data/mydataset/train/mydataset" + + + +- **replay_data_to_idx_paths**: dict + + Default = None + + As indicated above, gets automatically built from the replay_idx_paths_prefixes by appending to it "_doc_idx.npy", "_sample_idx.npy" and + "_shuffle_idx.npy". It generates a dict of dict, with the data paths as keys, and dictionaries mapping each data path to the relevant + doc_idx, sample_idx and shuffle_idx file paths. Note that these files must exist at the relevant paths. + + + +- **replay_data_paths**: list + + Default = None + + As indicated above, gets automatically built from the replay_idx_paths_prefixes by removing the information about the idx files to retain + only the path to the dataset itself. + + + +- **replay_data_weights**: list + + Default = None + + List of 'weights' that decide how often to sample from each replay dataset when building the replay buffer. + + + +- **replay_idx_offsets**: list + + Default = None + + List of indices that decide where to start in the list of seen indices during pretraining on each replay dataset when building + the replay buffer. For example, when training originally on a dataset seeing 10000 samples, this allows to start looking at the + RESHUFFLED indices starting from idx replay_idx_offsets[i] for replay dataset i. + If not set, this will uniformly sample among all replay datasets. + + + +- **replay_fraction**: float + + Default = 0.05 + + Fraction of a batch dedicated to doing replay. For example, 0.1 means that in a batch of 100, 19 samples will come from the replay + buffer. Note that this means that if we train on 100B tokens, we will have only used 90B tokens from the datasets specified in + train_data_paths. + + + +- **replay_reshuffle_idx**: bool + + Default = True + + When index files are loaded from those the dataset was originally pretrained on, they will follow the exact same sequence of samples + seen when training on that dataset the first time if this is set to False. If True, the indices are reshuffled to prevent that. + + + +- **replay_seed**: int + + Default = 1234 + + Seed used to reshuffle indices accessed when originally training on a dataset, that are used to do replay. This is useful in the case + where replay is done twice on as many passes over the dataset, in which case if the same seed is used, the replay buffers in both case + will be exactly the same. + + + - **weight_by_num_documents**: bool Default = False From 621ab257d718b235fd3f1510352c54868dc09c70 Mon Sep 17 00:00:00 2001 From: bentherien Date: Sun, 14 Apr 2024 16:37:57 -0400 Subject: [PATCH 7/8] Revert "added CPT code" This reverts commit 6ff3ae6ce5ffb4d5cac95de77f5e686a18b1c8bd. --- .../train/pile+slim_pajama_300B_each.yml | 25 ----- configs/datasets/train/pile_shard0.yml | 33 ------ configs/datasets/train/pile_train.yml | 11 -- configs/datasets/train/rp.yml | 32 ------ configs/datasets/train/slim_pajama_100B_1.yml | 24 ----- .../train/slim_pajama_100B_1_replay5.yml | 28 ----- configs/datasets/train/slim_pajama_100B_2.yml | 24 ----- .../train/slim_pajama_100B_2_replay5.yml | 45 -------- configs/datasets/train/slim_pajama_100B_3.yml | 24 ----- .../train/slim_pajama_100B_3_replay5.yml | 61 ----------- configs/datasets/train/slim_pajama_150B.yml | 23 ---- configs/datasets/train/slim_pajama_200B_1.yml | 24 ----- .../train/slim_pajama_200B_1_replay5.yml | 23 ---- configs/datasets/train/slim_pajama_200B_2.yml | 24 ----- .../train/slim_pajama_200B_2_replay5.yml | 45 -------- configs/datasets/train/slim_pajama_200B_3.yml | 24 ----- .../train/slim_pajama_200B_3_replay5.yml | 62 ----------- configs/datasets/train/slim_pajama_300B.yml | 23 ---- .../train/slim_pajama_300B_50_replay.yml | 45 -------- .../train/slim_pajama_300B_replay0-5.yml | 23 ---- .../train/slim_pajama_300B_replay1.yml | 23 ---- .../train/slim_pajama_300B_replay10.yml | 23 ---- .../train/slim_pajama_300B_replay5.yml | 23 ---- .../train/slim_pajama_300B_replay50.yml | 24 ----- configs/datasets/train/slim_pajama_606B.yml | 23 ---- configs/datasets/train/slim_pajama_75B.yml | 24 ----- .../datasets/train/slim_pajama_workshop.yml | 55 ---------- configs/datasets/train/taskwise/ArXiv.yml | 11 -- configs/datasets/train/taskwise/Book.yml | 12 --- configs/datasets/train/taskwise/C4.yml | 12 --- .../datasets/train/taskwise/CommonCrawl.yml | 11 -- configs/datasets/train/taskwise/Github.yml | 12 --- .../datasets/train/taskwise/StackExchange.yml | 11 -- configs/datasets/train/taskwise/Wikipedia.yml | 11 -- configs/datasets/val/pile_german.yml | 15 --- configs/datasets/val/pile_rp-no-se.yml | 36 ------- configs/datasets/val/pile_rp.yml | 40 ------- configs/datasets/val/pile_rp_subsets.yml | 62 ----------- configs/datasets/val/pile_slimp.yml | 15 --- configs/datasets/val/pile_slimp_domains.yml | 30 ------ configs/datasets/val/pile_slimp_workshop.yml | 32 ------ configs/datasets/val/pile_val.yml | 13 --- configs/datasets/val/pile_val_shard0.yml | 13 --- configs/iclr_models/3b_test.yml | 101 ------------------ configs/iclr_models/410M.yml | 99 ----------------- configs/iclr_models/410M_ckpt_100.yml | 100 ----------------- configs/iclr_models/49M.yml | 100 ----------------- configs/iclr_models/7_1B.yml | 94 ---------------- configs/load/3e-5const_0_410M_143_CPT.yml | 3 - configs/load/none.yml | 3 - configs/load/pythia_2-8B_143000.yml | 3 - configs/load/pythia_410m.yml | 3 - configs/load/pythia_410m_10000.yml | 3 - configs/load/pythia_410m_143000.yml | 3 - configs/load/pythia_410m_27000.yml | 3 - configs/load/pythia_6-9B_143000.yml | 3 - configs/load/pythia_deduped_410m_10000.yml | 3 - configs/load/pythia_deduped_410m_143000.yml | 3 - configs/load/pythia_deduped_410m_27000.yml | 3 - .../load/resume_1-2e-4_001_7-1B_pile_PT.yml | 3 - ...resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml | 3 - .../resume_1-2e-4_001_7-1B_slim_pajama_PT.yml | 3 - .../load/resume_1-5e-4_001_410M_143_CPT.yml | 3 - configs/load/resume_3e-4_001_410M_143_CPT.yml | 3 - ...sume_3e-4_001_410M_slim_pajama_CPT_r05.yml | 3 - ...esume_3e-4_001_410M_slim_pajama_CPT_r1.yml | 3 - ...sume_3e-4_001_410M_slim_pajama_CPT_r10.yml | 3 - ...esume_3e-4_001_410M_slim_pajama_CPT_r5.yml | 3 - configs/load/resume_3e-4_001_7-1B_pile_PT.yml | 3 - configs/load/resume_6e-4_001_410M_143_CPT.yml | 3 - configs/load/scratch.yml | 3 - .../load/test_3e-5const_0_410M_143_CPT.yml | 3 - configs/load/wu_001_lr1-5e-4_pile.yml | 3 - configs/load/wu_001_lr3e-4_pile.yml | 3 - configs/load/wu_001_lr6e-4_pile.yml | 3 - configs/pythia_410m_llama_setup_finetune.yml | 24 ----- configs/pythia_410m_llama_setup_resume.yml | 24 ----- ...B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml | 16 --- .../adam_constant_lr3e-4_3e-4_wu-001.yml | 13 --- .../adam_constant_lr3e-5_3e-5_wu-0.yml | 13 --- .../adam_cosine-inf_lr3e-4_3e-5_wu-001.yml | 14 --- .../adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml | 13 --- .../adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml | 14 --- .../adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml | 13 --- .../adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml | 13 --- .../adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml | 13 --- .../adam_cosine_lr3e-4_3e-5_wu-0.yml | 13 --- .../adam_cosine_lr3e-4_3e-5_wu-0005.yml | 13 --- .../adam_cosine_lr3e-4_3e-5_wu-001.yml | 13 --- .../adam_cosine_lr3e-4_3e-5_wu-002.yml | 13 --- .../adam_cosine_lr6e-4_6e-5_wu-0.yml | 14 --- .../adam_cosine_lr6e-4_6e-5_wu-0005.yml | 13 --- .../adam_cosine_lr6e-4_6e-5_wu-001.yml | 14 --- .../adam_cosine_lr6e-4_6e-5_wu-002.yml | 13 --- .../adam_infcos_lr3e-4_3e-5_wu-001.yml | 17 --- .../adam_infinv_lr3e-4_3e-5_wu-001.yml | 17 --- .../adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml | 16 --- megatron/data/data_utils.py | 92 ---------------- megatron/neox_arguments/neox_args.py | 20 ---- megatron/training.py | 48 +++------ megatron_config_1.json | 1 - train.py | 40 ------- 102 files changed, 12 insertions(+), 2244 deletions(-) delete mode 100644 configs/datasets/train/pile+slim_pajama_300B_each.yml delete mode 100644 configs/datasets/train/pile_shard0.yml delete mode 100644 configs/datasets/train/pile_train.yml delete mode 100644 configs/datasets/train/rp.yml delete mode 100644 configs/datasets/train/slim_pajama_100B_1.yml delete mode 100644 configs/datasets/train/slim_pajama_100B_1_replay5.yml delete mode 100644 configs/datasets/train/slim_pajama_100B_2.yml delete mode 100644 configs/datasets/train/slim_pajama_100B_2_replay5.yml delete mode 100644 configs/datasets/train/slim_pajama_100B_3.yml delete mode 100644 configs/datasets/train/slim_pajama_100B_3_replay5.yml delete mode 100644 configs/datasets/train/slim_pajama_150B.yml delete mode 100644 configs/datasets/train/slim_pajama_200B_1.yml delete mode 100644 configs/datasets/train/slim_pajama_200B_1_replay5.yml delete mode 100644 configs/datasets/train/slim_pajama_200B_2.yml delete mode 100644 configs/datasets/train/slim_pajama_200B_2_replay5.yml delete mode 100644 configs/datasets/train/slim_pajama_200B_3.yml delete mode 100644 configs/datasets/train/slim_pajama_200B_3_replay5.yml delete mode 100644 configs/datasets/train/slim_pajama_300B.yml delete mode 100644 configs/datasets/train/slim_pajama_300B_50_replay.yml delete mode 100644 configs/datasets/train/slim_pajama_300B_replay0-5.yml delete mode 100644 configs/datasets/train/slim_pajama_300B_replay1.yml delete mode 100644 configs/datasets/train/slim_pajama_300B_replay10.yml delete mode 100644 configs/datasets/train/slim_pajama_300B_replay5.yml delete mode 100644 configs/datasets/train/slim_pajama_300B_replay50.yml delete mode 100644 configs/datasets/train/slim_pajama_606B.yml delete mode 100644 configs/datasets/train/slim_pajama_75B.yml delete mode 100644 configs/datasets/train/slim_pajama_workshop.yml delete mode 100644 configs/datasets/train/taskwise/ArXiv.yml delete mode 100644 configs/datasets/train/taskwise/Book.yml delete mode 100644 configs/datasets/train/taskwise/C4.yml delete mode 100644 configs/datasets/train/taskwise/CommonCrawl.yml delete mode 100644 configs/datasets/train/taskwise/Github.yml delete mode 100644 configs/datasets/train/taskwise/StackExchange.yml delete mode 100644 configs/datasets/train/taskwise/Wikipedia.yml delete mode 100644 configs/datasets/val/pile_german.yml delete mode 100644 configs/datasets/val/pile_rp-no-se.yml delete mode 100644 configs/datasets/val/pile_rp.yml delete mode 100644 configs/datasets/val/pile_rp_subsets.yml delete mode 100644 configs/datasets/val/pile_slimp.yml delete mode 100644 configs/datasets/val/pile_slimp_domains.yml delete mode 100644 configs/datasets/val/pile_slimp_workshop.yml delete mode 100644 configs/datasets/val/pile_val.yml delete mode 100644 configs/datasets/val/pile_val_shard0.yml delete mode 100644 configs/iclr_models/3b_test.yml delete mode 100644 configs/iclr_models/410M.yml delete mode 100644 configs/iclr_models/410M_ckpt_100.yml delete mode 100644 configs/iclr_models/49M.yml delete mode 100644 configs/iclr_models/7_1B.yml delete mode 100644 configs/load/3e-5const_0_410M_143_CPT.yml delete mode 100644 configs/load/none.yml delete mode 100644 configs/load/pythia_2-8B_143000.yml delete mode 100644 configs/load/pythia_410m.yml delete mode 100644 configs/load/pythia_410m_10000.yml delete mode 100644 configs/load/pythia_410m_143000.yml delete mode 100644 configs/load/pythia_410m_27000.yml delete mode 100644 configs/load/pythia_6-9B_143000.yml delete mode 100644 configs/load/pythia_deduped_410m_10000.yml delete mode 100644 configs/load/pythia_deduped_410m_143000.yml delete mode 100644 configs/load/pythia_deduped_410m_27000.yml delete mode 100644 configs/load/resume_1-2e-4_001_7-1B_pile_PT.yml delete mode 100644 configs/load/resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml delete mode 100644 configs/load/resume_1-2e-4_001_7-1B_slim_pajama_PT.yml delete mode 100644 configs/load/resume_1-5e-4_001_410M_143_CPT.yml delete mode 100644 configs/load/resume_3e-4_001_410M_143_CPT.yml delete mode 100644 configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r05.yml delete mode 100644 configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r1.yml delete mode 100644 configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r10.yml delete mode 100644 configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r5.yml delete mode 100644 configs/load/resume_3e-4_001_7-1B_pile_PT.yml delete mode 100644 configs/load/resume_6e-4_001_410M_143_CPT.yml delete mode 100644 configs/load/scratch.yml delete mode 100644 configs/load/test_3e-5const_0_410M_143_CPT.yml delete mode 100644 configs/load/wu_001_lr1-5e-4_pile.yml delete mode 100644 configs/load/wu_001_lr3e-4_pile.yml delete mode 100644 configs/load/wu_001_lr6e-4_pile.yml delete mode 100644 configs/pythia_410m_llama_setup_finetune.yml delete mode 100644 configs/pythia_410m_llama_setup_resume.yml delete mode 100644 configs/schedules/7_1B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml delete mode 100644 configs/schedules/adam_constant_lr3e-4_3e-4_wu-001.yml delete mode 100644 configs/schedules/adam_constant_lr3e-5_3e-5_wu-0.yml delete mode 100644 configs/schedules/adam_cosine-inf_lr3e-4_3e-5_wu-001.yml delete mode 100644 configs/schedules/adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml delete mode 100644 configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml delete mode 100644 configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml delete mode 100644 configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml delete mode 100644 configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml delete mode 100644 configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0.yml delete mode 100644 configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0005.yml delete mode 100644 configs/schedules/adam_cosine_lr3e-4_3e-5_wu-001.yml delete mode 100644 configs/schedules/adam_cosine_lr3e-4_3e-5_wu-002.yml delete mode 100644 configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0.yml delete mode 100644 configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0005.yml delete mode 100644 configs/schedules/adam_cosine_lr6e-4_6e-5_wu-001.yml delete mode 100644 configs/schedules/adam_cosine_lr6e-4_6e-5_wu-002.yml delete mode 100644 configs/schedules/adam_infcos_lr3e-4_3e-5_wu-001.yml delete mode 100644 configs/schedules/adam_infinv_lr3e-4_3e-5_wu-001.yml delete mode 100644 configs/schedules/adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml delete mode 100644 megatron_config_1.json diff --git a/configs/datasets/train/pile+slim_pajama_300B_each.yml b/configs/datasets/train/pile+slim_pajama_300B_each.yml deleted file mode 100644 index 94f14fd58..000000000 --- a/configs/datasets/train/pile+slim_pajama_300B_each.yml +++ /dev/null @@ -1,25 +0,0 @@ -{ - # This will sample with equal likelihood Pile and SlimPajama: - "train-data-paths": [ - "data/pile/train/pile_train", - 'data/slim_pajama/train_300B/ArXiv/ArXiv', - 'data/slim_pajama/train_300B/Book/Book', - 'data/slim_pajama/train_300B/C4/C4', - 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_300B/Github/Github', - 'data/slim_pajama/train_300B/StackExchange/StackExchange', - 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], - "train-data-weights": [ - 50.0, - 2.2140923205, - 2.101565663, - 13.344249736, - 1.9986465625, - 2.612070528, - 1.6855393625, - 26.0438358255 - ], - "train-dataset-name": 'pile+slim_pajama_300B_each', - "train-iters": 264732, - "lr-decay-iters": 264732, -} \ No newline at end of file diff --git a/configs/datasets/train/pile_shard0.yml b/configs/datasets/train/pile_shard0.yml deleted file mode 100644 index 421fd2804..000000000 --- a/configs/datasets/train/pile_shard0.yml +++ /dev/null @@ -1,33 +0,0 @@ -{ - "train-data-paths": [ - "data/pile/shard_0/shard_0_text_document", - ], - "train-data-weights": [ - 1., - ], - "train-dataset-name": 'pile_shard0', - "train-iters": 1000, - "lr-decay-iters": 1000, - "is_replay_enabled": true, - "replay_config": { - "enabled": true, - # Have to specify idx filenames from original pretraining on tasks, as they contain the num iterations - # and seen indices assuming we're using the same (non-replay) seed as during pretraining - "replay_idx_paths_prefixes": [ - "data/pile/shard_0/shard_0_text_document_train_0_indexmap_32160ns_2048sl_1234s", - ], - "replay_data_weights":[ - 1.00, - ], - "replay_idx_offsets": [ - 1, - ], - # Fraction of samples coming from the replay buffer, between 0 and 1. - "replay_fraction": 0.5, - # Seed and reshuffle go hand in hand. They control whether you want to see the replay data in the same order - # as you've seen it (done by setting reshuffle to false), and if you decide to reshuffle, what seed you should - # use to reshuffle the seen data. - "replay_seed": 1234, - "replay_reshuffle_idx": false, - }, -} \ No newline at end of file diff --git a/configs/datasets/train/pile_train.yml b/configs/datasets/train/pile_train.yml deleted file mode 100644 index ff37e2c74..000000000 --- a/configs/datasets/train/pile_train.yml +++ /dev/null @@ -1,11 +0,0 @@ -{ - "train-data-paths": [ - "data/pile/train/pile_train", - ], - "train-data-weights": [ - 1., - ], - "train-dataset-name": 'pile_train', - "train-iters": 132366, - "lr-decay-iters": 132366, -} \ No newline at end of file diff --git a/configs/datasets/train/rp.yml b/configs/datasets/train/rp.yml deleted file mode 100644 index cfea01a64..000000000 --- a/configs/datasets/train/rp.yml +++ /dev/null @@ -1,32 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/arxiv/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/book/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/c4/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/wikipedia/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/github/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/stackexchange/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2019-30/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2020-05/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2021-04/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2022-05/folder_train/tokenized_text_document", - "/gpfs/alpine/csc499/proj-shared/incite_datasets/SlimPajama/tokenized300B/train_splits/common_crawl/2023-06/folder_train/tokenized_text_document", - ], - "train-data-weights": [ - 2.5, - 4.5, - 15.0, - 4.5, - 4.5, - 2.0, - 13.4, - 13.4, - 13.4, - 13.4, - 13.4 - ], - "train-dataset-name": 'rp', - - -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_1.yml b/configs/datasets/train/slim_pajama_100B_1.yml deleted file mode 100644 index 9bbf75249..000000000 --- a/configs/datasets/train/slim_pajama_100B_1.yml +++ /dev/null @@ -1,24 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_0-100B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_0-100B/Book/Book', - 'data/slim_pajama/tokenized_train_0-100B/C4/C4', - 'data/slim_pajama/tokenized_train_0-100B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_0-100B/Github/Github', - 'data/slim_pajama/tokenized_train_0-100B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_0-100B/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [ - 3.4703977435152775, - 3.904381603212791, - 25.641950653802013, - 3.804228253591696, - 4.9994643949282045, - 3.1815838172641993, - 49.99799353368582, - ], - "train-iters": 44229, - "lr-decay-iters": 44229, - "train-dataset-name": 'slim_pajama_100B_1', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_1_replay5.yml b/configs/datasets/train/slim_pajama_100B_1_replay5.yml deleted file mode 100644 index 2d957c3e5..000000000 --- a/configs/datasets/train/slim_pajama_100B_1_replay5.yml +++ /dev/null @@ -1,28 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_0-100B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_0-100B/Book/Book', - 'data/slim_pajama/tokenized_train_0-100B/C4/C4', - 'data/slim_pajama/tokenized_train_0-100B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_0-100B/Github/Github', - 'data/slim_pajama/tokenized_train_0-100B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_0-100B/CommonCrawl/CommonCrawl', - - 'data/pile_replay_shards/replay_10B_1/splits', - ], - "train-data-weights": [ - 3.4703977435152775, - 3.904381603212791, - 25.641950653802013, - 3.804228253591696, - 4.9994643949282045, - 3.1815838172641993, - 49.99799353368582, - - 5.0 - ], - "train-iters": 44229, - "lr-decay-iters": 44229, - "train-dataset-name": 'slim_pajama_100B_1_replay5', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_2.yml b/configs/datasets/train/slim_pajama_100B_2.yml deleted file mode 100644 index f59f56860..000000000 --- a/configs/datasets/train/slim_pajama_100B_2.yml +++ /dev/null @@ -1,24 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_100B-200B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_100B-200B/Book/Book', - 'data/slim_pajama/tokenized_train_100B-200B/C4/C4', - 'data/slim_pajama/tokenized_train_100B-200B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_100B-200B/Github/Github', - 'data/slim_pajama/tokenized_train_100B-200B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_100B-200B/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [ - 4.03666599074094, - 3.927523855378127, - 25.467175464208918, - 3.7984379710376293, - 4.990226864678155, - 3.1957646326079723, - 49.58420522134826, - ], - "train-iters": 44229, - "lr-decay-iters": 44229, - "train-dataset-name": 'slim_pajama_100B_2', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_2_replay5.yml b/configs/datasets/train/slim_pajama_100B_2_replay5.yml deleted file mode 100644 index ce1ba6f62..000000000 --- a/configs/datasets/train/slim_pajama_100B_2_replay5.yml +++ /dev/null @@ -1,45 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_100B-200B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_100B-200B/Book/Book', - 'data/slim_pajama/tokenized_train_100B-200B/C4/C4', - 'data/slim_pajama/tokenized_train_100B-200B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_100B-200B/Github/Github', - 'data/slim_pajama/tokenized_train_100B-200B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_100B-200B/CommonCrawl/CommonCrawl', - - 'data/pile_replay_shards/replay_10B_2/splits', - - 'data/sp_replay_shards/100B_1_shard1/ArXiv/ArXiv', - 'data/sp_replay_shards/100B_1_shard1/Book/Book', - 'data/sp_replay_shards/100B_1_shard1/C4/C4', - 'data/sp_replay_shards/100B_1_shard1/Wikipedia/Wikipedia', - 'data/sp_replay_shards/100B_1_shard1/Github/Github', - 'data/sp_replay_shards/100B_1_shard1/StackExchange/StackExchange', - 'data/sp_replay_shards/100B_1_shard1/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [ - 4.03666599074094, - 3.927523855378127, - 25.467175464208918, - 3.7984379710376293, - 4.990226864678155, - 3.1957646326079723, - 49.58420522134826, - - 3.8125, - - # total: 1.1875, - 0.04337997179394097, - 0.04880477004015989, - 0.3205243831725252, - 0.0475528531698962, - 0.06249330493660256, - 0.03976979771580249, - 0.6249749191710727, - ], - "train-iters": 44229, - "lr-decay-iters": 44229, - "train-dataset-name": 'slim_pajama_100B_2_replay5', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_3.yml b/configs/datasets/train/slim_pajama_100B_3.yml deleted file mode 100644 index 6cf015267..000000000 --- a/configs/datasets/train/slim_pajama_100B_3.yml +++ /dev/null @@ -1,24 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_200B-300B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_200B-300B/Book/Book', - 'data/slim_pajama/tokenized_train_200B-300B/C4/C4', - 'data/slim_pajama/tokenized_train_200B-300B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_200B-300B/Github/Github', - 'data/slim_pajama/tokenized_train_200B-300B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_200B-300B/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [ - 3.491756366873565, - 4.084283062119696, - 25.524317038754475, - 3.8109321899190314, - 4.89534056131328, - 3.254459546224121, - 49.93891123479581, - ], - "train-iters": 44229, - "lr-decay-iters": 44229, - "train-dataset-name": 'slim_pajama_100B_3', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_100B_3_replay5.yml b/configs/datasets/train/slim_pajama_100B_3_replay5.yml deleted file mode 100644 index 9f537c532..000000000 --- a/configs/datasets/train/slim_pajama_100B_3_replay5.yml +++ /dev/null @@ -1,61 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_200B-300B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_200B-300B/Book/Book', - 'data/slim_pajama/tokenized_train_200B-300B/C4/C4', - 'data/slim_pajama/tokenized_train_200B-300B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_200B-300B/Github/Github', - 'data/slim_pajama/tokenized_train_200B-300B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_200B-300B/CommonCrawl/CommonCrawl', - - 'data/pile_replay_shards/replay_10B_3/splits', - - 'data/sp_replay_shards/100B_1_shard2/ArXiv/ArXiv', - 'data/sp_replay_shards/100B_1_shard2/Book/Book', - 'data/sp_replay_shards/100B_1_shard2/C4/C4', - 'data/sp_replay_shards/100B_1_shard2/Wikipedia/Wikipedia', - 'data/sp_replay_shards/100B_1_shard2/Github/Github', - 'data/sp_replay_shards/100B_1_shard2/StackExchange/StackExchange', - 'data/sp_replay_shards/100B_1_shard2/CommonCrawl/CommonCrawl', - - 'data/sp_replay_shards/100B_2_shard1/ArXiv/ArXiv', - 'data/sp_replay_shards/100B_2_shard1/Book/Book', - 'data/sp_replay_shards/100B_2_shard1/C4/C4', - 'data/sp_replay_shards/100B_2_shard1/Wikipedia/Wikipedia', - 'data/sp_replay_shards/100B_2_shard1/Github/Github', - 'data/sp_replay_shards/100B_2_shard1/StackExchange/StackExchange', - 'data/sp_replay_shards/100B_2_shard1/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [3.491756366873565, - 4.084283062119696, - 25.524317038754475, - 3.8109321899190314, - 4.89534056131328, - 3.254459546224121, - 49.93891123479581, - - 3.088125, - - # total: 0.961875, - 0.03513777715309219, - 0.03953186373252951, - 0.2596247503697454, - 0.03851781106761592, - 0.05061957699864807, - 0.03221353614980002, - 0.506229684528569, - - #total: 0.95, - 0.0403666599074094, - 0.03927523855378127, - 0.25467175464208913, - 0.03798437971037629, - 0.049902268646781545, - 0.03195764632607972, - 0.4958420522134826, - ], - "train-iters": 44229, - "lr-decay-iters": 44229, - "train-dataset-name": 'slim_pajama_100B_3_replay5', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_150B.yml b/configs/datasets/train/slim_pajama_150B.yml deleted file mode 100644 index 477334a4a..000000000 --- a/configs/datasets/train/slim_pajama_150B.yml +++ /dev/null @@ -1,23 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_150B/ArXiv/ArXiv', - 'data/slim_pajama/train_150B/Book/Book', - 'data/slim_pajama/train_150B/C4/C4', - 'data/slim_pajama/train_150B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_150B/Github/Github', - 'data/slim_pajama/train_150B/StackExchange/StackExchange', - 'data/slim_pajama/train_150B/CommonCrawl/CommonCrawl',], - "train-data-weights": [ - 4.576447650075095, - 4.198505982426652, - 26.62982374026485, - 3.9945183507095225, - 5.218824282422116, - 3.372167199706489, - 52.00971279439528 - ], - "train-dataset-name": 'slim_pajama_150B', - "train-iters": 66342, - "lr-decay-iters": 66342, -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_1.yml b/configs/datasets/train/slim_pajama_200B_1.yml deleted file mode 100644 index b242479a4..000000000 --- a/configs/datasets/train/slim_pajama_200B_1.yml +++ /dev/null @@ -1,24 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_0-200B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_0-200B/Book/Book', - 'data/slim_pajama/tokenized_train_0-200B/C4/C4', - 'data/slim_pajama/tokenized_train_0-200B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_0-200B/Github/Github', - 'data/slim_pajama/tokenized_train_0-200B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_0-200B/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [ - 3.4703977435152775, - 3.904381603212791, - 25.641950653802013, - 3.804228253591696, - 4.9994643949282045, - 3.1815838172641993, - 49.99799353368582, - ], - "train-iters": 88457, - "lr-decay-iters": 88457, - "train-dataset-name": 'slim_pajama_200B_1', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_1_replay5.yml b/configs/datasets/train/slim_pajama_200B_1_replay5.yml deleted file mode 100644 index 518391cd4..000000000 --- a/configs/datasets/train/slim_pajama_200B_1_replay5.yml +++ /dev/null @@ -1,23 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_0-200B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_0-200B/Book/Book', - 'data/slim_pajama/tokenized_train_0-200B/C4/C4', - 'data/slim_pajama/tokenized_train_0-200B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_0-200B/Github/Github', - 'data/slim_pajama/tokenized_train_0-200B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_0-200B/CommonCrawl/CommonCrawl', - 'data/pile_replay_shards/replay_10B_1/splits',], - "train-data-weights": [3.4703977435152775, - 3.904381603212791, - 25.641950653802013, - 3.804228253591696, - 4.9994643949282045, - 3.1815838172641993, - 49.99799353368582, - 5.0], - "train-iters": 88457, - "lr-decay-iters": 88457, - "train-dataset-name": 'slim_pajama_200B_1_replay5', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_2.yml b/configs/datasets/train/slim_pajama_200B_2.yml deleted file mode 100644 index 753831a72..000000000 --- a/configs/datasets/train/slim_pajama_200B_2.yml +++ /dev/null @@ -1,24 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_200B-400B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_200B-400B/Book/Book', - 'data/slim_pajama/tokenized_train_200B-400B/C4/C4', - 'data/slim_pajama/tokenized_train_200B-400B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_200B-400B/Github/Github', - 'data/slim_pajama/tokenized_train_200B-400B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_200B-400B/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [ - 4.3217120887898215, - 4.0000865058486115, - 25.313223824892418, - 3.7875979441876595, - 4.916178735276899, - 3.205115989375923, - 49.45608491162867 - ], - "train-iters": 88457, - "lr-decay-iters": 88457, - "train-dataset-name": 'slim_pajama_200B_2', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_2_replay5.yml b/configs/datasets/train/slim_pajama_200B_2_replay5.yml deleted file mode 100644 index b71420699..000000000 --- a/configs/datasets/train/slim_pajama_200B_2_replay5.yml +++ /dev/null @@ -1,45 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_200B-400B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_200B-400B/Book/Book', - 'data/slim_pajama/tokenized_train_200B-400B/C4/C4', - 'data/slim_pajama/tokenized_train_200B-400B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_200B-400B/Github/Github', - 'data/slim_pajama/tokenized_train_200B-400B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_200B-400B/CommonCrawl/CommonCrawl', - - 'data/pile_replay_shards/replay_10B_1/splits', - - 'data/sp_replay_shards/200B_1_shard1/ArXiv/ArXiv', - 'data/sp_replay_shards/200B_1_shard1/Book/Book', - 'data/sp_replay_shards/200B_1_shard1/C4/C4', - 'data/sp_replay_shards/200B_1_shard1/Wikipedia/Wikipedia', - 'data/sp_replay_shards/200B_1_shard1/Github/Github', - 'data/sp_replay_shards/200B_1_shard1/StackExchange/StackExchange', - 'data/sp_replay_shards/200B_1_shard1/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [ - 4.3217120887898215, - 4.0000865058486115, - 25.313223824892418, - 3.7875979441876595, - 4.916178735276899, - 3.205115989375923, - 49.45608491162867, - - 3.8125, - - #total 1.1875 - 0.04337997179394097, - 0.04880477004015989, - 0.3205243831725252, - 0.0475528531698962, - 0.06249330493660256, - 0.03976979771580249, - 0.6249749191710727, - ], - "train-iters": 88457, - "lr-decay-iters": 88457, - "train-dataset-name": 'slim_pajama_200B_2_replay5', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_3.yml b/configs/datasets/train/slim_pajama_200B_3.yml deleted file mode 100644 index 9d36e7c4d..000000000 --- a/configs/datasets/train/slim_pajama_200B_3.yml +++ /dev/null @@ -1,24 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_400B-600B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_400B-600B/Book/Book', - 'data/slim_pajama/tokenized_train_400B-600B/C4/C4', - 'data/slim_pajama/tokenized_train_400B-600B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_400B-600B/Github/Github', - 'data/slim_pajama/tokenized_train_400B-600B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_400B-600B/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [ - 4.3217120887898215, - 4.0000865058486115, - 25.313223824892418, - 3.7875979441876595, - 4.916178735276899, - 3.205115989375923, - 49.45608491162867 - ], - "train-iters": 88457, - "lr-decay-iters": 88457, - "train-dataset-name": 'slim_pajama_200B_3', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_200B_3_replay5.yml b/configs/datasets/train/slim_pajama_200B_3_replay5.yml deleted file mode 100644 index cdcfcefad..000000000 --- a/configs/datasets/train/slim_pajama_200B_3_replay5.yml +++ /dev/null @@ -1,62 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/tokenized_train_400B-600B/ArXiv/ArXiv', - 'data/slim_pajama/tokenized_train_400B-600B/Book/Book', - 'data/slim_pajama/tokenized_train_400B-600B/C4/C4', - 'data/slim_pajama/tokenized_train_400B-600B/Wikipedia/Wikipedia', - 'data/slim_pajama/tokenized_train_400B-600B/Github/Github', - 'data/slim_pajama/tokenized_train_400B-600B/StackExchange/StackExchange', - 'data/slim_pajama/tokenized_train_400B-600B/CommonCrawl/CommonCrawl', - - 'data/pile_replay_shards/replay_10B_3/splits', - - 'data/sp_replay_shards/200B_1_shard2/ArXiv/ArXiv', - 'data/sp_replay_shards/200B_1_shard2/Book/Book', - 'data/sp_replay_shards/200B_1_shard2/C4/C4', - 'data/sp_replay_shards/200B_1_shard2/Wikipedia/Wikipedia', - 'data/sp_replay_shards/200B_1_shard2/Github/Github', - 'data/sp_replay_shards/200B_1_shard2/StackExchange/StackExchange', - 'data/sp_replay_shards/200B_1_shard2/CommonCrawl/CommonCrawl', - - 'data/sp_replay_shards/200B_2_shard1/ArXiv/ArXiv', - 'data/sp_replay_shards/200B_2_shard1/Book/Book', - 'data/sp_replay_shards/200B_2_shard1/C4/C4', - 'data/sp_replay_shards/200B_2_shard1/Wikipedia/Wikipedia', - 'data/sp_replay_shards/200B_2_shard1/Github/Github', - 'data/sp_replay_shards/200B_2_shard1/StackExchange/StackExchange', - 'data/sp_replay_shards/200B_2_shard1/CommonCrawl/CommonCrawl', - ], - "train-data-weights": [ - 4.3217120887898215, - 4.0000865058486115, - 25.313223824892418, - 3.7875979441876595, - 4.916178735276899, - 3.205115989375923, - 49.45608491162867, - - 3.088125, - - #total: 0.961875, - 0.03513777715309219, - 0.03953186373252951, - 0.2596247503697454, - 0.03851781106761592, - 0.05061957699864807, - 0.03221353614980002, - 0.506229684528569, - - #total: 0.95, - 0.043217120887898204, - 0.0400008650584861, - 0.2531322382489241, - 0.03787597944187659, - 0.04916178735276898, - 0.03205115989375923, - 0.49456084911628667, - ], - "train-iters": 88457, - "lr-decay-iters": 88457, - "train-dataset-name": 'slim_pajama_200B_3_replay5', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B.yml b/configs/datasets/train/slim_pajama_300B.yml deleted file mode 100644 index 1bc5be68d..000000000 --- a/configs/datasets/train/slim_pajama_300B.yml +++ /dev/null @@ -1,23 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/ArXiv/ArXiv', - 'data/slim_pajama/train_300B/Book/Book', - 'data/slim_pajama/train_300B/C4/C4', - 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_300B/Github/Github', - 'data/slim_pajama/train_300B/StackExchange/StackExchange', - 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], - "train-data-weights": [ - 4.428184641, - 4.203131326, - 26.688499472, - 3.997293125, - 5.224141056, - 3.371078725, - 52.087671651 - ], - "train-iters": 132366, - "lr-decay-iters": 132366, - "train-dataset-name": 'slim_pajama_300B', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_50_replay.yml b/configs/datasets/train/slim_pajama_300B_50_replay.yml deleted file mode 100644 index 841230d2a..000000000 --- a/configs/datasets/train/slim_pajama_300B_50_replay.yml +++ /dev/null @@ -1,45 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/ArXiv/ArXiv', - 'data/slim_pajama/train_300B/Book/Book', - 'data/slim_pajama/train_300B/C4/C4', - 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_300B/Github/Github', - 'data/slim_pajama/train_300B/StackExchange/StackExchange', - 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], - "train-data-weights": [ - 4.428184641, - 4.203131326, - 26.688499472, - 3.997293125, - 5.224141056, - 3.371078725, - 52.087671651 - ], - "train-iters": 132366, - "lr-decay-iters": 132366, - "train-dataset-name": 'slim_pajama_300B', - - "replay_config": { - "enabled": true, - # Have to specify idx filenames from original pretraining on tasks, as they contain the num iterations - # and seen indices assuming we're using the same (non-replay) seed as during pretraining - "replay_idx_paths_prefixes": [ - "data/pile/train/pile_train_train_0_indexmap_146862725ns_2048sl_1234s", - ], - "replay_data_weights":[ - 1.00, - ], - "replay_idx_offsets": [ - 0, - ], - # Fraction of samples coming from the replay buffer, between 0 and 1. - "replay_fraction": 0.5, - # Will need to reshuffle the shuffled indices. If you have replay multiple times on the same task, don't - # forget to change it every time if not manually managing offsets ! Otherwise you will see the same replay - # buffer in the same order. - "replay_seed": 1234, - "replay_reshuffle_idx": true, - }, -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay0-5.yml b/configs/datasets/train/slim_pajama_300B_replay0-5.yml deleted file mode 100644 index 6069b4cd9..000000000 --- a/configs/datasets/train/slim_pajama_300B_replay0-5.yml +++ /dev/null @@ -1,23 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/ArXiv/ArXiv', - 'data/slim_pajama/train_300B/Book/Book', - 'data/slim_pajama/train_300B/C4/C4', - 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_300B/Github/Github', - 'data/slim_pajama/train_300B/StackExchange/StackExchange', - 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', - 'data/pile_replay_shards/replay_1-5B/splits',], - "train-data-weights": [4.406043717971242, - 4.182115669537286, - 26.555056975702207, - 3.977306659534093, - 5.198020350927921, - 3.354223331509169, - 51.82723329481809, - 0.5], - "train-iters": 132366, - "lr-decay-iters": 132366, - "train-dataset-name": 'slim_pajama_300B_replay05', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay1.yml b/configs/datasets/train/slim_pajama_300B_replay1.yml deleted file mode 100644 index 6983c3a27..000000000 --- a/configs/datasets/train/slim_pajama_300B_replay1.yml +++ /dev/null @@ -1,23 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/ArXiv/ArXiv', - 'data/slim_pajama/train_300B/Book/Book', - 'data/slim_pajama/train_300B/C4/C4', - 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_300B/Github/Github', - 'data/slim_pajama/train_300B/StackExchange/StackExchange', - 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', - 'data/pile_replay_shards/replay_3B/splits',], - "train-data-weights": [4.383902794765357, - 4.161100012906445, - 26.42161447833687, - 3.9573201939082936, - 5.171899645646876, - 3.337367937883495, - 51.566794936552675, - 1.0], - "train-iters": 132366, - "lr-decay-iters": 132366, - "train-dataset-name": 'slim_pajama_300B_replay1', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay10.yml b/configs/datasets/train/slim_pajama_300B_replay10.yml deleted file mode 100644 index a85ede531..000000000 --- a/configs/datasets/train/slim_pajama_300B_replay10.yml +++ /dev/null @@ -1,23 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/ArXiv/ArXiv', - 'data/slim_pajama/train_300B/Book/Book', - 'data/slim_pajama/train_300B/C4/C4', - 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_300B/Github/Github', - 'data/slim_pajama/train_300B/StackExchange/StackExchange', - 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', - 'data/pile_replay_shards/replay_30B/splits',], - "train-data-weights": [3.985366177059415, - 3.7828181935513134, - 24.01964952576079, - 3.5975638126439033, - 4.701726950588069, - 3.033970852621359, - 46.87890448777516, - 10.0], - "train-iters": 132366, - "lr-decay-iters": 132366, - "train-dataset-name": 'slim_pajama_300B_replay10', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay5.yml b/configs/datasets/train/slim_pajama_300B_replay5.yml deleted file mode 100644 index c3e1b5713..000000000 --- a/configs/datasets/train/slim_pajama_300B_replay5.yml +++ /dev/null @@ -1,23 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/ArXiv/ArXiv', - 'data/slim_pajama/train_300B/Book/Book', - 'data/slim_pajama/train_300B/C4/C4', - 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_300B/Github/Github', - 'data/slim_pajama/train_300B/StackExchange/StackExchange', - 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', - 'data/pile_replay_shards/replay_15B/splits',], - "train-data-weights": [4.206775409118271, - 3.99297475985972, - 25.354074499414168, - 3.797428468901898, - 4.962934003398518, - 3.2025247888781014, - 49.48328807042934, - 5.0], - "train-iters": 132366, - "lr-decay-iters": 132366, - "train-dataset-name": 'slim_pajama_300B_replay5', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_300B_replay50.yml b/configs/datasets/train/slim_pajama_300B_replay50.yml deleted file mode 100644 index 372a4ad7b..000000000 --- a/configs/datasets/train/slim_pajama_300B_replay50.yml +++ /dev/null @@ -1,24 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/ArXiv/ArXiv', - 'data/slim_pajama/train_300B/Book/Book', - 'data/slim_pajama/train_300B/C4/C4', - 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_300B/Github/Github', - 'data/slim_pajama/train_300B/StackExchange/StackExchange', - 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl', - 'data/pile/train/pile_train',], - "train-data-weights": [ - 2.2140923205885636, - 2.101565663084063, - 13.344249736533772, - 1.9986465625799463, - 2.612070528104483, - 1.6855393625674218, - 26.043835826541756, - 50.0], - "train-iters": 132366, - "lr-decay-iters": 132366, - "train-dataset-name": 'slim_pajama_300B_replay50', -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_606B.yml b/configs/datasets/train/slim_pajama_606B.yml deleted file mode 100644 index 159b382f5..000000000 --- a/configs/datasets/train/slim_pajama_606B.yml +++ /dev/null @@ -1,23 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_606B/ArXiv/ArXiv', - 'data/slim_pajama/train_606B/Book/Book', - 'data/slim_pajama/train_606B/C4/C4', - 'data/slim_pajama/train_606B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_606B/Github/Github', - 'data/slim_pajama/train_606B/StackExchange/StackExchange', - 'data/slim_pajama/train_606B/CommonCrawl/CommonCrawl',], - "train-data-weights": [ - 4.576447650075095, - 4.198505982426652, - 26.62982374026485, - 3.9945183507095225, - 5.218824282422116, - 3.372167199706489, - 52.00971279439528 - ], - "train-dataset-name": 'slim_pajama_606B', - "train-iters": 268023, - "lr-decay-iters": 268023, -} \ No newline at end of file diff --git a/configs/datasets/train/slim_pajama_75B.yml b/configs/datasets/train/slim_pajama_75B.yml deleted file mode 100644 index c8195e881..000000000 --- a/configs/datasets/train/slim_pajama_75B.yml +++ /dev/null @@ -1,24 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_75B/ArXiv/ArXiv', - 'data/slim_pajama/train_75B/Book/Book', - 'data/slim_pajama/train_75B/C4/C4', - 'data/slim_pajama/train_75B/Wikipedia/Wikipedia', - 'data/slim_pajama/train_75B/Github/Github', - 'data/slim_pajama/train_75B/StackExchange/StackExchange', - 'data/slim_pajama/train_75B/CommonCrawl/CommonCrawl',], - "train-data-weights": [ - 4.576447650075095, - 4.198505982426652, - 26.62982374026485, - 3.9945183507095225, - 5.218824282422116, - 3.372167199706489, - 52.00971279439528 - ], - "train-dataset-name": 'slim_pajama_75B', - "train-iters": 33171, - "lr-decay-iters": 33171, -} - diff --git a/configs/datasets/train/slim_pajama_workshop.yml b/configs/datasets/train/slim_pajama_workshop.yml deleted file mode 100644 index 894be1f06..000000000 --- a/configs/datasets/train/slim_pajama_workshop.yml +++ /dev/null @@ -1,55 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/tokenized300B/train_splits/ArXiv/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk1/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk2/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk3/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk4/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk5/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk6/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk7/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk8/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk9/tokenized_text_document', - 'data/tokenized300B/train_splits/C4/chunk10/tokenized_text_document', - 'data/tokenized300B/train_splits/Github/tokenized_text_document', - 'data/tokenized300B/train_splits/StackExchange/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk1/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk2/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk3/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk4/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk5/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk6/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk7/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk8/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk9/tokenized_text_document', - 'data/tokenized300B/train_splits/CommonCrawl/chunk10/tokenized_text_document'], - "train-data-weights": [ - 2.5, - 1.5419773919518147, - 1.1151591815617388, - 1.5434825189166101, - 1.5436987632648516, - 1.5477789754805582, - 1.5415865562140247, - 1.5418730205678666, - 1.5429308066550182, - 1.5411863039403448, - 1.5403264814471718, - 4.5, - 2.0, - 6.879159444832248, - 4.995774267161819, - 6.872989472503661, - 6.874004825491571, - 6.892691875394247, - 6.889262861528587, - 6.869394427542272, - 6.909669621263436, - 6.915576906743681, - 6.90147629753848 - ], - "train-dataset-name": 'slim_pajama_workshop', - -} - diff --git a/configs/datasets/train/taskwise/ArXiv.yml b/configs/datasets/train/taskwise/ArXiv.yml deleted file mode 100644 index 83f68b22d..000000000 --- a/configs/datasets/train/taskwise/ArXiv.yml +++ /dev/null @@ -1,11 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/ArXiv/ArXiv', - ], - "train-data-weights": [ - 1.0, - ], - # 13252597099 - "train-iters": 5861, -} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/Book.yml b/configs/datasets/train/taskwise/Book.yml deleted file mode 100644 index 80311f835..000000000 --- a/configs/datasets/train/taskwise/Book.yml +++ /dev/null @@ -1,12 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/Book/Book', - ], - "train-data-weights": [ - 1.0, - ], - - # 12579061292 5563 - "train-iters": 5563, -} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/C4.yml b/configs/datasets/train/taskwise/C4.yml deleted file mode 100644 index d67b30b15..000000000 --- a/configs/datasets/train/taskwise/C4.yml +++ /dev/null @@ -1,12 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/C4/C4', - ], - "train-data-weights": [ - 1.0, - ], - - # 79872895846 - "train-iters": 35326, -} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/CommonCrawl.yml b/configs/datasets/train/taskwise/CommonCrawl.yml deleted file mode 100644 index a86b86ddf..000000000 --- a/configs/datasets/train/taskwise/CommonCrawl.yml +++ /dev/null @@ -1,11 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/CommonCrawl/CommonCrawl',], - "train-data-weights": [ - 1.0, - ], - - # 155887114486 - "train-iters": 68946, -} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/Github.yml b/configs/datasets/train/taskwise/Github.yml deleted file mode 100644 index 309e885fe..000000000 --- a/configs/datasets/train/taskwise/Github.yml +++ /dev/null @@ -1,12 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/Github/Github', - ], - "train-data-weights": [ - 1.0, - ], - - # 15634722173 - "train-iters": 6914, -} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/StackExchange.yml b/configs/datasets/train/taskwise/StackExchange.yml deleted file mode 100644 index 16ea275e5..000000000 --- a/configs/datasets/train/taskwise/StackExchange.yml +++ /dev/null @@ -1,11 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/StackExchange/StackExchange', - ], - "train-data-weights": [ - 1.0, - ], - # 10088908154 - "train-iters": 4462, -} \ No newline at end of file diff --git a/configs/datasets/train/taskwise/Wikipedia.yml b/configs/datasets/train/taskwise/Wikipedia.yml deleted file mode 100644 index 9c087e8dc..000000000 --- a/configs/datasets/train/taskwise/Wikipedia.yml +++ /dev/null @@ -1,11 +0,0 @@ -{ - # or for weighted datasets: - "train-data-paths": [ - 'data/slim_pajama/train_300B/Wikipedia/Wikipedia', - ], - "train-data-weights": [ - 1.0, - ], - # 11963032160 - "train-iters": 5291, -} \ No newline at end of file diff --git a/configs/datasets/val/pile_german.yml b/configs/datasets/val/pile_german.yml deleted file mode 100644 index 1496469b9..000000000 --- a/configs/datasets/val/pile_german.yml +++ /dev/null @@ -1,15 +0,0 @@ -{ - "test-data-paths": ["data/pile/test/pile_test_text_document"], - "test-data-weights": [ - 1. - ], - "valid-data-paths": [ - ["data/pile/val/pile_val_text_document"], - ["data/german/val/val"], - ], - "valid-data-weights": [ - [1.], - [1.], - ], - "val-dataset-name": 'pile_german', -} diff --git a/configs/datasets/val/pile_rp-no-se.yml b/configs/datasets/val/pile_rp-no-se.yml deleted file mode 100644 index 601b742fe..000000000 --- a/configs/datasets/val/pile_rp-no-se.yml +++ /dev/null @@ -1,36 +0,0 @@ - - "test-data-paths": ["data/red_pajama_400B/the_pile/test_tokenized_text_document"], - "test-data-weights": [ - 1. - ], - "valid-data-paths": [ - ["data/red_pajama_400B/the_pile/val_tokenized_text_document"], - [ - "data/red_pajama_400B/arxiv/folder_val/tokenized_text_document", - "data/red_pajama_400B/book/folder_val/tokenized_text_document", - "data/red_pajama_400B/c4/folder_val/tokenized_text_document", - "data/red_pajama_400B/wikipedia/folder_val/tokenized_text_document", - "data/red_pajama_400B/github/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2019-30/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2020-05/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2021-04/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2022-05/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2023-06/folder_val/tokenized_text_document", - ], - ], - "valid-data-weights": [ - [1.], - [ - 2.5, - 4.5, - 15.0, - 4.5, - 4.5, - 13.4, - 13.4, - 13.4, - 13.4, - 13.4 - ], - "val-dataset-name": 'pile_rp-no-se', - ], \ No newline at end of file diff --git a/configs/datasets/val/pile_rp.yml b/configs/datasets/val/pile_rp.yml deleted file mode 100644 index 16e368b48..000000000 --- a/configs/datasets/val/pile_rp.yml +++ /dev/null @@ -1,40 +0,0 @@ -{ - "test-data-paths": ["data/red_pajama_400B/the_pile/test_tokenized_text_document"], - "test-data-weights": [ - 1. - ], - "valid-data-paths": [ - ["data/red_pajama_400B/the_pile/val_tokenized_text_document"], - [ - "data/red_pajama_400B/arxiv/folder_val/tokenized_text_document", - "data/red_pajama_400B/book/folder_val/tokenized_text_document", - "data/red_pajama_400B/c4/folder_val/tokenized_text_document", - "data/red_pajama_400B/wikipedia/folder_val/tokenized_text_document", - "data/red_pajama_400B/github/folder_val/tokenized_text_document", - "data/red_pajama_400B/stackexchange/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2019-30/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2020-05/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2021-04/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2022-05/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2023-06/folder_val/tokenized_text_document", - ], - ], - "valid-data-weights": [ - [1.], - [ - 2.5, - 4.5, - 15.0, - 4.5, - 4.5, - 2.0, - 13.4, - 13.4, - 13.4, - 13.4, - 13.4 - ], - ], - "val-dataset-name": 'pile_rp', - - } \ No newline at end of file diff --git a/configs/datasets/val/pile_rp_subsets.yml b/configs/datasets/val/pile_rp_subsets.yml deleted file mode 100644 index 005417c49..000000000 --- a/configs/datasets/val/pile_rp_subsets.yml +++ /dev/null @@ -1,62 +0,0 @@ - -{ - "test-data-paths": ["data/red_pajama_400B/the_pile/test_tokenized_text_document"], - "test-data-weights": [ - 1. - ], - "valid-data-paths": [ - ["data/red_pajama_400B/the_pile/val_tokenized_text_document"], - [ - "data/red_pajama_400B/arxiv/folder_val/tokenized_text_document", - "data/red_pajama_400B/book/folder_val/tokenized_text_document", - "data/red_pajama_400B/c4/folder_val/tokenized_text_document", - "data/red_pajama_400B/wikipedia/folder_val/tokenized_text_document", - "data/red_pajama_400B/github/folder_val/tokenized_text_document", - "data/red_pajama_400B/stackexchange/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2019-30/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2020-05/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2021-04/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2022-05/folder_val/tokenized_text_document", - "data/red_pajama_400B/common_crawl/2023-06/folder_val/tokenized_text_document", - ], - ["data/red_pajama_400B/arxiv/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/book/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/c4/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/wikipedia/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/github/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/stackexchange/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/common_crawl/2019-30/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/common_crawl/2020-05/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/common_crawl/2021-04/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/common_crawl/2022-05/folder_val/tokenized_text_document"], - ["data/red_pajama_400B/common_crawl/2023-06/folder_val/tokenized_text_document"], - ], - "valid-data-weights": [ - [1.], - [ - 2.5, - 4.5, - 15.0, - 4.5, - 4.5, - 2.0, - 13.4, - 13.4, - 13.4, - 13.4, - 13.4 - ], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - ], - "val-dataset-name": 'pile_rp_subsets', -} \ No newline at end of file diff --git a/configs/datasets/val/pile_slimp.yml b/configs/datasets/val/pile_slimp.yml deleted file mode 100644 index 6956b4a9d..000000000 --- a/configs/datasets/val/pile_slimp.yml +++ /dev/null @@ -1,15 +0,0 @@ -{ - "test-data-paths": ["data/pile/test/pile_test_text_document"], - "test-data-weights": [ - 1. - ], - "valid-data-paths": [ - ["data/pile/val/pile_val_text_document"], - ["data/slim_pajama/val/all/sp_val"], - ], - "valid-data-weights": [ - [1.], - [1.], - ], - "val-dataset-name": 'pile_slimp', -} diff --git a/configs/datasets/val/pile_slimp_domains.yml b/configs/datasets/val/pile_slimp_domains.yml deleted file mode 100644 index 344abe95b..000000000 --- a/configs/datasets/val/pile_slimp_domains.yml +++ /dev/null @@ -1,30 +0,0 @@ -{ - "test-data-paths": ["data/pile/test/pile_test_text_document"], - "test-data-weights": [ - 1. - ], - "valid-data-paths": [ - ["data/pile/val/pile_val_text_document"], - ["data/slim_pajama/val/all/sp_val"], - ["data/slim_pajama/val/ArXiv/tokenized_arxiv_text_document"], - ["data/slim_pajama/val/Book/tokenized_book_text_document"], - ["data/slim_pajama/val/C4/tokenized_c4_text_document"], - ["data/slim_pajama/val/Wikipedia/tokenized_wikipedia_text_document"], - ["data/slim_pajama/val/Github/tokenized_github_text_document"], - ["data/slim_pajama/val/StackExchange/tokenized_stackexchange_text_document"], - ["data/slim_pajama/val/CommonCrawl/tokenized_commoncrawl_text_document"], - ], - "valid-data-weights": [ - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - ], - "val-dataset-name": 'pile_slimp_domains', - -} \ No newline at end of file diff --git a/configs/datasets/val/pile_slimp_workshop.yml b/configs/datasets/val/pile_slimp_workshop.yml deleted file mode 100644 index 93530eccd..000000000 --- a/configs/datasets/val/pile_slimp_workshop.yml +++ /dev/null @@ -1,32 +0,0 @@ -{ - "test-data-paths": ["data/red_pajama_400B/the_pile/test_tokenized_text_document"], - "test-data-weights": [ - 1. - ], - "valid-data-paths": [ - ["data/red_pajama_400B/the_pile/val_tokenized_text_document"], - [ - "data/tokenized300B/val_splits/ArXiv/tokenized_text_document", - "data/tokenized300B/val_splits/Book/tokenized_text_document", - "data/tokenized300B/val_splits/C4/tokenized_text_document", - "data/tokenized300B/val_splits/Wikipedia/tokenized_text_document", - "data/tokenized300B/val_splits/Github/tokenized_text_document", - "data/tokenized300B/val_splits/StackExchange/tokenized_text_document", - "data/tokenized300B/val_splits/CommonCrawl/tokenized_text_document", - ], - ], - "valid-data-weights": [ - [1.], - [ - 2.5, - 4.5, - 15.0, - 4.5, - 4.5, - 2.0, - 67.0 - ], - ], - "val-dataset-name": 'pile_slimp_workshop', - -} diff --git a/configs/datasets/val/pile_val.yml b/configs/datasets/val/pile_val.yml deleted file mode 100644 index bea136935..000000000 --- a/configs/datasets/val/pile_val.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ - "test-data-paths": ["data/pile/test/pile_test_text_document"], - "test-data-weights": [ - 1. - ], - "valid-data-paths": [ - ["data/pile/val/pile_val_text_document"], - ], - "valid-data-weights": [ - [1.], - ], - "val-dataset-name": 'pile_val', -} diff --git a/configs/datasets/val/pile_val_shard0.yml b/configs/datasets/val/pile_val_shard0.yml deleted file mode 100644 index 704e1d2e9..000000000 --- a/configs/datasets/val/pile_val_shard0.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ - "test-data-paths": ["data/pile/shard_0/shard_0_text_document"], - "test-data-weights": [ - 1. - ], - "valid-data-paths": [ - ["data/pile/shard_0/shard_0_text_document"], - ], - "valid-data-weights": [ - [1.], - ], - "val-dataset-name": 'pile_val_shard0', -} diff --git a/configs/iclr_models/3b_test.yml b/configs/iclr_models/3b_test.yml deleted file mode 100644 index 3bec093d9..000000000 --- a/configs/iclr_models/3b_test.yml +++ /dev/null @@ -1,101 +0,0 @@ -# GPT-2 pretraining setup -{ - # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages - # across the node boundaries ) - "pipe-parallel-size": 1, - "model-parallel-size": 6, # one copy of the model per node - - # model settings - - - "num_layers": 32, - "hidden_size": 3072, - "num_attention_heads": 24, - "seq_length": 2048, - "max_position_embeddings": 2048, - "pos_emb": "rotary", - "rotary_pct": 0.25, - "no_weight_tying": true, - "gpt_j_residual": true, - "output_layer_parallelism": "column", - - "attention_config": [[["global"], 32]], - - "scaled_upper_triang_masked_softmax_fusion": true, - "bias_gelu_fusion": true, - - # init methods - "init_method": "small_init", - "output_layer_init_method": "wang_init", - - #optimizer settings - # "optimizer": { - # "type": "Adam", - # "params": { - # "lr": 0.00012, - # "betas": [0.9, 0.95], - # "eps": 1.0e-8, - # } - - # }, - # "min_lr": 0.000012, - # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training - "zero_optimization": { - "stage": 1, - "allgather_partitions": True, - "allgather_bucket_size": 500000000, - "overlap_comm": True, - "reduce_scatter": True, - "reduce_bucket_size": 500000000, - "contiguous_gradients": True, - }, - - # batch / data settings - #"train_batch_size": 1, # across 1024 nodes... fingers crossed - "train_micro_batch_size_per_gpu": 8, - #"gradient_accumulation_steps": 2, - "gradient_accumulation_steps": 2, - # "gradient_accumulation_steps": 8, - "data-impl": "mmap", - "split": "949,50,1", - - # activation checkpointing - "checkpoint-activations": true, - "checkpoint-num-layers": 1, - "partition-activations": true, - "synchronize-each-layer": true, - - # regularization - "gradient_clipping": 1.0, - "weight-decay": 0.1, - "hidden-dropout": 0.0, - "attention-dropout": 0.0, - - # precision settings - "fp16": { - "enabled": true, - # "type": "bfloat16", # set bf16 as precision - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - - # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 - # misc. training settings - # "train-iters": 250000, - # "lr-decay-iters": 250000, - "distributed-backend": "nccl", - # "lr-decay-style": "cosine", - # "warmup": 0.01, - "checkpoint-factor": 1000, - "eval-interval": 1000, - "eval-iters": 10, - - # logging - "log-interval": 1, - "steps_per_print": 1, - "keep-last-n-checkpoints": 1000, - "wall_clock_breakdown": true, - -} diff --git a/configs/iclr_models/410M.yml b/configs/iclr_models/410M.yml deleted file mode 100644 index eb4b51994..000000000 --- a/configs/iclr_models/410M.yml +++ /dev/null @@ -1,99 +0,0 @@ -# GPT-2 pretraining setup -{ - #identifier string for this config used while logging - "identifier_string": "410M", - - # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages - # across the node boundaries ) - "pipe-parallel-size": 1, - "model-parallel-size": 1, # one copy of the model per node - - # model settings - "num-layers": 24, - "hidden-size": 1024, - "seq-length": 2048, - "num-attention-heads": 16, - "max-position-embeddings": 2048, - "pos-emb": "rotary", - "rotary-pct": 0.25, - "no-weight-tying": true, - "gpt-j-residual": true, - "output-layer-parallelism": "column", - - # these should provide some speedup but takes a while to build, set to true if desired - "scaled-upper-triang-masked-softmax-fusion": true, - "bias-gelu-fusion": true, - - # init methods - "init_method": "small_init", - "output_layer_init_method": "wang_init", - - # "optimizer": { - # "type": "Adam", - # "params": { - # "lr": 3.0e-4, - # "betas": [0.9, 0.95], - # "eps": 1.0e-8, - # } - # }, - # "min_lr": 3.0e-5, - - "zero_optimization": { - "stage": 1, - "allgather_partitions": True, - "allgather_bucket_size": 500000000, - "overlap_comm": True, - "reduce_scatter": True, - "reduce_bucket_size": 500000000, - "contiguous_gradients": True, - "cpu_offload": False - }, - - # LLAMA Config - # batch / data settings - "train_batch_size": 1104, #1104, #1104, #1104, #1104, #1104 # approximately 2.2M batch size across 46 nodes - "train_micro_batch_size_per_gpu": 4, - "data-impl": "mmap", - "split": "949,50,1", - - # activation checkpointing - "checkpoint-activations": true, - "checkpoint-num-layers": 1, - "partition-activations": true, - "synchronize-each-layer": true, - - # regularization - "gradient_clipping": 1.0, - "weight-decay": 0.1, - "hidden-dropout": 0.0, - "attention-dropout": 0.0, - - # precision settings of LLaMa - "fp16": { - "enabled": true, - # "type": "bfloat16", # set bf16 as precision - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - - # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 - - - - "distributed-backend": "nccl", - # "lr-decay-style": "cosine", - # "warmup": 0.01, - "checkpoint-factor": 400, - "eval-interval": 100, - "warup-eval-interval": 50, - "eval-iters": 10, - - # logging - "log-interval": 1, - "steps_per_print": 1, - "keep-last-n-checkpoints": 1000, - "wall_clock_breakdown": true, - -} diff --git a/configs/iclr_models/410M_ckpt_100.yml b/configs/iclr_models/410M_ckpt_100.yml deleted file mode 100644 index c2934729a..000000000 --- a/configs/iclr_models/410M_ckpt_100.yml +++ /dev/null @@ -1,100 +0,0 @@ -# GPT-2 pretraining setup -{ - #identifier string for this config used while logging - "identifier_string": "410M", - - # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages - # across the node boundaries ) - "pipe-parallel-size": 1, - "model-parallel-size": 1, # one copy of the model per node - - # model settings - "num-layers": 24, - "hidden-size": 1024, - "seq-length": 2048, - "num-attention-heads": 16, - "max-position-embeddings": 2048, - "pos-emb": "rotary", - "rotary-pct": 0.25, - "no-weight-tying": true, - "gpt-j-residual": true, - "output-layer-parallelism": "column", - - # these should provide some speedup but takes a while to build, set to true if desired - "scaled-upper-triang-masked-softmax-fusion": true, - "bias-gelu-fusion": true, - - # init methods - "init_method": "small_init", - "output_layer_init_method": "wang_init", - - # "optimizer": { - # "type": "Adam", - # "params": { - # "lr": 3.0e-4, - # "betas": [0.9, 0.95], - # "eps": 1.0e-8, - # } - # }, - # "min_lr": 3.0e-5, - - "zero_optimization": { - "stage": 1, - "allgather_partitions": True, - "allgather_bucket_size": 500000000, - "overlap_comm": True, - "reduce_scatter": True, - "reduce_bucket_size": 500000000, - "contiguous_gradients": True, - "cpu_offload": False - }, - - # LLAMA Config - # batch / data settings - "train_batch_size": 1104, #1104, #1104, #1104, #1104, #1104 # approximately 2.2M batch size across 46 nodes - "train_micro_batch_size_per_gpu": 4, - "data-impl": "mmap", - "split": "949,50,1", - - # activation checkpointing - "checkpoint-activations": true, - "checkpoint-num-layers": 1, - "partition-activations": true, - "synchronize-each-layer": true, - - # regularization - "gradient_clipping": 1.0, - "weight-decay": 0.1, - "hidden-dropout": 0.0, - "attention-dropout": 0.0, - - # precision settings of LLaMa - "fp16": { - "enabled": true, - # "type": "bfloat16", # set bf16 as precision - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - - # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 - - - - "distributed-backend": "nccl", - # "lr-decay-style": "cosine", - # "warmup": 0.01, - "checkpoint-factor": 100, - "extra_save_iters": [4462, 5291, 5563, 5861, 6914], - "eval-interval": 100, - "warup-eval-interval": 50, - "eval-iters": 10, - - # logging - "log-interval": 1, - "steps_per_print": 1, - "keep-last-n-checkpoints": 1000, - "wall_clock_breakdown": true, - -} \ No newline at end of file diff --git a/configs/iclr_models/49M.yml b/configs/iclr_models/49M.yml deleted file mode 100644 index c07c1eca9..000000000 --- a/configs/iclr_models/49M.yml +++ /dev/null @@ -1,100 +0,0 @@ -# GPT-2 pretraining setup -{ - #identifier string for this config used while logging - "identifier_string": "410M", - - # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages - # across the node boundaries ) - "pipe-parallel-size": 1, - "model-parallel-size": 1, # one copy of the model per node - - # model settings - "num-layers": 10, - "hidden-size": 640, - "num-attention-heads": 10, - "seq-length": 2048, - "max-position-embeddings": 2048, - "pos-emb": "rotary", - "rotary-pct": 0.25, - "no-weight-tying": true, - "gpt-j-residual": true, - "output-layer-parallelism": "column", - - # these should provide some speedup but takes a while to build, set to true if desired - "scaled-upper-triang-masked-softmax-fusion": true, - "bias-gelu-fusion": true, - - # init methods - "init_method": "small_init", - "output_layer_init_method": "wang_init", - - # "optimizer": { - # "type": "Adam", - # "params": { - # "lr": 3.0e-4, - # "betas": [0.9, 0.95], - # "eps": 1.0e-8, - # } - # }, - # "min_lr": 3.0e-5, - - "zero_optimization": { - "stage": 1, - "allgather_partitions": True, - "allgather_bucket_size": 500000000, - "overlap_comm": True, - "reduce_scatter": True, - "reduce_bucket_size": 500000000, - "contiguous_gradients": True, - "cpu_offload": False - }, - - # LLAMA Config - # batch / data settings - # "train_batch_size": 1104, #1104, #1104, #1104, #1104, #1104 # approximately 2.2M batch size across 46 nodes - "train_micro_batch_size_per_gpu": 16, - 'gradient_accumulation_steps': 1, - "data-impl": "mmap", - "split": "949,50,1", - - # activation checkpointing - "checkpoint-activations": true, - "checkpoint-num-layers": 1, - "partition-activations": true, - "synchronize-each-layer": true, - - # regularization - "gradient_clipping": 1.0, - "weight-decay": 0.1, - "hidden-dropout": 0.0, - "attention-dropout": 0.0, - - # precision settings of LLaMa - "fp16": { - "enabled": true, - "type": "float16", # set bf16 as precision - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - - # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 - - - - "distributed-backend": "nccl", - # "lr-decay-style": "cosine", - # "warmup": 0.01, - "checkpoint-factor": 20, - "eval-interval": 10, - "warup-eval-interval": 50, - "eval-iters": 10, - - # logging - "log-interval": 1, - "steps_per_print": 1, - "keep-last-n-checkpoints": 1000, - "wall_clock_breakdown": true, - -} diff --git a/configs/iclr_models/7_1B.yml b/configs/iclr_models/7_1B.yml deleted file mode 100644 index 2d5ebf615..000000000 --- a/configs/iclr_models/7_1B.yml +++ /dev/null @@ -1,94 +0,0 @@ -{ - #identifier string for this config used while logging - "identifier_string": "7-1B", - - "pipe-parallel-size": 4, - "model-parallel-size": 6, - - "num-layers": 36, - "hidden-size": 4608, - "seq-length": 2048, - "num-attention-heads": 36, - "max_position_embeddings": 2048, - "pos_emb": "rotary", - "rotary_pct": 0.25, - "no_weight_tying": true, - "gpt_j_residual": true, - "output_layer_parallelism": "column", - "attention_config": [[["global"], 36]], - - "scaled_upper_triang_masked_softmax_fusion": true, - "bias_gelu_fusion": true, - - - "init_method": "small_init", - "output_layer_init_method": "wang_init", - - #optimizer settings - # "optimizer": { - # "type": "Adam", - # "params": { - # "lr": 0.00012, - # "betas": [0.9, 0.95], - # "eps": 1.0e-8, - # } - - # "min_lr": 0.000012, - # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training - "zero_optimization": { - "stage": 1, - "allgather_partitions": True, - "allgather_bucket_size": 500000000, - "overlap_comm": True, - "reduce_scatter": True, - "reduce_bucket_size": 500000000, - "contiguous_gradients": True, - }, - - # batch / data settings - #"train_batch_size": 1, # across 1024 nodes... fingers crossed - "train_micro_batch_size_per_gpu": 4, - #"gradient_accumulation_steps": 2, - # "gradient_accumulation_steps": 8, - "gradient_accumulation_steps": 4, - "data-impl": "mmap", - - - # activation checkpointing - "checkpoint-activations": true, - "checkpoint-num-layers": 1, - "partition-activations": true, - "synchronize-each-layer": true, - - # regularization - "gradient_clipping": 1.0, - "weight-decay": 0.1, - "hidden-dropout": 0.0, - "attention-dropout": 0.0, - - # precision settings - "fp16": { - "enabled": true, - # "type": "bfloat16", # set bf16 as precision - "loss_scale": 0, - "loss_scale_window": 1000, - "hysteresis": 2, - "min_loss_scale": 1 - }, - - # "fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32 - # misc. training settings - "distributed-backend": "nccl", - # "warmup": 0.01, - "checkpoint-factor": 100, - "extra-save-iters": [143051], # [0,1,2,4,8,16,32,64,128,256,512], - - "eval-interval": 100, - "eval-iters": 10, - - # logging - "log-interval": 10, - "steps_per_print": 10, - "wall_clock_breakdown": true, - -} \ No newline at end of file diff --git a/configs/load/3e-5const_0_410M_143_CPT.yml b/configs/load/3e-5const_0_410M_143_CPT.yml deleted file mode 100644 index 61d048e86..000000000 --- a/configs/load/3e-5const_0_410M_143_CPT.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load":"checkpoints/continued_slim_pajama/JOB-3061457_pythia-deduped-410M-iters-131296_warmup-0.0_max-lr-3e-05_min-lr-3e-05_pretrain_slim_pajama_resume" -} \ No newline at end of file diff --git a/configs/load/none.yml b/configs/load/none.yml deleted file mode 100644 index f975cb0cb..000000000 --- a/configs/load/none.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "none", -} \ No newline at end of file diff --git a/configs/load/pythia_2-8B_143000.yml b/configs/load/pythia_2-8B_143000.yml deleted file mode 100644 index f975cb0cb..000000000 --- a/configs/load/pythia_2-8B_143000.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "none", -} \ No newline at end of file diff --git a/configs/load/pythia_410m.yml b/configs/load/pythia_410m.yml deleted file mode 100644 index df769c23f..000000000 --- a/configs/load/pythia_410m.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/neox_converted/mp1_pp1/pythia", -} \ No newline at end of file diff --git a/configs/load/pythia_410m_10000.yml b/configs/load/pythia_410m_10000.yml deleted file mode 100644 index 76aba61b4..000000000 --- a/configs/load/pythia_410m_10000.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/neox_converted/mp1_pp1/pythia/410_10000", -} \ No newline at end of file diff --git a/configs/load/pythia_410m_143000.yml b/configs/load/pythia_410m_143000.yml deleted file mode 100644 index 2d703aa6c..000000000 --- a/configs/load/pythia_410m_143000.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/neox_converted/mp1_pp1/pythia/410_143000", -} \ No newline at end of file diff --git a/configs/load/pythia_410m_27000.yml b/configs/load/pythia_410m_27000.yml deleted file mode 100644 index 266ccd51f..000000000 --- a/configs/load/pythia_410m_27000.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/neox_converted/mp1_pp1/pythia/410_27000", -} \ No newline at end of file diff --git a/configs/load/pythia_6-9B_143000.yml b/configs/load/pythia_6-9B_143000.yml deleted file mode 100644 index f975cb0cb..000000000 --- a/configs/load/pythia_6-9B_143000.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "none", -} \ No newline at end of file diff --git a/configs/load/pythia_deduped_410m_10000.yml b/configs/load/pythia_deduped_410m_10000.yml deleted file mode 100644 index 4492ff72a..000000000 --- a/configs/load/pythia_deduped_410m_10000.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/neox_converted/mp1_pp1/pythia/410m_deduped_step10000", -} \ No newline at end of file diff --git a/configs/load/pythia_deduped_410m_143000.yml b/configs/load/pythia_deduped_410m_143000.yml deleted file mode 100644 index 5b46b4eb5..000000000 --- a/configs/load/pythia_deduped_410m_143000.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/neox_converted/mp1_pp1/pythia/410m_deduped_step143000", -} \ No newline at end of file diff --git a/configs/load/pythia_deduped_410m_27000.yml b/configs/load/pythia_deduped_410m_27000.yml deleted file mode 100644 index 2957457a7..000000000 --- a/configs/load/pythia_deduped_410m_27000.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/neox_converted/mp1_pp1/pythia/410m_deduped_step27000", -} \ No newline at end of file diff --git a/configs/load/resume_1-2e-4_001_7-1B_pile_PT.yml b/configs/load/resume_1-2e-4_001_7-1B_pile_PT.yml deleted file mode 100644 index bff295f9e..000000000 --- a/configs/load/resume_1-2e-4_001_7-1B_pile_PT.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ -"load":"checkpoints/cpt_iclr_2/JOB-3178176_7-1B_it-132366_wu-0.01_mxlr-0.00012_mnlr-1.2e-05_sch-cosine_tr-pile-train_scratch" -} \ No newline at end of file diff --git a/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml b/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml deleted file mode 100644 index 2cb78d8dc..000000000 --- a/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_CPT.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ -"load":"checkpoints/cpt_iclr_2/JOB-3199708_7-1B_it-132366_wu-0.01_mxlr-0.00012_mnlr-1.2e-05_sch-cosine_tr-slim-pajama-300B_finetune" -} \ No newline at end of file diff --git a/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_PT.yml b/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_PT.yml deleted file mode 100644 index ada5b3ead..000000000 --- a/configs/load/resume_1-2e-4_001_7-1B_slim_pajama_PT.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ -"load":"checkpoints/cpt_iclr_2/JOB-3199748_7-1B_it-132366_wu-0.01_mxlr-0.00012_mnlr-1.2e-05_sch-cosine_tr-slim-pajama-300B_scratch" -} \ No newline at end of file diff --git a/configs/load/resume_1-5e-4_001_410M_143_CPT.yml b/configs/load/resume_1-5e-4_001_410M_143_CPT.yml deleted file mode 100644 index 8ad24fba8..000000000 --- a/configs/load/resume_1-5e-4_001_410M_143_CPT.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/continued_slim_pajama/JOB-3046061_pythia-deduped-410M-iters-131296_warmup-0.01_max-lr-0.00015_min-lr-1.5e-05_pretrain_slim_pajama_resume", -} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_410M_143_CPT.yml b/configs/load/resume_3e-4_001_410M_143_CPT.yml deleted file mode 100644 index 87239b92f..000000000 --- a/configs/load/resume_3e-4_001_410M_143_CPT.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/continued_slim_pajama/JOB-3051279_pythia-deduped-410M-iters-131296_warmup-0.01_max-lr-0.0003_min-lr-3e-05_pretrain_slim_pajama_resume", -} diff --git a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r05.yml b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r05.yml deleted file mode 100644 index 22b44ad8b..000000000 --- a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r05.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ -"load":"checkpoints/cpt_iclr_2/JOB-3199939_410M_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-slim-pajama-300B-replay05_finetune" -} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r1.yml b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r1.yml deleted file mode 100644 index 9a0c93bc5..000000000 --- a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r1.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ -"load":"checkpoints/cpt_iclr_2/JOB-3200003_410M_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-slim-pajama-300B-replay1_finetune" -} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r10.yml b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r10.yml deleted file mode 100644 index b1527a45d..000000000 --- a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r10.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ -"load":"checkpoints/cpt_iclr_2/JOB-3199999_410M_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-slim-pajama-300B-replay10_finetune" -} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r5.yml b/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r5.yml deleted file mode 100644 index 4788d814e..000000000 --- a/configs/load/resume_3e-4_001_410M_slim_pajama_CPT_r5.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ -"load":"checkpoints/cpt_iclr_2/JOB-3200000_410M_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-slim-pajama-300B-replay5_finetune" -} \ No newline at end of file diff --git a/configs/load/resume_3e-4_001_7-1B_pile_PT.yml b/configs/load/resume_3e-4_001_7-1B_pile_PT.yml deleted file mode 100644 index a6ad9144d..000000000 --- a/configs/load/resume_3e-4_001_7-1B_pile_PT.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ -"load":"checkpoints/cpt_iclr/JOB-3156148_7-1B_it-132366_wu-0.01_mxlr-0.0003_mnlr-3e-05_sch-cosine_tr-pile-train_scratch" -} \ No newline at end of file diff --git a/configs/load/resume_6e-4_001_410M_143_CPT.yml b/configs/load/resume_6e-4_001_410M_143_CPT.yml deleted file mode 100644 index b4fb5571e..000000000 --- a/configs/load/resume_6e-4_001_410M_143_CPT.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/continued_slim_pajama/JOB-3047316_pythia-deduped-410M-iters-131296_warmup-0.01_max-lr-0.0006_min-lr-6e-05_pretrain_slim_pajama_resume", -} \ No newline at end of file diff --git a/configs/load/scratch.yml b/configs/load/scratch.yml deleted file mode 100644 index dd4c8ec8e..000000000 --- a/configs/load/scratch.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/continued_slim_pajama/JOB-3047440_pythia-deduped-410M-iters-131296_warmup-0.01_max-lr-0.0003_min-lr-3e-05_pretrain_slim_pajama_none", -} \ No newline at end of file diff --git a/configs/load/test_3e-5const_0_410M_143_CPT.yml b/configs/load/test_3e-5const_0_410M_143_CPT.yml deleted file mode 100644 index fd41eeee3..000000000 --- a/configs/load/test_3e-5const_0_410M_143_CPT.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ -"load":"checkpoints/continued_slim_pajama/JOB-3057769_pythia-deduped-410M-iters-131296_warmup-0.0_max-lr-3e-05_min-lr-3e-05_pretrain_slim_pajama_resume" -} \ No newline at end of file diff --git a/configs/load/wu_001_lr1-5e-4_pile.yml b/configs/load/wu_001_lr1-5e-4_pile.yml deleted file mode 100644 index adfab6c49..000000000 --- a/configs/load/wu_001_lr1-5e-4_pile.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/continued_test8/JOB-2966423_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.00015_min-lr-1.5e-05_finetune_pile", -} \ No newline at end of file diff --git a/configs/load/wu_001_lr3e-4_pile.yml b/configs/load/wu_001_lr3e-4_pile.yml deleted file mode 100644 index edc1bc825..000000000 --- a/configs/load/wu_001_lr3e-4_pile.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/continued_test8/JOB-2966421_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.0003_min-lr-3e-05_finetune_pile", -} \ No newline at end of file diff --git a/configs/load/wu_001_lr6e-4_pile.yml b/configs/load/wu_001_lr6e-4_pile.yml deleted file mode 100644 index 6d7357118..000000000 --- a/configs/load/wu_001_lr6e-4_pile.yml +++ /dev/null @@ -1,3 +0,0 @@ -{ - "load": "checkpoints/continued_test8/JOB-2966422_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.0006_min-lr-6e-05_finetune_pile", -} \ No newline at end of file diff --git a/configs/pythia_410m_llama_setup_finetune.yml b/configs/pythia_410m_llama_setup_finetune.yml deleted file mode 100644 index c01771008..000000000 --- a/configs/pythia_410m_llama_setup_finetune.yml +++ /dev/null @@ -1,24 +0,0 @@ -# Suggested data paths when using GPT-NeoX locally -{ - # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group. - # WARNING: setting this to True will override any user provided weights - # "weight_by_num_documents": false, - # "weighted_sampler_alpha": 0.3, - - "tokenizer-type": "HFTokenizer", - "vocab-file": "data/20B_tokenizer.json", - - "checkpoint_validation_with_forward_pass": False, - "use_wandb": False, - # "wandb_host": "https://api.wandb.ai", - - "launcher": "jsrun", - "deepspeed_jsrun": true, - "num_workers": 1, - "finetune": true, - - "save": "checkpoints/cpt_iclr_2", - "tensorboard-dir": "tensorboard/cpt_iclr_2", - "log-dir": "logs", - "wandb_project": "cpt_iclr_2", -} \ No newline at end of file diff --git a/configs/pythia_410m_llama_setup_resume.yml b/configs/pythia_410m_llama_setup_resume.yml deleted file mode 100644 index e49d1db20..000000000 --- a/configs/pythia_410m_llama_setup_resume.yml +++ /dev/null @@ -1,24 +0,0 @@ -# Suggested data paths when using GPT-NeoX locally -{ - # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group. - # WARNING: setting this to True will override any user provided weights - # "weight_by_num_documents": false, - # "weighted_sampler_alpha": 0.3, - - "tokenizer-type": "HFTokenizer", - "vocab-file": "data/20B_tokenizer.json", - - "checkpoint_validation_with_forward_pass": False, - "use_wandb": False, - # "wandb_host": "https://api.wandb.ai", - - "launcher": "openmpi", - # "deepspeed_jsrun": true, - "num_workers": 2, - "finetune": false, - - "save": "checkpoints/cpt_iclr_2", - "tensorboard-dir": "tensorboard/cpt_iclr_2", - "log-dir": "logs", - "wandb_project": "cpt_iclr_2", -} \ No newline at end of file diff --git a/configs/schedules/7_1B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml b/configs/schedules/7_1B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml deleted file mode 100644 index 2f563a1a3..000000000 --- a/configs/schedules/7_1B_adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml +++ /dev/null @@ -1,16 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 0.00012, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 0.000012, - "lr-decay-style": "inverse_sqrt_infinite", - "num_repeats": 1, - "warmup": 0.01, - "constant_iters_percent": 0.98, - "constant_lr": 0.000017, -} \ No newline at end of file diff --git a/configs/schedules/adam_constant_lr3e-4_3e-4_wu-001.yml b/configs/schedules/adam_constant_lr3e-4_3e-4_wu-001.yml deleted file mode 100644 index 638439c4d..000000000 --- a/configs/schedules/adam_constant_lr3e-4_3e-4_wu-001.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-4, - "lr-decay-style": "constant", # this will coincide with the else in AnnealingLR - "warmup": 0.01, -} \ No newline at end of file diff --git a/configs/schedules/adam_constant_lr3e-5_3e-5_wu-0.yml b/configs/schedules/adam_constant_lr3e-5_3e-5_wu-0.yml deleted file mode 100644 index 640728c5f..000000000 --- a/configs/schedules/adam_constant_lr3e-5_3e-5_wu-0.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-5, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-5, - "lr-decay-style": "constant", # this will coincide with the else in AnnealingLR - "warmup": 0., -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine-inf_lr3e-4_3e-5_wu-001.yml b/configs/schedules/adam_cosine-inf_lr3e-4_3e-5_wu-001.yml deleted file mode 100644 index 383e476d1..000000000 --- a/configs/schedules/adam_cosine-inf_lr3e-4_3e-5_wu-001.yml +++ /dev/null @@ -1,14 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-5, - "lr-decay-style": "cosine-inf", - "num_repeats": 3, - "warmup": 0.01, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml b/configs/schedules/adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml deleted file mode 100644 index f4451118d..000000000 --- a/configs/schedules/adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 1.2e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 1.2e-5, - "lr-decay-style": "cosine", - "warmup": 0.01, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml deleted file mode 100644 index 72e42e239..000000000 --- a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0.yml +++ /dev/null @@ -1,14 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 1.5e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 1.5e-5, - "lr-decay-style": "cosine", - "warmup": 0., - # "load": "/gpfs/alpine/csc499/scratch/btherien/gpt-neox/checkpoints/continued_test8/JOB-2959220_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.00015_min-lr-1.5e-05_finetune", -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml deleted file mode 100644 index 58ffad137..000000000 --- a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-0005.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 1.5e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 1.5e-5, - "lr-decay-style": "cosine", - "warmup": 0.005, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml deleted file mode 100644 index 4780250d3..000000000 --- a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-001.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 1.5e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 1.5e-5, - "lr-decay-style": "cosine", - "warmup": 0.01, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml b/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml deleted file mode 100644 index 653d47d0f..000000000 --- a/configs/schedules/adam_cosine_lr1-5e-4_1-5e-5_wu-002.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 1.5e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 1.5e-5, - "lr-decay-style": "cosine", - "warmup": 0.02, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0.yml b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0.yml deleted file mode 100644 index 955708b5c..000000000 --- a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-5, - "lr-decay-style": "cosine", - "warmup": 0.0, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0005.yml b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0005.yml deleted file mode 100644 index c86695af7..000000000 --- a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-0005.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-5, - "lr-decay-style": "cosine", - "warmup": 0.005, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-001.yml b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-001.yml deleted file mode 100644 index 52944b4a5..000000000 --- a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-001.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-5, - "lr-decay-style": "cosine", - "warmup": 0.01, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-002.yml b/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-002.yml deleted file mode 100644 index d8da20aa5..000000000 --- a/configs/schedules/adam_cosine_lr3e-4_3e-5_wu-002.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-5, - "lr-decay-style": "cosine", - "warmup": 0.02, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0.yml b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0.yml deleted file mode 100644 index 776225363..000000000 --- a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0.yml +++ /dev/null @@ -1,14 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 6.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 6.0e-5, - "lr-decay-style": "cosine", - "warmup": 0.0, - # "load": "/gpfs/alpine/csc499/scratch/btherien/gpt-neox/checkpoints/continued_test8/JOB-2959221_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.0006_min-lr-6e-05_finetune" -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0005.yml b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0005.yml deleted file mode 100644 index 73f38a6cc..000000000 --- a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-0005.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 6.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 6.0e-5, - "lr-decay-style": "cosine", - "warmup": 0.005, -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-001.yml b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-001.yml deleted file mode 100644 index d3ae77ba9..000000000 --- a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-001.yml +++ /dev/null @@ -1,14 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 6.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 6.0e-5, - "lr-decay-style": "cosine", - "warmup": 0.01, - # "load": "/gpfs/alpine/csc499/scratch/btherien/gpt-neox/checkpoints/continued_test8/JOB-2959221_pythia-c-410M-iters-181793_warmup-0.01_max-lr-0.0006_min-lr-6e-05_finetune" -} \ No newline at end of file diff --git a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-002.yml b/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-002.yml deleted file mode 100644 index b01da6ed1..000000000 --- a/configs/schedules/adam_cosine_lr6e-4_6e-5_wu-002.yml +++ /dev/null @@ -1,13 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 6.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 6.0e-5, - "lr-decay-style": "cosine", - "warmup": 0.02, -} \ No newline at end of file diff --git a/configs/schedules/adam_infcos_lr3e-4_3e-5_wu-001.yml b/configs/schedules/adam_infcos_lr3e-4_3e-5_wu-001.yml deleted file mode 100644 index 53838af53..000000000 --- a/configs/schedules/adam_infcos_lr3e-4_3e-5_wu-001.yml +++ /dev/null @@ -1,17 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-5, - "lr-decay-style": "cosine_cooldown_infinite", - "warmup": 0.01, - "constant_lr": 0.000165, - "constant_iters_percent" : 0.85, - "cooldown_iters_percent" : 0.6, - "timescale" : 10, -} \ No newline at end of file diff --git a/configs/schedules/adam_infinv_lr3e-4_3e-5_wu-001.yml b/configs/schedules/adam_infinv_lr3e-4_3e-5_wu-001.yml deleted file mode 100644 index 9e833f4cc..000000000 --- a/configs/schedules/adam_infinv_lr3e-4_3e-5_wu-001.yml +++ /dev/null @@ -1,17 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-5, - "lr-decay-style": "inverse_sqrt_infinite", - "warmup": 0.01, - "constant_lr": 0.000165, - "constant_iters_percent" : 0.85, - "cooldown_iters_percent" : 0.6, - "timescale" : 10, -} \ No newline at end of file diff --git a/configs/schedules/adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml b/configs/schedules/adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml deleted file mode 100644 index 0e7134ec3..000000000 --- a/configs/schedules/adam_inv-inf_lr3e-4_8e-5_3e-5_wu-001.yml +++ /dev/null @@ -1,16 +0,0 @@ -{ -"optimizer": { - "type": "Adam", - "params": { - "lr": 3.0e-4, - "betas": [0.9, 0.95], - "eps": 1.0e-8, - } - }, - "min_lr": 3.0e-5, - "lr-decay-style": "inverse_sqrt_infinite", - "num_repeats": 1, - "warmup": 0.01, - "constant_iters_percent": 0.98, - "constant_lr": 8.0e-5, -} \ No newline at end of file diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index c1119c89c..65347e24f 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -565,98 +565,6 @@ def build_train_valid_test_data_iterators(neox_args): return train_data_iterator, valid_data_iterator, test_data_iterator -def build_validation_iterator(neox_args): - """XXX""" - - valid_dataloader = None - - print_rank_0("> building validation ...") - - # Ensure only the first/last pipeline stages have data loaders - if neox_args.is_pipe_parallel: - is_first_stage = mpu.get_pipe_parallel_rank() == 0 - is_last_stage = ( - mpu.get_pipe_parallel_rank() == mpu.get_pipe_parallel_world_size() - 1 - ) - pipe_load = is_first_stage or is_last_stage - else: - pipe_load = True - - # Data loader only on rank 0 of each model parallel group. - if mpu.get_model_parallel_rank() == 0 and pipe_load: - # Number of train/valid/test samples. - train_iters = neox_args.train_iters - eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters - test_iters = neox_args.eval_iters - num_samples = eval_iters * neox_args.train_batch_size - - valid_weights, valid_num_samples = get_normalized_weights_and_num_samples( - neox_args.valid_data_weights, num_samples - ) - - # build individual datasets - _, valid_datasets, _ = build_weighted_datasets( - neox_args, - 0, - valid_num_samples, - 0, - 0, - valid_weights, - 0, - build_index_mappings=not neox_args.weight_by_num_documents, - concatenate_train_replay_paths=False, - ) - - # if neox_args.weight_by_num_documents: # Not supported for now - - if valid_datasets: - valid_ds = BlendableDataset(valid_datasets, valid_weights) - valid_dataloader = make_data_loader(valid_ds, neox_args=neox_args) - - # Flags to know if we need to do training/validation/testing. - do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 - - # Need to broadcast num_tokens and num_type_tokens. - flags = torch.cuda.LongTensor([0, int(do_valid), 0]) - else: - flags = torch.cuda.LongTensor([0, 0, 0]) - - # Broadcast num tokens. - if neox_args.is_pipe_parallel: - # Only first/last pipeline stages have data loaders, so pipeline parallelism should - # broadcast globally instead of just the model parallel group. - torch.distributed.broadcast(flags, src=0) - else: - torch.distributed.broadcast( - flags, - mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group(), - ) - neox_args.do_train = flags[0].item() - neox_args.do_valid = flags[1].item() - neox_args.do_test = flags[2].item() - - if valid_dataloader is not None: - start_iter_val = ( - (neox_args.iteration * neox_args.gradient_accumulation_steps) - // neox_args.eval_interval - ) * neox_args.eval_iters - valid_dataloader.batch_sampler.start_iter = start_iter_val % len( - valid_dataloader - ) - print_rank_0( - "setting validation data start iteration to {}".format( - valid_dataloader.batch_sampler.start_iter - ) - ) - - if valid_dataloader is not None: - valid_data_iterator = iter(valid_dataloader) - else: - valid_data_iterator = None - - return valid_data_iterator - def compile_helper(): """Compile helper function at runtime. Make sure this is invoked on a single process.""" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index ea13659a8..037fe2079 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -467,16 +467,6 @@ class NeoXArgsModel(NeoXArgsTemplate): Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ - identifier_string: str = "" - """ - an identifier for the model, used for saving checkpoints,logging, etc. - """ - - warup_eval_interval: int = 50 - """ - the evaluation interval to use during warmup - """ - @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): @@ -1285,16 +1275,6 @@ class NeoXArgsTraining(NeoXArgsTemplate): What to scale width by when creating the delta model for mup """ - train_dataset_name: str = "no_train_dataset_name_given" - """ - An identified for the training dataset used for logging - """ - - val_dataset_name: str = "no_val_dataset_name_given" - """ - An identified for the training dataset used for logging - """ - @dataclass class NeoXArgsTextgen(NeoXArgsTemplate): diff --git a/megatron/training.py b/megatron/training.py index bfe8c90a8..1ee5d256d 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -25,7 +25,6 @@ import math import sys from contextlib import nullcontext -import copy import torch import deepspeed @@ -46,7 +45,7 @@ get_params_for_weight_decay_optimization, ) from megatron.checkpointing import load_checkpoint, save_checkpoint -from megatron.data.data_utils import build_train_valid_test_data_iterators, build_validation_iterator +from megatron.data.data_utils import build_train_valid_test_data_iterators from megatron.initialize import initialize_megatron from megatron.learning_rates import AnnealingLR from megatron.logging import tb_wandb_log, training_log @@ -198,18 +197,6 @@ def pretrain(neox_args): ) timers("model and optimizer").stop() - tensorboard_writer = neox_args.tensorboard_writer - neox_args.tensorboard_writer = None - neox_args_val = copy.deepcopy(neox_args) - neox_args.tensorboard_writer = tensorboard_writer - neox_args_val.train_data_paths = [None] - neox_args_val.test_data_paths = [None] - neox_args.valid_data_paths = neox_args.valid_data_paths[0] - neox_args.valid_data_weights = neox_args.valid_data_weights[0] - - - print(neox_args.is_replay_enabled) - # Data stuff. timers("train/valid/test data iterators").start() ( @@ -217,14 +204,6 @@ def pretrain(neox_args): valid_data_iterator, test_data_iterator, ) = build_train_valid_test_data_iterators(neox_args=neox_args) - val_iters = [valid_data_iterator] - if neox_args_val.valid_data_paths is not None and len(neox_args_val.valid_data_paths) > 1: - for i in range(1, len(neox_args_val.valid_data_paths)): - temp_copy = copy.deepcopy(neox_args_val) - temp_copy.valid_data_paths = temp_copy.valid_data_paths[i] - temp_copy.valid_data_weights = temp_copy.valid_data_weights[i] - temp_copy.num_workers = 0 - val_iters.append(build_validation_iterator(neox_args=temp_copy)) timers("train/valid/test data iterators").stop() if neox_args.use_mup and neox_args.coord_check: @@ -258,20 +237,17 @@ def pretrain(neox_args): ) if neox_args.do_valid: - prefix = "the start of training for val data" - for i in range(len(val_iters)): - print_rank_0("in if neox_args.do_valid for val_iters[i]",i, val_iters[i]) - evaluate_and_print_results( - neox_args=neox_args, - prefix=prefix, - forward_step_func=forward_step, - data_iterator=val_iters[i], - model=model, - iteration=iteration, - verbose=False, - timers=timers, - eval_name=f"val_{i}", - ) + prefix = "the end of training for val data" + evaluate_and_print_results( + neox_args=neox_args, + prefix=prefix, + forward_step_func=forward_step, + data_iterator=valid_data_iterator, + model=model, + iteration=iteration, + verbose=False, + timers=timers, + ) if neox_args.save and iteration != 0: save_checkpoint( diff --git a/megatron_config_1.json b/megatron_config_1.json deleted file mode 100644 index e92fb185d..000000000 --- a/megatron_config_1.json +++ /dev/null @@ -1 +0,0 @@ -{"launcher": "jsrun", "train_batch_size": 276, "train_micro_batch_size_per_gpu": 138, "optimizer": {"type": "Adam", "params": {"lr": 0.00012, "betas": [0.9, 0.95], "eps": 1e-08}}, "fp16": {"enabled": true, "loss_scale": 0, "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1}, "gradient_clipping": 1.0, "zero_optimization": {"stage": 1, "allgather_partitions": true, "allgather_bucket_size": 500000000, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": 500000000, "contiguous_gradients": true, "cpu_offload": false}, "steps_per_print": 1, "wall_clock_breakdown": true, "precision": "fp16", "num_layers": 10, "hidden_size": 640, "num_attention_heads": 10, "seq_length": 2048, "max_position_embeddings": 2048, "pos_emb": "rotary", "no_weight_tying": true, "attention_config": ["global", "global", "global", "global", "global", "global", "global", "global", "global", "global"], "sparsity_config": {}, "scaled_upper_triang_masked_softmax_fusion": true, "bias_gelu_fusion": true, "rotary_pct": 0.25, "init_method": "small_init", "output_layer_init_method": "wang_init", "gpt_j_residual": true, "output_layer_parallelism": "column", "identifier_string": "410M", "lr_decay_style": "cosine", "lr_decay_iters": 132366, "min_lr": 1.2e-05, "optimizer_type": "Adam", "zero_stage": 1, "zero_reduce_scatter": true, "zero_contiguous_gradients": true, "zero_reduce_bucket_size": 500000000, "zero_allgather_bucket_size": 500000000, "lr": 0.00012, "tokenizer_type": "HFTokenizer", "train_data_paths": ["data/pile/train/pile_train"], "test_data_paths": ["data/pile/test/pile_test_text_document"], "valid_data_paths": [["data/pile/val/pile_val_text_document"], ["data/slim_pajama/val/all/sp_val"]], "train_data_weights": [1.0], "valid_data_weights": [[1.0], [1.0]], "test_data_weights": [1.0], "data_impl": "mmap", "save": "checkpoints/cpt_iclr_2", "config_files": {"pythia_410m_llama_setup_resume.yml": "# Suggested data paths when using GPT-NeoX locally\n{\n # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group.\n # WARNING: setting this to True will override any user provided weights\n # \"weight_by_num_documents\": false,\n # \"weighted_sampler_alpha\": 0.3,\n\n \"tokenizer-type\": \"HFTokenizer\",\n \"vocab-file\": \"data/20B_tokenizer.json\",\n\n \"checkpoint_validation_with_forward_pass\": False,\n \"use_wandb\": False,\n # \"wandb_host\": \"https://api.wandb.ai\",\n\n \"launcher\": \"jsrun\",\n \"deepspeed_jsrun\": true,\n \"num_workers\": 1,\n \"finetune\": false,\n\n \"save\": \"checkpoints/cpt_iclr_2\",\n \"tensorboard-dir\": \"tensorboard/cpt_iclr_2\",\n \"log-dir\": \"logs\",\n \"wandb_project\": \"cpt_iclr_2\",\n}", "49M.yml": "# GPT-2 pretraining setup\n{\n #identifier string for this config used while logging\n \"identifier_string\": \"410M\",\n\n # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages\n # across the node boundaries )\n \"pipe-parallel-size\": 1,\n \"model-parallel-size\": 1, # one copy of the model per node\n\n # model settings\n \"num-layers\": 10,\n \"hidden-size\": 640,\n \"num-attention-heads\": 10,\n \"seq-length\": 2048,\n \"max-position-embeddings\": 2048,\n \"pos-emb\": \"rotary\",\n \"rotary-pct\": 0.25,\n \"no-weight-tying\": true,\n \"gpt-j-residual\": true,\n \"output-layer-parallelism\": \"column\",\n\n # these should provide some speedup but takes a while to build, set to true if desired\n \"scaled-upper-triang-masked-softmax-fusion\": true,\n \"bias-gelu-fusion\": true,\n\n # init methods\n \"init_method\": \"small_init\",\n \"output_layer_init_method\": \"wang_init\",\n\n # \"optimizer\": {\n # \"type\": \"Adam\",\n # \"params\": {\n # \"lr\": 3.0e-4,\n # \"betas\": [0.9, 0.95],\n # \"eps\": 1.0e-8,\n # }\n # },\n # \"min_lr\": 3.0e-5,\n\n \"zero_optimization\": {\n \"stage\": 1,\n \"allgather_partitions\": True,\n \"allgather_bucket_size\": 500000000,\n \"overlap_comm\": True,\n \"reduce_scatter\": True,\n \"reduce_bucket_size\": 500000000,\n \"contiguous_gradients\": True,\n \"cpu_offload\": False\n },\n\n # LLAMA Config\n # batch / data settings\n # \"train_batch_size\": 1104, #1104, #1104, #1104, #1104, #1104 # approximately 2.2M batch size across 46 nodes \n \"train_micro_batch_size_per_gpu\": 138,\n 'gas': 4,\n \"data-impl\": \"mmap\",\n \"split\": \"949,50,1\",\n\n # activation checkpointing\n \"checkpoint-activations\": true,\n \"checkpoint-num-layers\": 1,\n \"partition-activations\": true,\n \"synchronize-each-layer\": true,\n\n # regularization\n \"gradient_clipping\": 1.0,\n \"weight-decay\": 0.1,\n \"hidden-dropout\": 0.0,\n \"attention-dropout\": 0.0,\n\n # precision settings of LLaMa\n \"fp16\": {\n \"enabled\": true,\n # \"type\": \"bfloat16\", # set bf16 as precision\n \"loss_scale\": 0,\n \"loss_scale_window\": 1000,\n \"hysteresis\": 2,\n \"min_loss_scale\": 1\n },\n\n # \"fp32_allreduce\": True, # without a patch to torch, bf16 models have to do the allreduce in fp32\n\n\n\n \"distributed-backend\": \"nccl\",\n # \"lr-decay-style\": \"cosine\",\n # \"warmup\": 0.01,\n \"checkpoint-factor\": 400,\n \"eval-interval\": 100,\n \"warup-eval-interval\": 50,\n \"eval-iters\": 10,\n\n # logging\n \"log-interval\": 1,\n \"steps_per_print\": 1,\n \"keep-last-n-checkpoints\": 1000,\n \"wall_clock_breakdown\": true,\n\n}\n", "pile_train.yml": "{ \n \"train-data-paths\": [\n \"data/pile/train/pile_train\",\n ],\n \"train-data-weights\": [\n 1.,\n ],\n \"train-dataset-name\": 'pile_train',\n \"train-iters\": 132366,\n \"lr-decay-iters\": 132366,\n}", "pile_slimp.yml": "{\n \"test-data-paths\": [\"data/pile/test/pile_test_text_document\"],\n \"test-data-weights\": [\n 1.\n ],\n \"valid-data-paths\": [\n [\"data/pile/val/pile_val_text_document\"],\n [\"data/slim_pajama/val/all/sp_val\"],\n ],\n \"valid-data-weights\": [\n [1.],\n [1.],\n ],\n \"val-dataset-name\": 'pile_slimp',\n}\n", "none.yml": "{\n \"load\": \"none\",\n}", "adam_cosine_lr1-2e-4_1-2e-5_wu-001.yml": "{\n\"optimizer\": {\n \"type\": \"Adam\",\n \"params\": {\n \"lr\": 1.2e-4,\n \"betas\": [0.9, 0.95],\n \"eps\": 1.0e-8,\n }\n },\n \"min_lr\": 1.2e-5,\n \"lr-decay-style\": \"cosine\",\n \"warmup\": 0.01,\n}"}, "load": "none", "checkpoint_factor": 400, "batch_size": 138, "train_iters": 132366, "eval_iters": 10, "keep_last_n_checkpoints": 1000, "eval_interval": 100, "split": "949,50,1", "vocab_file": "data/20B_tokenizer.json", "num_workers": 1, "attention_dropout": 0.0, "hidden_dropout": 0.0, "weight_decay": 0.1, "checkpoint_activations": true, "synchronize_each_layer": true, "partition_activations": true, "gas": 1, "clip_grad": 1.0, "dynamic_loss_scale": true, "train_dataset_name": "pile_train", "val_dataset_name": "pile_slimp", "pipe_parallel_size": 1, "world_size": 1, "is_pipe_parallel": true, "use_wandb": false, "wandb_project": "cpt_iclr_2", "log_dir": "logs", "tensorboard_dir": "tensorboard/cpt_iclr_2", "log_interval": 1, "text_gen_type": "unconditional", "local_rank": 0, "rank": 0, "deepspeed_jsrun": true, "user_script": "train.py", "save_iters": [400, 800, 1200, 1600, 2000, 2400, 2800, 3200, 3600, 4000, 4400, 4800, 5200, 5600, 6000, 6400, 6800, 7200, 7600, 8000, 8400, 8800, 9200, 9600, 10000, 10400, 10800, 11200, 11600, 12000, 12400, 12800, 13200, 13600, 14000, 14400, 14800, 15200, 15600, 16000, 16400, 16800, 17200, 17600, 18000, 18400, 18800, 19200, 19600, 20000, 20400, 20800, 21200, 21600, 22000, 22400, 22800, 23200, 23600, 24000, 24400, 24800, 25200, 25600, 26000, 26400, 26800, 27200, 27600, 28000, 28400, 28800, 29200, 29600, 30000, 30400, 30800, 31200, 31600, 32000, 32400, 32800, 33200, 33600, 34000, 34400, 34800, 35200, 35600, 36000, 36400, 36800, 37200, 37600, 38000, 38400, 38800, 39200, 39600, 40000, 40400, 40800, 41200, 41600, 42000, 42400, 42800, 43200, 43600, 44000, 44400, 44800, 45200, 45600, 46000, 46400, 46800, 47200, 47600, 48000, 48400, 48800, 49200, 49600, 50000, 50400, 50800, 51200, 51600, 52000, 52400, 52800, 53200, 53600, 54000, 54400, 54800, 55200, 55600, 56000, 56400, 56800, 57200, 57600, 58000, 58400, 58800, 59200, 59600, 60000, 60400, 60800, 61200, 61600, 62000, 62400, 62800, 63200, 63600, 64000, 64400, 64800, 65200, 65600, 66000, 66400, 66800, 67200, 67600, 68000, 68400, 68800, 69200, 69600, 70000, 70400, 70800, 71200, 71600, 72000, 72400, 72800, 73200, 73600, 74000, 74400, 74800, 75200, 75600, 76000, 76400, 76800, 77200, 77600, 78000, 78400, 78800, 79200, 79600, 80000, 80400, 80800, 81200, 81600, 82000, 82400, 82800, 83200, 83600, 84000, 84400, 84800, 85200, 85600, 86000, 86400, 86800, 87200, 87600, 88000, 88400, 88800, 89200, 89600, 90000, 90400, 90800, 91200, 91600, 92000, 92400, 92800, 93200, 93600, 94000, 94400, 94800, 95200, 95600, 96000, 96400, 96800, 97200, 97600, 98000, 98400, 98800, 99200, 99600, 100000, 100400, 100800, 101200, 101600, 102000, 102400, 102800, 103200, 103600, 104000, 104400, 104800, 105200, 105600, 106000, 106400, 106800, 107200, 107600, 108000, 108400, 108800, 109200, 109600, 110000, 110400, 110800, 111200, 111600, 112000, 112400, 112800, 113200, 113600, 114000, 114400, 114800, 115200, 115600, 116000, 116400, 116800, 117200, 117600, 118000, 118400, 118800, 119200, 119600, 120000, 120400, 120800, 121200, 121600, 122000, 122400, 122800, 123200, 123600, 124000, 124400, 124800, 125200, 125600, 126000, 126400, 126800, 127200, 127600, 128000, 128400, 128800, 129200, 129600, 130000, 130400, 130800, 131200, 131600, 132000], "global_num_gpus": 2} \ No newline at end of file diff --git a/train.py b/train.py index 17f6ef91c..2e4b09954 100644 --- a/train.py +++ b/train.py @@ -19,9 +19,6 @@ from megatron.neox_arguments import NeoXArgs from megatron.training import pretrain -import os -import numpy as np - def main(input_args=None, overwrite_values=None): neox_args = NeoXArgs.consume_neox_args( @@ -29,43 +26,6 @@ def main(input_args=None, overwrite_values=None): ) neox_args.configure_distributed_args() neox_args.build_tokenizer() # tokenizer needs to be build in training in order to set the padding vocab - - if neox_args.load.split('/')[-1].startswith('JOB'): - if 'scratch' in neox_args.load: - training_mode = 'scratch' - elif 'finetune' in neox_args.load: - training_mode = 'finetune' - else: - training_mode = 'resume' - - elif neox_args.load == 'none': - training_mode = 'scratch' - elif neox_args.finetune: - training_mode = 'finetune' - else: - training_mode = 'resume' - - dir_str = "JOB-{}_{}_it-{}_wu-{}_mxlr-{}_mnlr-{}_sch-{}_tr-{}_{}".format( - "ENTER_YOUR_JOBID_IN_TRAIN.PY",# os.environ['LSB_JOBID'], - neox_args.identifier_string.replace('_',"-"), - neox_args.train_iters, - neox_args.warmup, - neox_args.optimizer['params']['lr'], - neox_args.min_lr, - neox_args.lr_decay_style, - neox_args.train_dataset_name.replace('_',"-"), - training_mode) - - - - neox_args.tensorboard_dir = os.path.join(neox_args.tensorboard_dir, dir_str) - neox_args.save = os.path.join(neox_args.save, dir_str) - print("NEOX ARGS tensorboard_dir: ", neox_args.tensorboard_dir) - print("NEOX ARGS save: ", neox_args.save) - # exit(0) - - - neox_args.initialize_tensorboard_writer() # is initialized if tensorboard directory is defined pretrain(neox_args=neox_args) From 96226421ced04e8625ded2b4e8119bd14a4178a6 Mon Sep 17 00:00:00 2001 From: github-actions Date: Sun, 14 Apr 2024 20:40:57 +0000 Subject: [PATCH 8/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 34 +--------------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 37b16455e..0ab7a2e5c 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 2d33aaa + Default = 621ab25 current git hash of repository @@ -762,22 +762,6 @@ Model Arguments -- **identifier_string**: str - - Default = - - an identifier for the model, used for saving checkpoints,logging, etc. - - - -- **warup_eval_interval**: int - - Default = 50 - - the evaluation interval to use during warmup - - - ## NeoXArgsOptimizer Optimizer Arguments @@ -1987,22 +1971,6 @@ Training Arguments -- **train_dataset_name**: str - - Default = no_train_dataset_name_given - - An identified for the training dataset used for logging - - - -- **val_dataset_name**: str - - Default = no_val_dataset_name_given - - An identified for the training dataset used for logging - - - ## NeoXArgsDeepspeedConfig Args for deepspeed config