Skip to content

Commit

Permalink
Merge branch 'return-type-sample-idx' into 'main'
Browse files Browse the repository at this point in the history
sample index helper function, no unnecessary memory allocation, no unnecessary casting/copying

See merge request ADLR/megatron-lm!2381
  • Loading branch information
jaredcasper committed Dec 8, 2024
2 parents 9dc7fef + e059614 commit 9665f2d
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 72 deletions.
10 changes: 7 additions & 3 deletions megatron/core/datasets/Makefile
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = helpers

LIBNAME = helpers_cpp
LIBEXT = $(shell python3-config --extension-suffix)

default: $(LIBNAME)$(LIBEXT)
OUT = $(LIBNAME)$(LIBEXT)
SRC = helpers.cpp

default: $(OUT)

%$(LIBEXT): %.cpp
$(OUT): $(SRC)
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
62 changes: 47 additions & 15 deletions megatron/core/datasets/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class GPTDataset(MegatronDataset):
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When
None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
Expand Down Expand Up @@ -318,7 +319,8 @@ def _build_document_sample_shuffle_indices(
-- A random permutation of index range of the sample index
Returns:
Tuple[numpy.ndarray, numpy.ndarray]: The document index, the sample index, and the shuffle index
Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: The document index, the sample
index, and the shuffle index
"""
path_to_cache = self.config.path_to_cache
if path_to_cache is None and not self.config.mock:
Expand All @@ -327,10 +329,8 @@ def _build_document_sample_shuffle_indices(
)

if path_to_cache:
get_path_to = lambda suffix: os.path.join(
path_to_cache,
f"{self.unique_description_hash}-{type(self).__name__}-{self.index_split.name}-{suffix}",
)
base = f"{self.unique_description_hash}-{type(self).__name__}-{self.index_split.name}"
get_path_to = lambda affix: os.path.join(path_to_cache, f"{base}-{affix}")
path_to_description = get_path_to("description.txt")
path_to_document_index = get_path_to("document_index.npy")
path_to_sample_index = get_path_to("sample_index.npy")
Expand Down Expand Up @@ -427,11 +427,13 @@ def _build_document_sample_shuffle_indices(
assert document_index.dtype == numpy.int32
assert self.dataset.sequence_lengths.dtype == numpy.int32
if len(document_index) * 2 > len(self.dataset.sequence_lengths):
# Heuristic: if "access density" of sequence_lengths is relatively high,
# force loading the mmap-ed array into memory by taking a copy.
# If "access density" of sequence_lengths is high, force load the mmap-ed array
# into memory by making a copy.
#
# System performance benefits come from two aspects:
# 1. **sequentially** pre-loading the whole file if we're gonna read a large fraction anyways.
# 2. GIL is held when calling into c++ code; making the c++ func faster improves parallelism.
# 1. We sequentially pre-load the whole file, most of which we expect to read
# 2. The GIL is held when entering the c++ program, improving the speed of which
# improves parallelism
sequence_lengths_for_cpp = self.dataset.sequence_lengths.copy()
else:
sequence_lengths_for_cpp = self.dataset.sequence_lengths
Expand Down Expand Up @@ -467,7 +469,7 @@ def _build_document_sample_shuffle_indices(
log_single_rank(
logger,
logging.WARNING,
f"Unable to save the {type(self).__name__} indexes because path_to_cache is None",
f"Unable to save {type(self).__name__} indexes because path_to_cache is None",
)

t_end = time.time()
Expand Down Expand Up @@ -592,7 +594,8 @@ def _build_shuffle_index(
Args:
num_samples (int): The size of the first shuffle range [0, num_samples)
total_size (int): The size of the entire index. If larger than 'num_samples', it defines the second shuffle range [num_samples, total_size)
total_size (int): The size of the entire index. If larger than 'num_samples', it defines
the second shuffle range [num_samples, total_size)
numpy_random_state (numpy.random.RandomState): The NumPy random state
Expand Down Expand Up @@ -635,7 +638,8 @@ def _get_ltor_masks_and_position_ids(
eod_mask_loss (bool): Switch to enable the EOD mask loss
create_attention_mask (bool): Switch to enable the attention masks generation. Can be disabled if attention kernel generates masks by itself.
create_attention_mask (bool): Switch to enable the attention masks generation. Can be
disabled if attention kernel generates masks by itself.
Returns:
torch.Tensor: Attention mask needed to be used for Attention
Expand Down Expand Up @@ -691,10 +695,24 @@ def _get_ltor_masks_and_position_ids(


class MockGPTLowLevelDataset:
"""The mock GPT low level dataset
This class is meant to generate tokenized data in the classic "Megatron-LM" GPT style. Notably,
we add the end of document token to each element indexed in __getitem__
Args:
tokenizer (MegatronTokenizer): The tokenizer the special token information of which we use
to augment the mock data.
"""

seed: int = 0
"""The hard-coded random seed to use to set the NumPy RNG"""

size: int = 100000
"""The hard-coded number of samples to generate"""

max_sequence_length: int = 4096
"""The hard-coded max sequence length to generate"""

def __init__(self, tokenizer: MegatronTokenizer) -> None:
self.tokenizer = tokenizer
Expand All @@ -714,6 +732,18 @@ def __getitem__(self, idx: int) -> numpy.number:
return sample

def get(self, idx: int, offset: int = 0, length: Optional[int] = None) -> numpy.ndarray:
"""This function is n abstraction over __getitem__ with support for slicing
Args:
idx (int): The index into the dataset
offset (int): The integer token offset in the sequence
length (Optional[int]): The number of tokens to grab from the sequence
Returns:
numpy.ndarray: The sequence tokens at the index
"""
if length is None:
length = self.sequence_lengths[idx] - offset
return self[idx][offset : offset + length]
Expand All @@ -723,7 +753,8 @@ class MockGPTDataset(GPTDataset):
"""The mock GPT dataset
Args:
indexed_dataset (MockGPTLowLevelDataset): The MockGPTLowLevelDataset around which to build the MockGPTDataset
indexed_dataset (MockGPTLowLevelDataset): The MockGPTLowLevelDataset around which to build
the MockGPTDataset
dataset_path (Optional[str]): This argument is of no consequence for the MockGPTDataset
Expand Down Expand Up @@ -768,7 +799,8 @@ def build_low_level_dataset(
"""Abstract method implementation
Args:
dataset_path (Optional[str]): This argument is of no consequence for the MockGPTLowLevelDataset
dataset_path (Optional[str]): This argument is of no consequence for the
MockGPTLowLevelDataset
config (GPTDatasetConfig): The config
Expand Down
107 changes: 57 additions & 50 deletions megatron/core/datasets/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,22 @@ void build_blending_indices(py::array_t<int16_t> &dataset_index,
}
}

py::array build_sample_idx(const py::array_t<int32_t> &sizes_,
const py::array_t<int32_t> &doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch,
const bool drop_last_partial_sequence = true,
const int add_extra_token_to_sequence = 1)
{
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
starting offset in that document.*/
template <typename T>
py::array_t<T> build_sample_idx(
const py::array_t<int32_t> &sizes_,
const py::array_t<int32_t> &document_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch,
const bool drop_last_partial_sequence = true,
const int add_extra_token_to_sequence = 1
){
/*
Sample index (sample_idx) is used for gpt2 like dataset for which the documents are flattened
and the samples are built based on this 1-D flatten array. It is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains the index into `doc_idx` and [..., 1] is
the starting offset in that document.
*/

// Consistency checks.
assert(seq_length > 1);
Expand All @@ -160,83 +163,86 @@ py::array build_sample_idx(const py::array_t<int32_t> &sizes_,

// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();
auto document_idx = document_idx_.unchecked<1>();

// Mapping and it's length (1D).
// Build the sample idx as a contiguous 1-D array of type T.
int64_t num_samples = 0;
if (drop_last_partial_sequence == true)
{
if (drop_last_partial_sequence == true) {
num_samples = (num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length;
}
else
{
else {
num_samples = ceil(float(num_epochs * tokens_per_epoch - add_extra_token_to_sequence) / seq_length);
}
int64_t *sample_idx = new int64_t[2 * (num_samples + 1)];
T *sample_idx = new T[2 * (num_samples + 1)];

// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
int64_t sample_idx_index = 0;
// Index into document_idx.
T document_idx_index = 0;
// Begining offset for each document.
int32_t doc_offset = 0;
T doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
sample_idx[2 * sample_idx_index] = document_idx_index;
sample_idx[2 * sample_idx_index + 1] = doc_offset;
++sample_idx_index;

while (sample_index <= num_samples)
while (sample_idx_index <= num_samples)
{
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + add_extra_token_to_sequence;
while (remaining_seq_length != 0)
{
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
auto document_index = document_idx[document_idx_index];
auto document_length = sizes[document_index] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
remaining_seq_length -= document_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0)
{
doc_offset += (remaining_seq_length + doc_length - add_extra_token_to_sequence);
doc_offset += (remaining_seq_length + document_length - add_extra_token_to_sequence);
remaining_seq_length = 0;
}
else
{
// Otherwise, start from the begining of the next document.
if (doc_idx_index == (doc_idx_.shape(0) - 1))
if (document_idx_index == (document_idx_.shape(0) - 1))
{
// If we have reached the end of the documents, break.
assert(sample_index == num_samples);
doc_offset = sizes[doc_idx[doc_idx_index]] - add_extra_token_to_sequence;
assert(sample_idx_index == num_samples);
doc_offset = sizes[document_idx[document_idx_index]] - add_extra_token_to_sequence;
break;
}
++doc_idx_index;
++document_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
sample_idx[2 * sample_idx_index] = document_idx_index;
sample_idx[2 * sample_idx_index + 1] = doc_offset;
++sample_idx_index;
}

// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void *mem_)
{
int64_t *mem = reinterpret_cast<int64_t*>(mem_);
delete[] mem; });
py::capsule free_when_done(
sample_idx,
[](void *mem_){
T *mem = reinterpret_cast<T*>(mem_);
delete[] mem;
}
);

// Return the numpy array.
const auto byte_size = sizeof(int64_t);
return py::array(std::vector<int64_t>{num_samples + 1, 2}, // shape
{2 * byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done); // numpy array references
const auto byte_size = sizeof(T);
return py::array_t<T>(
std::vector<int64_t>{num_samples + 1, 2}, // shape
{2 * byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done // numpy array references
);
}

inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
Expand Down Expand Up @@ -829,11 +835,12 @@ py::array build_blocks_mapping(const py::array_t<int64_t> &docs_,
}
}

PYBIND11_MODULE(helpers, m)
PYBIND11_MODULE(helpers_cpp, m)
{
m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx", &build_sample_idx);
m.def("build_sample_idx_int32", &build_sample_idx<int32_t>);
m.def("build_sample_idx_int64", &build_sample_idx<int64_t>);
m.def("build_blending_indices", &build_blending_indices);
m.def("build_exhaustive_blending_indices", &build_exhaustive_blending_indices);
}
64 changes: 64 additions & 0 deletions megatron/core/datasets/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import numpy

# Implicit imports for backwards compatibility
# Explicit imports for readability
from megatron.core.datasets.helpers_cpp import *
from megatron.core.datasets.helpers_cpp import build_sample_idx_int32, build_sample_idx_int64


def build_sample_idx(
sizes: numpy.ndarray,
document_indices: numpy.ndarray,
sequence_length: int,
num_epochs: int,
tokens_per_epoch: int,
drop_last_partial_sequence: bool = True,
add_extra_token_to_sequence: bool = True,
):
"""Build the 2-D sample index using the properly typed templated C++ function from helpers.cpp
Args:
sizes (numpy.ndarray): The 1-D array of document lengths
document_indices (numpy.ndarray): The 1-D array of document indices
sequence_length (int): The sequence length
num_epochs (int): The number of epochs
tokens_per_epoch (int): The number of tokens per epoch
drop_last_partial_sequence (bool): Whether to omit the last partial sequence in the sample
index should it exist. Defaults to True.
add_extra_token_to_sequence (bool): Whether to build samples with sequence length
`sequence_length + 1`. Defaults to True.
Returns:
numpy.ndarray: The 2-D sample index
"""
sample_idx_max = max(document_indices.shape[0], sizes.max())
if sample_idx_max <= numpy.iinfo(numpy.int32).max:
sample_idx = build_sample_idx_int32(
sizes,
document_indices,
sequence_length,
num_epochs,
tokens_per_epoch,
drop_last_partial_sequence,
1 if add_extra_token_to_sequence else 0,
)
assert sample_idx.min() >= 0 and sample_idx.max() <= sample_idx_max
else:
sample_idx = build_sample_idx_int64(
sizes,
document_indices,
sequence_length,
num_epochs,
tokens_per_epoch,
drop_last_partial_sequence,
1 if add_extra_token_to_sequence else 0,
)
return sample_idx
Loading

0 comments on commit 9665f2d

Please sign in to comment.