Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Polylithic: Enabling multi-threaded DataLoading through non-monolithic parallelism #1334

Open
andrewkho opened this issue Oct 10, 2024 · 9 comments

Comments

@andrewkho
Copy link
Contributor

andrewkho commented Oct 10, 2024

This RFC was re-created due to a problem with the original. Summary of comments from previous issue below.

🚀 The feature

TL;DR - We want to lean into modular Multi-Threading/Multi-Processing instead of the current monolithic Multi-Processing, and steer users away from the monolithic Dataset parallelism approach towards composable DataSources, and composable IterableDatasets for pre-proc operations, with parallelism configured within each operation. This will enable multi-threaded dataloading (with NoGIL support), auto-tunable parallelism, torch.compilable and GPU enabled preproc operations, more efficient loading of mixed-modalities, and composable dataloading and pre-proc graphs.

Motivation, pitch

Working name for the project: Polylithic (non-monolithic)
Where it will live: torchdata

Multimodal DataLoading is here and torch.utils.data doesn’t support it well

Multi-Modal LLMs are here. Tasks like fine-tuning, alignment, and distillation will require multi-modal dataloading for our users. LLM training often requires reading from 10s-100s of multi-modal datasets, tokenizing them, and packing them into a “token-buffer” where tokens from individual datasets are shuffled and combined into training examples for the model.

Audio, Image, and Video datasets may also require heavy-weight decoding operations to be performed before tokenization, and the difference in the data sizes between text, image, and video may be orders of magnitude. GPU decoding of images and video is an option for users as well, and libraries like Nvidia DALI will compile the entire pre-proc pipeline into GPU operations, minimizing the overhead of transfers between CPU and GPU memory.

torch.utils.data’s Dataset and DataLoader abstractions are extremely popular with users, however they are not well equipped to handle MultiModal DataLoading and accelerated pre-proc, because of the monolithic, black-box way in which it treated parallelism with multiprocessing; ie running GPU Preproc under multiprocessing is not currently realistic. While the abstractions are extremely flexible and very easy to experiment with, users are often required to write bespoke classes to create pre-proc pipelines, handle data sharding and combine multiple datasets. Optimizing is also a challenge because of the lack of control in parallelism.

Existing Context and definitions

Torch.utils.data contains the following abstractions today:

  • Dataset (aka Map-Dataset): an interface that users subclass, which defines the `__getitem__(i) -> sample` method and is responsible for loading data and performing pre-proc.
    • Eg load encoded image ‘i’ into memory, performing decoding, tensorfication, cropping, random rotations, and returning the sample for training
  • Sampler: typically of type `Iterable[int]` that defines the iteration order over a Map-Dataset. Numerous built-in samplers exist which handle in-order iteration, shuffling, weighted sampling (eg if every sample has a sampling weight), and data sharding for distributed training.
  • IterableDataset: an interface that users subclass, which defines the `__iter__() -> Iterable[sample]` method and is responsible for loading data and performing pre-proc, but is also responsible for shuffling, data sharding for distributed training. IterableDatasets are not used with Samplers.
    • Eg for LLM training, holds iterators to 5 text datasets, performs weighted sampling between datasets, loads and tokenizes text, fills a token buffer and yields sets of tokens.
  • DataLoader/StatefulDataLoader: a class which takes either a) Dataset + Sampler, or b) IterableDataset, may create multiple processes which each hold a copy of the Dataset/IterableDataset object instance. The DataLoader requests data from each individual worker (through either Sampler-provided-index or next()).
    • Multi-processing is currently the only available option provided by the DataLoader. The Python GIL prevents true thread-based parallelism, however the NoGIL PEP 703 is hoping to change that and enable true free-threaded parallelism in Python. Caveat: even with the GIL, there are likely use cases which would still benefit from multi-threading over multi-processing, though the pool of use-cases is probably smaller.
    • StatefulDataLoader is a drop-in replacement for DataLoader that has state_dict/load_state_dict methods.
# Example usage of torch.utils.data.DataLoader, Sampler, and Dataset, with multiprocess parallelism
dl = torch.utils.data.DataLoader(my_dataset, maybe_my_sampler, batch_size, multiprocessing_num_workers)
for batch in dl:
  # model forward/backward

