From 84d497973bc82fe5ca17c4ca0f20471b239b1783 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Wed, 30 Apr 2025 23:01:58 -0700 Subject: [PATCH] Explore new pipeline that overlaps optimizer with emb_lookup (#2916) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2916 # context * this workstream started from an training QPS optimization initiated from the PG side (see doc in the reference section), observing the embedding lookup can overlap with the optimizer. * Embedding table weights are updated in the fused backward (fused-TBE), so the embedding lookup can start immediately after backward is completed without dependency on the optiimzer. * we use a separate stream to run embedding lookup so that it can overlap with the previous optimizer (changed, see below) * there is also an option of using data_dist stream for this embedding lookup, the output_dist won't be block but the start_sparse_data_dist would, which results a smaller mem footprint. WARNING: This pipeline **DOES NOT** work for EBC/EC with feature processors because the embedding lookup is started immediately after TBE backward (where the embedding tables' weights have been updated) # benchmark readings * runtime: SemiSync < FusedSparseDist (lookup after opt) < FusedSparseDist (lookup before opt) < SparseDist ``` TrainPipelineSemiSync | Runtime (P90): 5447.42 ms | Peak Memory alloc (P90): 61.63 GB | Peak Memory reserved (P90): 64.31 GB TrainPipelineFusedSparseDist | Runtime (P90): 5605.63 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.61 GB TrainPipelineFusedSparseDist* | Runtime (P90): 5661.92 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.67 GB TrainPipelineSparseDist | Runtime (P90): 6034.46 ms | Peak Memory alloc (P90): 51.80 GB | Peak Memory reserved (P90): 62.25 GB * embedding_lookup_after_opt = False ``` * traces show that: (1) the emb_lookup is right behind the TBE-bwd (on the same cuda stream) (2) the output_dist is invoked right after each emb_lookup (there are two, one for unweighted ebc, one for weighted) (3) the optimizer seems **NOT** overlap with emb_lookup kernel when `embedding_lookup_after_opt = False` {F1977309185} (4) the optimizer still does **NOT** overlap with emb_lookup kernel, but it fills in the gap between the `KJTTensorAwaitable.wait()` and the embedding lookup kernel when `embedding_lookup_after_opt = True` {F1977309202} (5) if use a separate stream for embedding lookup, so that the following `start_sparse_data_dist` can start immediately. however this causes extra memory consumption. {F1977366363} (6) if re-use the data_dist stream for embedding lookup, the following up `start_sparse_data_dist` will wait for embedding lookup to complete, the measured memory footprint is smaller {F1977366349} NOTE: Based on (5) and (6) we set `use_emb_lookup_stream = False` is the default behavior # conclusions * Based on a simple model (SparseNN), both "Fused Sparse Dist" pipeline and the "Semi Sync" pipeline are faster than the current default (commonly used) "Sparse Dist" pipeline, respectively -7% (fused sparse dist) and -10% (semi sync) in runtime. * In a more realistic scenario, the optimizer step has a longer runtime footprint, which can amplify this optimization. * The "Semi Sync" pipeline has a larger QPS win but it produces slightly different numerical training results, while the "Fused Sparse Dist" pipeline with a slight few QPS win should be numerically the same as the default pipeline. * It would be the user's choice for which one to use. # reference * https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486 Reviewed By: dstaay-fb Differential Revision: D64479105 --- .../benchmark/benchmark_train_sparsenn.py | 2 + .../distributed/train_pipeline/__init__.py | 1 + .../train_pipeline/train_pipelines.py | 174 +++++++++++++++++- torchrec/distributed/train_pipeline/utils.py | 15 ++ 4 files changed, 190 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_train_sparsenn.py b/torchrec/distributed/benchmark/benchmark_train_sparsenn.py index c7d8e76e4..65fda4a85 100644 --- a/torchrec/distributed/benchmark/benchmark_train_sparsenn.py +++ b/torchrec/distributed/benchmark/benchmark_train_sparsenn.py @@ -42,6 +42,7 @@ from torchrec.distributed.train_pipeline import ( TrainPipeline, TrainPipelineBase, + TrainPipelineFusedSparseDist, TrainPipelineSparseDist, ) from torchrec.distributed.train_pipeline.train_pipelines import ( @@ -106,6 +107,7 @@ def generate_pipeline( ] = { "base": TrainPipelineBase, "sparse": TrainPipelineSparseDist, + "fused": TrainPipelineFusedSparseDist, "semi": TrainPipelineSemiSync, "prefetch": PrefetchTrainPipelineSparseDist, } diff --git a/torchrec/distributed/train_pipeline/__init__.py b/torchrec/distributed/train_pipeline/__init__.py index d7b38d2b0..5d7349160 100644 --- a/torchrec/distributed/train_pipeline/__init__.py +++ b/torchrec/distributed/train_pipeline/__init__.py @@ -15,6 +15,7 @@ TorchCompileConfig, # noqa TrainPipeline, # noqa TrainPipelineBase, # noqa + TrainPipelineFusedSparseDist, # noqa TrainPipelinePT2, # noqa TrainPipelineSparseDist, # noqa TrainPipelineSparseDistCompAutograd, # noqa diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 4685fae9c..64828aa70 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -47,6 +47,7 @@ EmbeddingPipelinedForward, EmbeddingTrainPipelineContext, In, + InSyncEmbeddingPipelinedForward, Out, PipelinedForward, PipelinedPostproc, @@ -540,8 +541,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: # modify the (sharded) sparse module forward, and invoke the first part of input_dist self._init_pipelined_modules( - # pyre-ignore [6] - self.batches[0], + self.batches[0], # pyre-ignore [6] self.contexts[0], self._pipelined_forward_type, ) @@ -803,6 +803,176 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter) +class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]): + """ + This pipeline modifies TrainPipelineSparseDist by running embedding lookup in a + separate stream so that it can overlap with the previous optimizer. The assumption + made here is the embedding is updated in the fused backward (fused-TBE) so the + embedding lookup can start immediately after backward is completed without dependency + on the optiimzer. + + NOTE: This assumption is not true if there is feature processor(s). + NOTE: This pipeline is still experimental, users should always run NE parity tests. + + batch i+0: + ShardedModule.compute_and_output_dist - uses emb_lookup CUDA stream + forward (without emb lookup) + backward and optimizer + batch i+1: + ShardedModule.input_dist() - uses data_dist CUDA stream + batch i+2: + copy batch to device + + `ShardedModule.input_dist()` is only done for top-level modules in the call graph. + To be considered a top-level module, a module can only depend on 'getattr' calls on + input. + + Input model must be symbolically traceable with the exception of `ShardedModule` and + `DistributedDataParallel` modules. + + Args: + model (torch.nn.Module): model to pipeline. + optimizer (torch.optim.Optimizer): optimizer to use. + device (torch.device): device where device transfer, sparse data dist, and + forward/backward pass will happen. + execute_all_batches (bool): executes remaining batches in pipeline after + exhausting dataloader iterator. + apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + TODO: pipeline_postproc, custom_model_fwd, strict + use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist + (for batch i+1) using a new stream, else re-using the data_dist stream + """ + + # The PipelinedForward class that is used in _rewrite_model + _pipelined_forward_type = InSyncEmbeddingPipelinedForward # pyre-ignore + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + execute_all_batches: bool = True, + apply_jit: bool = False, + pipeline_postproc: bool = True, + custom_model_fwd: Optional[ + Callable[[Optional[In]], Tuple[torch.Tensor, Out]] + ] = None, + strict: bool = False, + use_emb_lookup_stream: bool = False, # default False explained below + ) -> None: + super().__init__( + model=model, + optimizer=optimizer, + device=device, + execute_all_batches=execute_all_batches, + apply_jit=apply_jit, + context_type=EmbeddingTrainPipelineContext, + pipeline_postproc=pipeline_postproc, + custom_model_fwd=custom_model_fwd, + ) + if use_emb_lookup_stream: + self._emb_lookup_stream: Optional[torch.Stream] = ( + (torch.get_device_module(device).Stream()) + if device.type in ["cuda", "mtia"] + else None + ) + else: + # default to False: re-use data_dist stream for emb lookup to reduce CUDA memory footprint + # due to Caching Allocator reserving the memory for each stream + self._emb_lookup_stream = self._data_dist_stream + + def wait_embedding_lookup(self) -> None: + """ + Waits on the embedding lookup requests to get the embedding lookup tensors requests + """ + current_stream = torch.get_device_module(self._device).current_stream() + current_stream.wait_stream(self._emb_lookup_stream) + + def start_embedding_lookup( + self, + batch: Optional[In], + context: EmbeddingTrainPipelineContext, + ) -> None: + """ + Waits for batch to finish getting copied to GPU, then starts the input dist. This Event based vesrion. + """ + if batch is None: + return + + with record_function(f"## start_embedding_lookup {context.index} ##"): + current_stream = torch.get_device_module(self._device).current_stream() + with self._stream_context(self._emb_lookup_stream): + for module in self._pipelined_modules: + _start_embedding_lookup( + module, + context, + source_stream=self._emb_lookup_stream, + target_stream=current_stream, + stream_context=self._stream_context, + ) + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + """ + For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity): + batches[0]: i+0 batch, fwd/bwd/opt (expecting output_dist) + batches[1]: i+1 batch, for input_dist (expecting copied to device), and compute_and_output_dist + batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter) + """ + + # attach the model just in case the user forgets to call it, especially when the user + # pauses the pipeline.progress and detach the model for other purpose. + if not self._model_attached: + self.attach(self._model) + + # fill the pipeline is only needed for the beginning when the pipeline (batches) is empty + self.fill_pipeline(dataloader_iter) + + # here is the expected stop after exhausting all batches + if not self.batches: + raise StopIteration + + # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) + self._set_module_context(self.contexts[0]) + + # start embedding_lookup so it can overlap with previous optimizer + # pyre-ignore [6] + self.start_embedding_lookup(self.batches[0], self.contexts[0]) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + # wait for batches[0] being available on device, this should always be completed since + # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check + self._wait_for_batch() + + if len(self.batches) >= 2: + # invoke splits all_to_all comms (first part of input_dist) + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) + + # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here + self.enqueue_batch(dataloader_iter) + + # forward + with record_function("## forward ##"): + losses, output = self._model_fwd(self.batches[0]) + + if len(self.batches) >= 2: + # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) + self.wait_sparse_data_dist(self.contexts[1]) + + if self._model.training: + # backward + self._backward(losses) + + # update + with record_function("## optimizer ##"): + self._optimizer.step() + + self.dequeue_batch() + return output + + class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]): """ Novel method for RecSys model training by leveraging "Semi-Synchronous" training, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index f75a25af4..fb20fa38a 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -670,6 +670,20 @@ def detach_embeddings( self._context.detached_embedding_tensors.append(detached_tensors) +class InSyncEmbeddingPipelinedForward(EmbeddingPipelinedForward): + """ + This pipeline is used in TrainPipelineFusedSparseDist + """ + + def detach_embeddings( + self, + embeddings: Union[Dict[str, JaggedTensor], KeyedTensor], + cur_stream: torch.Stream, + ) -> None: + # doing nothing + pass + + class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]): """ This pipeline is used in PrefetchTrainPipelineSparseDist @@ -853,6 +867,7 @@ def _start_data_dist( PipelinedForward, PrefetchPipelinedForward, EmbeddingPipelinedForward, + InSyncEmbeddingPipelinedForward, ), )