Skip to content

Commit 6b2f509

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
Explore new pipeline that overlaps optimizer with emb_lookup (#2916)
Summary: # context * this workstream started from an training QPS optimization initiated from the PG side (see doc in the reference section), observing the embedding lookup can overlap with the optimizer. * Embedding table weights are updated in the fused backward (fused-TBE), so the embedding lookup can start immediately after backward is completed without dependency on the optiimzer. * we use a separate stream to run embedding lookup so that it can overlap with the previous optimizer (changed, see below) * there is also an option of using data_dist stream for this embedding lookup, the output_dist won't be block but the start_sparse_data_dist would, which results a smaller mem footprint. WARNING: This pipeline **DOES NOT** work for EBC/EC with feature processors because the embedding lookup is started immediately after TBE backward (where the embedding tables' weights have been updated) # benchmark readings * runtime: SemiSync < FusedSparseDist (lookup after opt) < FusedSparseDist (lookup before opt) < SparseDist ``` TrainPipelineSemiSync | Runtime (P90): 5447.42 ms | Peak Memory alloc (P90): 61.63 GB | Peak Memory reserved (P90): 64.31 GB TrainPipelineFusedSparseDist | Runtime (P90): 5605.63 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.61 GB TrainPipelineFusedSparseDist* | Runtime (P90): 5661.92 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.67 GB TrainPipelineSparseDist | Runtime (P90): 6034.46 ms | Peak Memory alloc (P90): 51.80 GB | Peak Memory reserved (P90): 62.25 GB * embedding_lookup_after_opt = False ``` * traces show that: (1) the emb_lookup is right behind the TBE-bwd (on the same cuda stream) (2) the output_dist is invoked right after each emb_lookup (there are two, one for unweighted ebc, one for weighted) (3) the optimizer seems **NOT** overlap with emb_lookup kernel when `embedding_lookup_after_opt = False` {F1977309185} (4) the optimizer still does **NOT** overlap with emb_lookup kernel, but it fills in the gap between the `KJTTensorAwaitable.wait()` and the embedding lookup kernel when `embedding_lookup_after_opt = True` {F1977309202} (5) if use a separate stream for embedding lookup, so that the following `start_sparse_data_dist` can start immediately. however this causes extra memory consumption. {F1977366363} (6) if re-use the data_dist stream for embedding lookup, the following up `start_sparse_data_dist` will wait for embedding lookup to complete, the measured memory footprint is smaller {F1977366349} NOTE: Based on (5) and (6) we set `use_emb_lookup_stream = False` is the default behavior # conclusions * Based on a simple model (SparseNN), both "Fused Sparse Dist" pipeline and the "Semi Sync" pipeline are faster than the current default (commonly used) "Sparse Dist" pipeline, respectively -7% (fused sparse dist) and -10% (semi sync) in runtime. * In a more realistic scenario, the optimizer step has a longer runtime footprint, which can amplify this optimization. * The "Semi Sync" pipeline has a larger QPS win but it produces slightly different numerical training results, while the "Fused Sparse Dist" pipeline with a slight few QPS win should be numerically the same as the default pipeline. * It would be the user's choice for which one to use. # reference * https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486 Differential Revision: D64479105
1 parent 33aeafa commit 6b2f509

File tree

5 files changed

+191
-5
lines changed

5 files changed

+191
-5
lines changed

torchrec/distributed/benchmark/benchmark_train_sparsenn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from torchrec.distributed.train_pipeline import (
4343
TrainPipeline,
4444
TrainPipelineBase,
45+
TrainPipelineFusedSparseDist,
4546
TrainPipelineSparseDist,
4647
)
4748
from torchrec.distributed.train_pipeline.train_pipelines import (
@@ -106,6 +107,7 @@ def generate_pipeline(
106107
] = {
107108
"base": TrainPipelineBase,
108109
"sparse": TrainPipelineSparseDist,
110+
"fused": TrainPipelineFusedSparseDist,
109111
"semi": TrainPipelineSemiSync,
110112
"prefetch": PrefetchTrainPipelineSparseDist,
111113
}

torchrec/distributed/test_utils/test_input.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,9 @@ def _assemble_kjt(
442442
lengths = None
443443
if pin_memory:
444444
indices = indices.pin_memory()
445-
lengths = lengths.pin_memory() if lengths else None
446-
weights = weights.pin_memory() if weights else None
447-
offsets = offsets.pin_memory() if offsets else None
445+
lengths = lengths.pin_memory() if lengths is not None else None
446+
weights = weights.pin_memory() if weights is not None else None
447+
offsets = offsets.pin_memory() if offsets is not None else None
448448
return KeyedJaggedTensor(features, indices, weights, lengths, offsets)
449449

450450
@staticmethod

torchrec/distributed/train_pipeline/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TorchCompileConfig, # noqa
1616
TrainPipeline, # noqa
1717
TrainPipelineBase, # noqa
18+
TrainPipelineFusedSparseDist, # noqa
1819
TrainPipelinePT2, # noqa
1920
TrainPipelineSparseDist, # noqa
2021
TrainPipelineSparseDistCompAutograd, # noqa

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
EmbeddingPipelinedForward,
4848
EmbeddingTrainPipelineContext,
4949
In,
50+
InSyncEmbeddingPipelinedForward,
5051
Out,
5152
PipelinedForward,
5253
PipelinedPostproc,
@@ -540,8 +541,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
540541

541542
# modify the (sharded) sparse module forward, and invoke the first part of input_dist
542543
self._init_pipelined_modules(
543-
# pyre-ignore [6]
544-
self.batches[0],
544+
self.batches[0], # pyre-ignore [6]
545545
self.contexts[0],
546546
self._pipelined_forward_type,
547547
)
@@ -803,6 +803,174 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
803803
self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)
804804

805805

806+
class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
807+
"""
808+
This pipeline modifies TrainPipelineSparseDist by running embedding lookup in a
809+
separate stream so that it can overlap with the previous optimizer. The assumption
810+
made here is the embedding is updated in the fused backward (fused-TBE) so the
811+
embedding lookup can start immediately after backward is completed without dependency
812+
on the optiimzer.
813+
NOTE: This assumption is not true if there is feature processor(s).
814+
815+
batch i+0:
816+
ShardedModule.compute_and_output_dist - uses emb_lookup CUDA stream
817+
forward (without emb lookup)
818+
backward and optimizer
819+
batch i+1:
820+
ShardedModule.input_dist() - uses data_dist CUDA stream
821+
batch i+2:
822+
copy batch to device
823+
824+
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
825+
To be considered a top-level module, a module can only depend on 'getattr' calls on
826+
input.
827+
828+
Input model must be symbolically traceable with the exception of `ShardedModule` and
829+
`DistributedDataParallel` modules.
830+
831+
Args:
832+
model (torch.nn.Module): model to pipeline.
833+
optimizer (torch.optim.Optimizer): optimizer to use.
834+
device (torch.device): device where device transfer, sparse data dist, and
835+
forward/backward pass will happen.
836+
execute_all_batches (bool): executes remaining batches in pipeline after
837+
exhausting dataloader iterator.
838+
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
839+
TODO: pipeline_postproc, custom_model_fwd, strict
840+
use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
841+
(for batch i+1) using a new stream, else re-using the data_dist stream
842+
"""
843+
844+
# The PipelinedForward class that is used in _rewrite_model
845+
_pipelined_forward_type = InSyncEmbeddingPipelinedForward # pyre-ignore
846+
847+
def __init__(
848+
self,
849+
model: torch.nn.Module,
850+
optimizer: torch.optim.Optimizer,
851+
device: torch.device,
852+
execute_all_batches: bool = True,
853+
apply_jit: bool = False,
854+
pipeline_postproc: bool = True,
855+
custom_model_fwd: Optional[
856+
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
857+
] = None,
858+
strict: bool = False,
859+
use_emb_lookup_stream: bool = False, # default False explained below
860+
) -> None:
861+
super().__init__(
862+
model=model,
863+
optimizer=optimizer,
864+
device=device,
865+
execute_all_batches=execute_all_batches,
866+
apply_jit=apply_jit,
867+
context_type=EmbeddingTrainPipelineContext,
868+
pipeline_postproc=pipeline_postproc,
869+
custom_model_fwd=custom_model_fwd,
870+
)
871+
if use_emb_lookup_stream:
872+
self._emb_lookup_stream: Optional[torch.Stream] = (
873+
(torch.get_device_module(device).Stream())
874+
if device.type in ["cuda", "mtia"]
875+
else None
876+
)
877+
else:
878+
# default to False: re-use data_dist stream for emb lookup to reduce CUDA memory footprint
879+
# due to Caching Allocator reserving the memory for each stream
880+
self._emb_lookup_stream = self._data_dist_stream
881+
882+
def wait_embedding_lookup(self) -> None:
883+
"""
884+
Waits on the embedding lookup requests to get the embedding lookup tensors requests
885+
"""
886+
current_stream = torch.get_device_module(self._device).current_stream()
887+
current_stream.wait_stream(self._emb_lookup_stream)
888+
889+
def start_embedding_lookup(
890+
self,
891+
batch: Optional[In],
892+
context: EmbeddingTrainPipelineContext,
893+
) -> None:
894+
"""
895+
Waits for batch to finish getting copied to GPU, then starts the input dist. This Event based vesrion.
896+
"""
897+
if batch is None:
898+
return
899+
900+
with record_function(f"## start_embedding_lookup {context.index} ##"):
901+
current_stream = torch.get_device_module(self._device).current_stream()
902+
with self._stream_context(self._emb_lookup_stream):
903+
for module in self._pipelined_modules:
904+
_start_embedding_lookup(
905+
module,
906+
context,
907+
source_stream=self._emb_lookup_stream,
908+
target_stream=current_stream,
909+
stream_context=self._stream_context,
910+
)
911+
912+
def progress(self, dataloader_iter: Iterator[In]) -> Out:
913+
"""
914+
For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity):
915+
batches[0]: i+0 batch, fwd/bwd/opt (expecting output_dist)
916+
batches[1]: i+1 batch, for input_dist (expecting copied to device), and compute_and_output_dist
917+
batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
918+
"""
919+
920+
# attach the model just in case the user forgets to call it, especially when the user
921+
# pauses the pipeline.progress and detach the model for other purpose.
922+
if not self._model_attached:
923+
self.attach(self._model)
924+
925+
# fill the pipeline is only needed for the beginning when the pipeline (batches) is empty
926+
self.fill_pipeline(dataloader_iter)
927+
928+
# here is the expected stop after exhausting all batches
929+
if not self.batches:
930+
raise StopIteration
931+
932+
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
933+
self._set_module_context(self.contexts[0])
934+
935+
# start embedding_lookup so it can overlap with previous optimizer
936+
# pyre-ignore [6]
937+
self.start_embedding_lookup(self.batches[0], self.contexts[0])
938+
939+
if self._model.training:
940+
with record_function("## zero_grad ##"):
941+
self._optimizer.zero_grad()
942+
943+
# wait for batches[0] being available on device, this should always be completed since
944+
# the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
945+
self._wait_for_batch()
946+
947+
if len(self.batches) >= 2:
948+
# invoke splits all_to_all comms (first part of input_dist)
949+
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
950+
951+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
952+
self.enqueue_batch(dataloader_iter)
953+
954+
# forward
955+
with record_function("## forward ##"):
956+
losses, output = self._model_fwd(self.batches[0])
957+
958+
if len(self.batches) >= 2:
959+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
960+
self.wait_sparse_data_dist(self.contexts[1])
961+
962+
if self._model.training:
963+
# backward
964+
self._backward(losses)
965+
966+
# update
967+
with record_function("## optimizer ##"):
968+
self._optimizer.step()
969+
970+
self.dequeue_batch()
971+
return output
972+
973+
806974
class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
807975
"""
808976
Novel method for RecSys model training by leveraging "Semi-Synchronous" training,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,20 @@ def detach_embeddings(
670670
self._context.detached_embedding_tensors.append(detached_tensors)
671671

672672

673+
class InSyncEmbeddingPipelinedForward(EmbeddingPipelinedForward):
674+
"""
675+
This pipeline is used in TrainPipelineFusedSparseDist
676+
"""
677+
678+
def detach_embeddings(
679+
self,
680+
embeddings: Union[Dict[str, JaggedTensor], KeyedTensor],
681+
cur_stream: torch.Stream,
682+
) -> None:
683+
# doing nothing
684+
pass
685+
686+
673687
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
674688
"""
675689
This pipeline is used in PrefetchTrainPipelineSparseDist
@@ -850,6 +864,7 @@ def _start_data_dist(
850864
PipelinedForward,
851865
PrefetchPipelinedForward,
852866
EmbeddingPipelinedForward,
867+
InSyncEmbeddingPipelinedForward,
853868
),
854869
)
855870

0 commit comments

Comments
 (0)