“Monolithic” parallelism

Currently users have a single lever to control parallelism, num_workers. When num_workers > 0, the DataLoader creates background processes and holds a copy of the entire Dataset object in process memory, treating it as a “monolithic” object to be parallelized.

Consider the scenario in the figure below, where a user has defined an iterable dataset which combines two text datasets and one image dataset. There is no parallelism in this example.
image

Now consider the common case when only the image-decoding and tokenization is a bottleneck causing GPU Starvation. With today’s tooling, users simply increase dataloader num_workers > 1. The image below depicts how this is done today, by treating the entire IterableDataset as a monolith that is forked/spawned to another process.
image

Pain-points with Monolithic Parallelism for Multi-Modal LLM training

Multimodal data loading may require different levels of parallelism for different modalities, e.g. text tokenization may require only a single worker, while image decoding may benefit from 4+. The “monolithic” approach needlessly parallelizes operators that don’t need them, increasing memory and CPU utilization for things like token buffers. Tuning parallelism for performance is difficult as there is only one knob (num_workers) available.

Enabling GPU-PreProc pipelines (see Nvidia-DALI) may improve total training throughput for many users, however combining multiprocessing (eg to parallelize blob-fetching) and GPU PreProc (eg for image decoding / cropping) in the same Dataset is not currently possible.

Tensor and Pipeline parallelism offer opportunities for more efficient and more resilient/correct dataloading, however the current torch.utils.data.DataLoader is not well equipped to take advantage of this.

As we gradually move to a NoGIL world and multi-threading becomes a viable method to parallelize, the current monolithic approach requires the entire Dataset (dataloading and preproc) and its dependencies to be thread-safe, which may cause problems with adoption.

We also suffer from the usual multi-processing pain points:

  • Multiprocessing is heavy-weight in memory, startup time, and introduces IPC.
  • Sharding and shuffling IterableDatasets are very easy to get wrong
  • Large images and videos suffer from IPC overhead when passing through multiprocessing queues, and having many small tensors can incur significant ser/de overhead.
  • UDFs need to be picklable or serializable

A granular parallelism approach

To fix the monolithic parallelism problem, we want to introduce abstractions and tooling that expose more granular parallelism controls to users. This implies a solution where users construct their dataloading and pre-proc pipelines by defining and stitching together datasource and pre-proc nodes into a graph, in a similar fashion to tf.data and datapipes, with data passing between the nodes. The root of the graph is the node which produces batches that are passed to the model. The leaves are data-sources which produce data by reading from local disk, remote storage, or eg random number generators. Intermediate nodes may transform data, perform pre-fetching, combine data from multiple nodes, perform “enrichments” by eg fetching images from blob stores, perform decoding, schedule GPU operations etc.

Requirements and Constraints

To adequately support Multi Modal LLM training for PyTorch users, address the above pain points, and give us the best chance for wide-adoption, we want our solution to meet the following requirements and constraints:

  • Eager execution is the default behaviour. Ease of experimentation, flexibility, and debugging of experimental python code are critical to PyTorch’s success. Ensuring our solution has an “eager mode” which will dump great stack-traces will make developing and debugging easy for users.
  • Construct your graph with Python. Giving users the flexibility to write their pre-proc pipelines with a general purpose language will maximize experimentation and expressivity, lower barriers for entry, and match PyTorch conventions.
    • Example pseudo-code block: LLM training with 50 datasets, randomly sampling from datasets on each iteration
