From e059614f0cd8b73b139c134ae6ecffce74965835 Mon Sep 17 00:00:00 2001 From: John Kamalu Date: Sat, 7 Dec 2024 19:54:01 -0800 Subject: [PATCH] ADLR/megatron-lm!2381 - sample index helper function, no unnecessary memory allocation, no unnecessary casting/copying Co-authored-by: Mcore Bot --- megatron/core/datasets/Makefile | 10 +- megatron/core/datasets/gpt_dataset.py | 62 +++++++--- megatron/core/datasets/helpers.cpp | 107 ++++++++++-------- megatron/core/datasets/helpers.py | 64 +++++++++++ setup.py | 2 +- tests/unit_tests/data/test_builder.py | 4 +- tests/unit_tests/data/test_gpt_dataset.py | 4 +- .../data/test_multimodal_dataset.py | 2 +- 8 files changed, 183 insertions(+), 72 deletions(-) create mode 100644 megatron/core/datasets/helpers.py diff --git a/megatron/core/datasets/Makefile b/megatron/core/datasets/Makefile index 8f9db76866..e745f52399 100644 --- a/megatron/core/datasets/Makefile +++ b/megatron/core/datasets/Makefile @@ -1,9 +1,13 @@ CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color CPPFLAGS += $(shell python3 -m pybind11 --includes) -LIBNAME = helpers + +LIBNAME = helpers_cpp LIBEXT = $(shell python3-config --extension-suffix) -default: $(LIBNAME)$(LIBEXT) +OUT = $(LIBNAME)$(LIBEXT) +SRC = helpers.cpp + +default: $(OUT) -%$(LIBEXT): %.cpp +$(OUT): $(SRC) $(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@ diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 115727de92..2eb7702b54 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -72,7 +72,8 @@ class GPTDataset(MegatronDataset): indexed_indices (numpy.ndarray): The set of the documents indices to expose - num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch. + num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When + None, build as many samples as correspond to one epoch. index_split (Split): The indexed_indices Split @@ -318,7 +319,8 @@ def _build_document_sample_shuffle_indices( -- A random permutation of index range of the sample index Returns: - Tuple[numpy.ndarray, numpy.ndarray]: The document index, the sample index, and the shuffle index + Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: The document index, the sample + index, and the shuffle index """ path_to_cache = self.config.path_to_cache if path_to_cache is None and not self.config.mock: @@ -327,10 +329,8 @@ def _build_document_sample_shuffle_indices( ) if path_to_cache: - get_path_to = lambda suffix: os.path.join( - path_to_cache, - f"{self.unique_description_hash}-{type(self).__name__}-{self.index_split.name}-{suffix}", - ) + base = f"{self.unique_description_hash}-{type(self).__name__}-{self.index_split.name}" + get_path_to = lambda affix: os.path.join(path_to_cache, f"{base}-{affix}") path_to_description = get_path_to("description.txt") path_to_document_index = get_path_to("document_index.npy") path_to_sample_index = get_path_to("sample_index.npy") @@ -427,11 +427,13 @@ def _build_document_sample_shuffle_indices( assert document_index.dtype == numpy.int32 assert self.dataset.sequence_lengths.dtype == numpy.int32 if len(document_index) * 2 > len(self.dataset.sequence_lengths): - # Heuristic: if "access density" of sequence_lengths is relatively high, - # force loading the mmap-ed array into memory by taking a copy. + # If "access density" of sequence_lengths is high, force load the mmap-ed array + # into memory by making a copy. + # # System performance benefits come from two aspects: - # 1. **sequentially** pre-loading the whole file if we're gonna read a large fraction anyways. - # 2. GIL is held when calling into c++ code; making the c++ func faster improves parallelism. + # 1. We sequentially pre-load the whole file, most of which we expect to read + # 2. The GIL is held when entering the c++ program, improving the speed of which + # improves parallelism sequence_lengths_for_cpp = self.dataset.sequence_lengths.copy() else: sequence_lengths_for_cpp = self.dataset.sequence_lengths @@ -467,7 +469,7 @@ def _build_document_sample_shuffle_indices( log_single_rank( logger, logging.WARNING, - f"Unable to save the {type(self).__name__} indexes because path_to_cache is None", + f"Unable to save {type(self).__name__} indexes because path_to_cache is None", ) t_end = time.time() @@ -592,7 +594,8 @@ def _build_shuffle_index( Args: num_samples (int): The size of the first shuffle range [0, num_samples) - total_size (int): The size of the entire index. If larger than 'num_samples', it defines the second shuffle range [num_samples, total_size) + total_size (int): The size of the entire index. If larger than 'num_samples', it defines + the second shuffle range [num_samples, total_size) numpy_random_state (numpy.random.RandomState): The NumPy random state @@ -635,7 +638,8 @@ def _get_ltor_masks_and_position_ids( eod_mask_loss (bool): Switch to enable the EOD mask loss - create_attention_mask (bool): Switch to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself. + create_attention_mask (bool): Switch to enable the attention masks generation. Can be + disabled if attention kernel generates masks by itself. Returns: torch.Tensor: Attention mask needed to be used for Attention @@ -691,10 +695,24 @@ def _get_ltor_masks_and_position_ids( class MockGPTLowLevelDataset: + """The mock GPT low level dataset + + This class is meant to generate tokenized data in the classic "Megatron-LM" GPT style. Notably, + we add the end of document token to each element indexed in __getitem__ + + Args: + tokenizer (MegatronTokenizer): The tokenizer the special token information of which we use + to augment the mock data. + """ seed: int = 0 + """The hard-coded random seed to use to set the NumPy RNG""" + size: int = 100000 + """The hard-coded number of samples to generate""" + max_sequence_length: int = 4096 + """The hard-coded max sequence length to generate""" def __init__(self, tokenizer: MegatronTokenizer) -> None: self.tokenizer = tokenizer @@ -714,6 +732,18 @@ def __getitem__(self, idx: int) -> numpy.number: return sample def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy.ndarray: + """This function is n abstraction over __getitem__ with support for slicing + + Args: + idx (int): The index into the dataset + + offset (int): The integer token offset in the sequence + + length (Optional[int]): The number of tokens to grab from the sequence + + Returns: + numpy.ndarray: The sequence tokens at the index + """ if length is None: length = self.sequence_lengths[idx] - offset return self[idx][offset : offset + length] @@ -723,7 +753,8 @@ class MockGPTDataset(GPTDataset): """The mock GPT dataset Args: - indexed_dataset (MockGPTLowLevelDataset): The MockGPTLowLevelDataset around which to build the MockGPTDataset + indexed_dataset (MockGPTLowLevelDataset): The MockGPTLowLevelDataset around which to build + the MockGPTDataset dataset_path (Optional[str]): This argument is of no consequence for the MockGPTDataset @@ -768,7 +799,8 @@ def build_low_level_dataset( """Abstract method implementation Args: - dataset_path (Optional[str]): This argument is of no consequence for the MockGPTLowLevelDataset + dataset_path (Optional[str]): This argument is of no consequence for the + MockGPTLowLevelDataset config (GPTDatasetConfig): The config diff --git a/megatron/core/datasets/helpers.cpp b/megatron/core/datasets/helpers.cpp index 0b05f09d7a..1a3e8448f3 100644 --- a/megatron/core/datasets/helpers.cpp +++ b/megatron/core/datasets/helpers.cpp @@ -139,19 +139,22 @@ void build_blending_indices(py::array_t &dataset_index, } } -py::array build_sample_idx(const py::array_t &sizes_, - const py::array_t &doc_idx_, - const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch, - const bool drop_last_partial_sequence = true, - const int add_extra_token_to_sequence = 1) -{ - /* Sample index (sample_idx) is used for gpt2 like dataset for which - the documents are flattened and the samples are built based on this - 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] - where [..., 0] contains the index into `doc_idx` and [..., 1] is the - starting offset in that document.*/ +template +py::array_t build_sample_idx( + const py::array_t &sizes_, + const py::array_t &document_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch, + const bool drop_last_partial_sequence = true, + const int add_extra_token_to_sequence = 1 +){ + /* + Sample index (sample_idx) is used for gpt2 like dataset for which the documents are flattened + and the samples are built based on this 1-D flatten array. It is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains the index into `doc_idx` and [..., 1] is + the starting offset in that document. + */ // Consistency checks. assert(seq_length > 1); @@ -160,83 +163,86 @@ py::array build_sample_idx(const py::array_t &sizes_, // Remove bound checks. auto sizes = sizes_.unchecked<1>(); - auto doc_idx = doc_idx_.unchecked<1>(); + auto document_idx = document_idx_.unchecked<1>(); - // Mapping and it's length (1D). + // Build the sample idx as a contiguous 1-D array of type T. int64_t num_samples = 0; - if (drop_last_partial_sequence == true) - { + if (drop_last_partial_sequence == true) { num_samples = (num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length; } - else - { + else { num_samples = ceil(float(num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length); } - int64_t *sample_idx = new int64_t[2 * (num_samples + 1)]; + T *sample_idx = new T[2 * (num_samples + 1)]; // Index into sample_idx. - int64_t sample_index = 0; - // Index into doc_idx. - int64_t doc_idx_index = 0; + int64_t sample_idx_index = 0; + // Index into document_idx. + T document_idx_index = 0; // Begining offset for each document. - int32_t doc_offset = 0; + T doc_offset = 0; // Start with first document and no offset. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; + sample_idx[2 * sample_idx_index] = document_idx_index; + sample_idx[2 * sample_idx_index + 1] = doc_offset; + ++sample_idx_index; - while (sample_index <= num_samples) + while (sample_idx_index <= num_samples) { // Start with a fresh sequence. int32_t remaining_seq_length = seq_length + add_extra_token_to_sequence; while (remaining_seq_length != 0) { // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; + auto document_index = document_idx[document_idx_index]; + auto document_length = sizes[document_index] - doc_offset; // And add it to the current sequence. - remaining_seq_length -= doc_length; + remaining_seq_length -= document_length; // If we have more than a full sequence, adjust offset and set // remaining length to zero so we return from the while loop. // Note that -1 here is for the same reason we have -1 in // `_num_epochs` calculations. if (remaining_seq_length <= 0) { - doc_offset += (remaining_seq_length + doc_length - add_extra_token_to_sequence); + doc_offset += (remaining_seq_length + document_length - add_extra_token_to_sequence); remaining_seq_length = 0; } else { // Otherwise, start from the begining of the next document. - if (doc_idx_index == (doc_idx_.shape(0) - 1)) + if (document_idx_index == (document_idx_.shape(0) - 1)) { // If we have reached the end of the documents, break. - assert(sample_index == num_samples); - doc_offset = sizes[doc_idx[doc_idx_index]] - add_extra_token_to_sequence; + assert(sample_idx_index == num_samples); + doc_offset = sizes[document_idx[document_idx_index]] - add_extra_token_to_sequence; break; } - ++doc_idx_index; + ++document_idx_index; doc_offset = 0; } } // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; + sample_idx[2 * sample_idx_index] = document_idx_index; + sample_idx[2 * sample_idx_index + 1] = doc_offset; + ++sample_idx_index; } // Method to deallocate memory. - py::capsule free_when_done(sample_idx, [](void *mem_) - { - int64_t *mem = reinterpret_cast(mem_); - delete[] mem; }); + py::capsule free_when_done( + sample_idx, + [](void *mem_){ + T *mem = reinterpret_cast(mem_); + delete[] mem; + } + ); // Return the numpy array. - const auto byte_size = sizeof(int64_t); - return py::array(std::vector{num_samples + 1, 2}, // shape - {2 * byte_size, byte_size}, // C-style contiguous strides - sample_idx, // the data pointer - free_when_done); // numpy array references + const auto byte_size = sizeof(T); + return py::array_t( + std::vector{num_samples + 1, 2}, // shape + {2 * byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done // numpy array references + ); } inline int32_t get_target_sample_len(const int32_t short_seq_ratio, @@ -829,11 +835,12 @@ py::array build_blocks_mapping(const py::array_t &docs_, } } -PYBIND11_MODULE(helpers, m) +PYBIND11_MODULE(helpers_cpp, m) { m.def("build_mapping", &build_mapping); m.def("build_blocks_mapping", &build_blocks_mapping); - m.def("build_sample_idx", &build_sample_idx); + m.def("build_sample_idx_int32", &build_sample_idx); + m.def("build_sample_idx_int64", &build_sample_idx); m.def("build_blending_indices", &build_blending_indices); m.def("build_exhaustive_blending_indices", &build_exhaustive_blending_indices); } diff --git a/megatron/core/datasets/helpers.py b/megatron/core/datasets/helpers.py new file mode 100644 index 0000000000..9978a6050a --- /dev/null +++ b/megatron/core/datasets/helpers.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import numpy + +# Implicit imports for backwards compatibility +# Explicit imports for readability +from megatron.core.datasets.helpers_cpp import * +from megatron.core.datasets.helpers_cpp import build_sample_idx_int32, build_sample_idx_int64 + + +def build_sample_idx( + sizes: numpy.ndarray, + document_indices: numpy.ndarray, + sequence_length: int, + num_epochs: int, + tokens_per_epoch: int, + drop_last_partial_sequence: bool = True, + add_extra_token_to_sequence: bool = True, +): + """Build the 2-D sample index using the properly typed templated C++ function from helpers.cpp + + Args: + sizes (numpy.ndarray): The 1-D array of document lengths + + document_indices (numpy.ndarray): The 1-D array of document indices + + sequence_length (int): The sequence length + + num_epochs (int): The number of epochs + + tokens_per_epoch (int): The number of tokens per epoch + + drop_last_partial_sequence (bool): Whether to omit the last partial sequence in the sample + index should it exist. Defaults to True. + + add_extra_token_to_sequence (bool): Whether to build samples with sequence length + `sequence_length + 1`. Defaults to True. + + Returns: + numpy.ndarray: The 2-D sample index + """ + sample_idx_max = max(document_indices.shape[0], sizes.max()) + if sample_idx_max <= numpy.iinfo(numpy.int32).max: + sample_idx = build_sample_idx_int32( + sizes, + document_indices, + sequence_length, + num_epochs, + tokens_per_epoch, + drop_last_partial_sequence, + 1 if add_extra_token_to_sequence else 0, + ) + assert sample_idx.min() >= 0 and sample_idx.max() <= sample_idx_max + else: + sample_idx = build_sample_idx_int64( + sizes, + document_indices, + sequence_length, + num_epochs, + tokens_per_epoch, + drop_last_partial_sequence, + 1 if add_extra_token_to_sequence else 0, + ) + return sample_idx diff --git a/setup.py b/setup.py index 73f20775a7..756348beef 100644 --- a/setup.py +++ b/setup.py @@ -102,7 +102,7 @@ def req_file(filename, folder="requirements"): packages=setuptools.find_namespace_packages(include=["megatron.core", "megatron.core.*"]), ext_modules=[ Extension( - "megatron.core.datasets.helpers", + "megatron.core.datasets.helpers_cpp", sources=["megatron/core/datasets/helpers.cpp"], language="c++", extra_compile_args=( diff --git a/tests/unit_tests/data/test_builder.py b/tests/unit_tests/data/test_builder.py index 7f4caaa0f6..221eb4aabe 100644 --- a/tests/unit_tests/data/test_builder.py +++ b/tests/unit_tests/data/test_builder.py @@ -1,5 +1,7 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + ## -# Compile megatron.core.datasets.helpers dependencies before BlendedDataset import +# Compile megatron.core.datasets.helpers_cpp dependencies before BlendedDataset import ## import os diff --git a/tests/unit_tests/data/test_gpt_dataset.py b/tests/unit_tests/data/test_gpt_dataset.py index 42a8532b73..cc87c0f4be 100644 --- a/tests/unit_tests/data/test_gpt_dataset.py +++ b/tests/unit_tests/data/test_gpt_dataset.py @@ -1,5 +1,7 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + ## -# Compile megatron.core.datasets.helpers dependencies before BlendedDataset import +# Compile megatron.core.datasets.helpers_cpp dependencies before BlendedDataset import ## import random diff --git a/tests/unit_tests/data/test_multimodal_dataset.py b/tests/unit_tests/data/test_multimodal_dataset.py index a9a30c02ec..12f0f45eb5 100644 --- a/tests/unit_tests/data/test_multimodal_dataset.py +++ b/tests/unit_tests/data/test_multimodal_dataset.py @@ -1,7 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. ## -# Compile megatron.core.datasets.helpers dependencies before BlendedDataset import +# Compile megatron.core.datasets.helpers_cpp dependencies before BlendedDataset import ## from types import SimpleNamespace