class DatasetSampler:
  def __init__(self, sources: List[iterables]):
    self.sources = sources

  def __iter__(self):
    self.base_iters = [itertools.cycle(iter(x)) for x in self.sources]
    n = len(self.base_iters) 
    while True: 
      ds_idx = random(n, self.sampling_weights)
      yield next(self.base_iters[ds_idx])
  • Backwards Compatibility with torch.utils.data
    • We want to minimize the number of new concepts/classes we introduce to users
    • We also want to provide an easy path to adoption by allowing users to reuse their existing Datasets (eg WebDS, Mosaic, HuggingFace, etc) as much as possible.
  • Support multi-process, multi-threaded, and NoGIL multi-threaded based parallelism at the node level
    • Some users may not want to move to multi-threading, may be stuck with GIL Python, or non-thread-safe code and libraries, where multi-processing gives the best performance.
    • With NoGIL, we will hopefully be in a world where Thread-based parallelism is a viable alternative to process-based parallelism, and we want to lean into this aspect as much as possible.
  • Enables GPU Pre-Proc pipelines to be defined, and enables compilability
    • Our solution should provide a path to torch.compile compatibility to enable GPU Pre-Proc pipelines
    • One potential solution is to have an “Accelerate” node that takes a sequence of operations and runs torch.compile on them.
    • We don’t require that every possible graph configuration (eg multiprocess parallelism) is torch.compilable, however the following example should be possible:
      • MultiThreaded reading -> torch.compile(GPU decoding -> Crop -> mirror) -> training loop
  • Support for in-order and out-of-order iteration, and support for random transform reproducibility
    • The current dataloader provides guarantees on ordering, we should continue to support this by default as it’s an important requirement for many researchers.
    • We will provide the option to relax these constraints, which may improve throughput
    • Needs feedback: if it’s too difficult to guarantee eg random transform reproducibility, we might want to relax this constraint which is more challenging in a multi-threaded environment, but should do everything we can to ensure reproducibility in iteration order.
    • Bottom line: reproducibility is an important variable to control for in experimentation, and we should do everything we can to ensure reproducibility.
  • Support for automatic tuning of workers
    • Introducing more granular parallelism controls creates a large dataloader-parameter space for users to optimize performance. Our solution should enable tf.data-style AUTOTUNE capabilities (see section 3.3) and provide good-enough results for most users.
    • Depending on how hairy this gets for multi-process, we might limit to tuning of multi-threaded workers only
  • Support for mid-epoch checkpointing and resuming
    • Training an epoch can take days or even weeks (or more) for some models and datasets. Mid-epoch checkpoint/resuming is essential to these types of workloads.
    • [caveat] We’ll need to think through how this would work for out-of-order execution.
  • Nodes will be iterable only, with no indexing support
    • Map-style datasets + samplers will be supported, but we won’t support index-based access between nodes
      • [what we won’t do] Datapipe’s MapDataPipe allowed users to pass indexes to the root of the graph and retrieve specific examples, however this requires two directions of communication between nodes, and also does not work at all for the more general case of eg sampling from multiple datasets
    • Users may still use Map Datasets + samplers by wrapping into an Iterable which produces samples, but the indices from sampler are generated from within the iterable, not coming from outside the sampler.
    • Alternatively, a Sampler can be used as a source dataset, and passed to a Mapper node which does something like “yield from (self.dataset[i] for i in self.source)”

How will we achieve this/what will we build? Plan of Record

We will introduce a new base class, (working name) say class PolylithicNode(torch.utils.data.IterableDataset). Nodes in the graph will be instances of subclasses of PolylithicNode. Nodes will define a .iterator() method instead of overriding __iter__(). This is inspired by nn.Module’s implementation where users define .forward() instead of __call__. This will allow PolylithicNode to instantiate user-defined iterators and wrap them, insert queues for pipeline-parallelism, and measure latency. For backwards compatibility, we’ll provide a wrapper which takes an existing IterableDataset. Users can compose their datasets by composing PolylithicNodes (ie through iter() and next()).

Example of composing iterable datasets to create a multimodal dataloader. [Note that we are open to ideas on syntactical sugar]

from torchdata.polylithic.nodes import PolylithicNode, Batcher, MultiThreadedMapper, PinMemory, Prefetcher, AcceleratedMapper # Note that all of these classes subclass PolylithicNode

# Note: PolylithicNode is an abstract class which provides common code for state_dict, graph traversal, autotuning, #   error propogation, etc.
# class PolylithicNode(torch.utils.data.IterableDataset): ...
#   def __iter__(self):  # PolylithicNode is still an IterableDataset
#     ...

# Some existing IterableDataset, perhaps generated through eg HuggingFace
class MyIterableDataset(torch.utils.data.IterableDataset):
  def __init__(self, json_l_file):
    self.json_l_file = json_l_file
  def __iter__(self):
    while True: # Loop forever
      with open(self.json_l_file, "r") as f:
        for line in f.readlines():
          yield json.loads(line)

# Define a Token Packer
class MyTokenPacker(PolylithicNode):
  def __init__(self, tokens_per_sample: int, sources: List[PolylithicNode], weights: List[float]):
    self.n = tokens_per_sample
    self.sources = sources
    self.weights = weights

  def iterator(self):
    self.source_iters = [iter(src) for src in self.sources]
    sample = []
    while True:
      while len(sample) < self.n:
        src_idx = weighted_sample_int(len(weights), self.weights)
        tokens = next(self.source_iters[src_idx])["tokens"]
        sample.extend(tokens)
      yield sample[:self.n]
      sample = sample[self.n:]

# Set up Tokenizer UDFs
def tokenize(data):
  data["tokens"] = Tokenizer()(data["text"])

def tokenize_img_and_text(data):
  data["tokens"] = DecodeAndTokenize()(data["image"]) + Tokenizer()(data["caption"])

# Set up text reader
text_src = PolylithicNode.from_iterable(MyIterableDataset("text_data.jsonl"))
text_src = MultiThreadedMapper(text_src, udf=tokenize, num_workers="AUTOTUNE")

# Set up Text and Image dataset, with GPU Decoding 
img_src = PolylithicNode.from_iterable(MyIterableDataset("img_caption_data.jsonl"))
img_src = Mapper(img_src, udf=GpuImageDecoder(...)) # single threaded in main process
img_src = MultiThreadedMapper(img_src, udf=tokenize_img_and_text, num_workers="AUTOTUNE")
�# Rest of pipeline
node = MyTokenPacker([img_src, text_src], [0.25, 0.75])
node = Batcher(node, batch_size)
node = PinMemory(node)
node = Prefetch(node, 2)

for tokens in node:
  ...

More complex diagram
image

  • We will define the PolylithicNode base class and tooling and utilities to support composing multiple PolylithicNodes into pipelines/preproc graphs (DAGs), with node-level parallelism controls.
  • As a programming model, users would be chaining together iterators similar to how nn.Module’s are composed: by having a base class (ie PolylithicNode) whos dependencies are its member variables which are also PolylithicNodes.
    • We can traverse the graph with reflection, inspecting instance fields to find ancestors in the graph (ie the current datapipes approach)
    • PolylithicNode is itself a subclass of IterableDataset
  • Users will be able to define their own nodes with Python
  • Define a constructor to wrap existing implementations of IterableDatasets into PolylithicNodes
  • Build out a library of useful nodes/operators that will provide users with the same functionality they expect, some examples:
    • MapToIterable operator that takes a map-style dataset + sampler to create an iterable PolylithicNode
    • Batcher
    • Pin Memory
    • ParallelMap operator which supports thread or process based parallelism for UDFs
      • Create input/output queues, and workers
        • input_queue -> [worker, worker] -> output_queue
          • A single thread reads from the source node and puts data into the input queue
          • Workers put data on the output queue
          • __iter__() yields from the output queue
      • We could have a mode which disables parallelism and prefetch eg:
        • With NoParallelism():
          • For batch in my_polylithic_dag: …
    • TorchCompile’d Map operator
      • We’d want to make sure this runs in the main process and single-threaded, and tensors passed in / created are on the same device
    • Prefetch operator
    • Caching capabilities (to memory or disk)
    • Broadcast node for TensorParallel consistency
  • Include tooling to traverse graph (see datapipe’s graph traversal method for an example)
    • [Needs feedback/investigation on feasibility] Provide graph-optimizations (eg node fusion) and pipeline parallelism (in dataloading graph)
  • torch.compile compatibility - Accelerated Pre-Proc pipelines will be supported through a special “Mapper” node which do things like ensure data is on the correct device, and that it’s not multiprocess
  • Add auto-tuning capabilities a la tf.data, and include tooling to measure throughput through each node to identify bottlenecks.
  • Include an “eager mode” that will disable parallelism and prefetching, to ease debugging and development
  • Deterministic iteration order by default, with out-of-order possible
  • Checkpointing support for iteration order with Deterministic ordering, and some limited support when iterating out-of-order
  • Provide a migration plan for users coming from datapipes who rely on composability
  • Backwards compatibility: folks can re-use existing Map and IterableDatasets, and Samplers
    • Existing IterableDatasets (or any iterable) can be wrapped/converted to a PolylithicNode
    • While we won’t support map-style access (e.g. MapDataPipe), users can control iteration order at the source, by combining Map + Sampler into a PolylithicNode.
    • Debugging will be more clunky, but can be achieved by passing a list of specified indexes as the sampler.

What about DataPipes and DL v2?

DataPipes and DL v2 were designed to address issues like composability, and there is a lot of value in what they’ve built, however their parallelism and sharding structure is still based on a monolithic approach (eg plug a datapipe into DL v1, or DL v2 + multiprocess reading service). They required migration/rewrite of datasets with often no improvement in performance, identifying dataloading-preproc bottlenecks was a challenge, and shuffling/sharding pain points weren’t adequately addressed.

The proposed approach improves upon DataPipes + DLv2 in the following ways:

  • Support for more granular parallelism
  • Reduced resource utilization through granular parallelism
  • Improved throughput/performance through NoGIL multi-threading
    • [Risk] if Python NoGIL does not gain adoption, users may need to fallback on granular process-based parallelism
  • Simpler migration path: by supporting existing IterableDatasets and Map Datasets, users can reuse as much code as they like

We want to maintain the composable aspects of datapipes, the eager-execution, and continue our partnerships with storage and cloud providers (AWS, Azure, GCP) where they provide high-performance clients, share customer pain points, and provide recommended solutions and examples to their users.

Alternatives

No response

Additional context

No response

@andrewkho
Copy link
Contributor Author

Some comments / discussions from earlier:

  • we won't be supporting both datapipes/polylithic, datapipes users will need to migrate or pin.
  • we will look into setting things like torch.num_threads and OMP_NUM_THREADS as well as work with libraries like torchcodec to minimize context-switching and thread contention from "greedy" multi-threaded libraries
  • We'll look at providing tooling for investigating bottlenecks / throughput

@andrewkho andrewkho pinned this issue Oct 10, 2024
@knoopx
Copy link

knoopx commented Oct 12, 2024

I was perfectly happy with datapipes, it provided me simple building blocks that allowed to optimize heavy-weight processes. I don't understand the need to kill them with no replacement but a promise of a better solution which addresses a completely different problem.

@andrewkho
Copy link
Contributor Author

Thanks @knoopx for the comment, is there something particular that you are doing with datapipes that wouldn't be possible with this proposal?

@knoopx
Copy link

knoopx commented Oct 12, 2024

@andrewkho I mostly use iterable-style datapipes, I like the simplicity and being able to easily chain them together and defer execution. I use them for all sort of things, not just for ML stuff. Iterable datapipes feel like python-esque observables/streams/deferables/futures/promises to me. the problems the proposal tries to solve are novel, and I'm pretty sure I could accomplish the same things but imho this new api looks like a step backwards in developer experience and i'm not sure it will solve all the existing pitfalls (like "debuggability"), after-all parallelism is intrinsically a hard problem (plus python gotchas) and adding more lower-level abstractions won't make it easier for regular users. just hopping an alternative higher-level api comes later, after you figure out all the necessary building blocks.

@andrewkho
Copy link
Contributor Author

@knoopx thanks for the feedback! yes we definitely want dev-ex to be the primary thing we optimize alongside efficiency. We do eventually want to have higher-level building blocks built on the lower-level foundational building blocks. We'd also welcome any contributions if you see gaps! One question in terms of dev-ex: we currently are trying to optimize for Eager-first execution so things like debugging and experimenting are easier. Can you point to examples of how you're currently using datapies with futures/promises?

@knoopx
Copy link

knoopx commented Oct 16, 2024

you got me wrong, I use iter-datapipes as a more convenient replacement of promises, makes it easy to turn existing sync code into async. this example also shows auto-tuning is not necessarily useful for every scenario as the primary bottleneck in here is network.

# async def fetch(url):
await asyncio.gather(*[fetch(url) for url in urls])

# def fetch(url):
list(IterableWrapper(urls).threadpool_map(fetch))

I would also strongly suggest you to get some inspiration from some of the concepts/api design of ReactiveX (hybrid observable+iterable pattern) (https://reactivex.io/)

@andrewkho
Copy link
Contributor Author

@knoopx would something like this work for you?

# urls = ["https://...", ...]
# def fetch(url): ...
src = IterableNode(urls)
list(ParallelMap(src, udf=fetch, num_workers=8))  # num_workers could also be auto-tuned

We can iterate on ideas for syntax sugar to make it look more like the IterableWrapper().threadpool_map example you've given, eg

IterableNode(urls).parallel_map(fetch, num_workers=8)

@sehoffmann
Copy link

Hey @andrewkho, this is very cool! The design makes a lot of sense for me and the focus on modularity and fine-grained control over parallelism and sharding is very appreciated!

I haven't read the RFC in detail yet so I can't give thorough feedback yet. However, one point I am missing is a set_epoch() method that propagates through the whole graph for shuffling. For instance, my own ad-hoc replacement for torch.data has this:

class DownstreamDataset(IterableDataset):
    def __init__(self, source_ds: Iterable):
        self.source_ds = source_ds

    def set_epoch(self, epoch: int):
        if hasattr(self.source_ds, 'set_epoch'):
            self.source_ds.set_epoch(epoch)

    def __len__(self):
        return len(self.source_ds)

which is then used here:


class ShardedSequenceDataset(IterableDataset):
    def __init__(
        self,
        sequence: Sequence,
        shuffle: bool = False,
        even_shards: bool = True,
        seed: int = 0,
        rank: int | None = None,
        world_size: int | None = None,
    ):
        self.sequence = sequence
        self.shuffle = shuffle
        self.even_shards = even_shards
        self.seed = seed
        self.rank = rank if rank is not None else dist.get_rank()
        self.world_size = world_size if world_size is not None else dist.get_world_size()
        self.epoch = 0

    def set_epoch(self, epoch: int):
        self.epoch = epoch

    def __iter__(self):
        worker_info = get_worker_info()
        if worker_info is None:
            rank = self.rank
            world_size = self.world_size
        else:
            rank = self.rank * worker_info.num_workers + worker_info.id
            world_size = self.world_size * worker_info.num_workers
        shards = shard_sequence(
            self.sequence,
            rank,
            world_size,
            shuffle=self.shuffle,
            even_shards=self.even_shards,
            seed=self.seed + self.epoch,
        )
        return iter(shards)

Otherwise, I am seeing very similar ideas to my own design, so I really like this new direction. E.g. something I had to implement ad-hoc as well:

class PrefetchDataset(DownstreamDataset):
    def __init__(self, source_ds: Iterable, num_elements: int):
        super().__init__(source_ds)
        self.num_elements = num_elements

    def __iter__(self):
        pool = ThreadPoolExecutor(max_workers=1)
        iter_ = iter(self.source_ds)

        with pool:
            futures = [pool.submit(next, iter_) for _ in range(self.num_elements)]
            while True:
                future = futures.pop(0)
                try:
                    element = future.result()
                except StopIteration:
                    return
                futures += [pool.submit(next, iter_)]
                yield element

It sounds like you plan to implement very similar building blocks (nodes) if I read the RFC correctly? Very cool.

@andrewkho
Copy link
Contributor Author

Thanks for the comment @sehoffmann ! Yes we've actually landed some of this already, for prefetcher see: https://github.com/pytorch/data/blob/main/torchdata/nodes/prefetch.py#L16

For epochs, there's an open PR to add this functionality, PTAL and let me know your thoughts: #1357

One thing I'm contemplating is some global callback mechanism that will traverse the dag and check hasattr on each node/operator and call for things like Shuffle and manually calling set_epoch, but may lead to unexpected behaviour if one does not know everything that's happening in the dag. Eg something seemingly simple like shuffling may be hard to reason about globally if you include things like streaming/map-style, multi-dataset mixing, and nested multi-dataset mixing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants