Skip to content

Commit 9ab495f

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
Explore new pipeline that overlaps optimizer with emb_lookup
Summary: # context * this workstream started from an training QPS optimization initiated from the PG side (see doc in the reference section), observing the embedding lookup can overlap with the optimizer. * Embedding table weights are updated in the fused backward (fused-TBE), so the embedding lookup can start immediately after backward is completed without dependency on the optiimzer. * we use a separate stream to run embedding lookup so that it can overlap with the previous optimizer (changed, see below) * there is also an option of using data_dist stream for this embedding lookup, the output_dist won't be block but the start_sparse_data_dist would, which results a smaller mem footprint. WARNING: This pipeline **DOES NOT** work for EBC/EC with feature processors because the embedding lookup is started immediately after TBE backward (where the embedding tables' weights have been updated) # benchmark readings * runtime: SemiSync < FusedSparseDist (lookup after opt) < FusedSparseDist (lookup before opt) < SparseDist ``` TrainPipelineSemiSync | Runtime (P90): 5447.42 ms | Peak Memory alloc (P90): 61.63 GB | Peak Memory reserved (P90): 64.31 GB TrainPipelineFusedSparseDist | Runtime (P90): 5605.63 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.61 GB TrainPipelineFusedSparseDist* | Runtime (P90): 5661.92 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.67 GB TrainPipelineSparseDist | Runtime (P90): 6034.46 ms | Peak Memory alloc (P90): 51.80 GB | Peak Memory reserved (P90): 62.25 GB * embedding_lookup_after_opt = False ``` * traces show that: (1) the emb_lookup is right behind the TBE-bwd (on the same cuda stream) (2) the output_dist is invoked right after each emb_lookup (there are two, one for unweighted ebc, one for weighted) (3) the optimizer seems **NOT** overlap with emb_lookup kernel when `embedding_lookup_after_opt = False` {F1977309185} (4) the optimizer still does **NOT** overlap with emb_lookup kernel, but it fills in the gap between the `KJTTensorAwaitable.wait()` and the embedding lookup kernel when `embedding_lookup_after_opt = True` {F1977309202} (5) if use a separate stream for embedding lookup, so that the following `start_sparse_data_dist` can start immediately. however this causes extra memory consumption. {F1977366363} (6) if re-use the data_dist stream for embedding lookup, the following up `start_sparse_data_dist` will wait for embedding lookup to complete, the measured memory footprint is smaller {F1977366349} NOTE: Based on (5) and (6) we set `use_emb_lookup_stream = False` is the default behavior # conclusions * Based on a simple model (SparseNN), both "Fused Sparse Dist" pipeline and the "Semi Sync" pipeline are faster than the current default (commonly used) "Sparse Dist" pipeline, respectively -7% (fused sparse dist) and -10% (semi sync) in runtime. * In a more realistic scenario, the optimizer step has a longer runtime footprint, which can amplify this optimization. * The "Semi Sync" pipeline has a larger QPS win but it produces slightly different numerical training results, while the "Fused Sparse Dist" pipeline with a slight few QPS win should be numerically the same as the default pipeline. * It would be the user's choice for which one to use. # reference * https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486 Differential Revision: D64479105
1 parent 3f46701 commit 9ab495f

File tree

4 files changed

+231
-0
lines changed

4 files changed

+231
-0
lines changed

torchrec/distributed/benchmark/benchmark_train_sparsenn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from torchrec.distributed.train_pipeline import (
4343
TrainPipeline,
4444
TrainPipelineBase,
45+
TrainPipelineFusedSparseDist,
4546
TrainPipelineSparseDist,
4647
)
4748
from torchrec.distributed.train_pipeline.train_pipelines import (
@@ -106,6 +107,7 @@ def generate_pipeline(
106107
] = {
107108
"base": TrainPipelineBase,
108109
"sparse": TrainPipelineSparseDist,
110+
"fused": TrainPipelineFusedSparseDist,
109111
"semi": TrainPipelineSemiSync,
110112
"prefetch": PrefetchTrainPipelineSparseDist,
111113
}

torchrec/distributed/train_pipeline/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TorchCompileConfig, # noqa
1616
TrainPipeline, # noqa
1717
TrainPipelineBase, # noqa
18+
TrainPipelineFusedSparseDist, # noqa
1819
TrainPipelinePT2, # noqa
1920
TrainPipelineSparseDist, # noqa
2021
TrainPipelineSparseDistCompAutograd, # noqa

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
EmbeddingPipelinedForward,
4848
EmbeddingTrainPipelineContext,
4949
In,
50+
InSyncEmbeddingPipelinedForward,
5051
Out,
5152
PipelinedForward,
5253
PipelinedPostproc,
@@ -803,6 +804,218 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
803804
self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)
804805

805806

807+
class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
808+
"""
809+
This pipeline modifies TrainPipelineSparseDist by running embedding lookup in a
810+
separate stream so that it can overlap with the previous optimizer. The assumption
811+
made here is the embedding is updated in the fused backward (fused-TBE) so the
812+
embedding lookup can start immediately after backward is completed without dependency
813+
on the optiimzer.
814+
NOTE: This assumption is not true if there is feature processor(s).
815+
816+
batch i+0:
817+
ShardedModule.compute_and_output_dist - uses emb_lookup CUDA stream
818+
forward (without emb lookup)
819+
backward and optimizer
820+
batch i+1:
821+
ShardedModule.input_dist() - uses data_dist CUDA stream
822+
batch i+2:
823+
copy batch to device
824+
825+
`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
826+
To be considered a top-level module, a module can only depend on 'getattr' calls on
827+
input.
828+
829+
Input model must be symbolically traceable with the exception of `ShardedModule` and
830+
`DistributedDataParallel` modules.
831+
832+
Args:
833+
model (torch.nn.Module): model to pipeline.
834+
optimizer (torch.optim.Optimizer): optimizer to use.
835+
device (torch.device): device where device transfer, sparse data dist, and
836+
forward/backward pass will happen.
837+
execute_all_batches (bool): executes remaining batches in pipeline after
838+
exhausting dataloader iterator.
839+
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
840+
TODO: pipeline_postproc, custom_model_fwd, strict
841+
use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
842+
(for batch i+1) using a new stream, else re-using the data_dist stream
843+
"""
844+
845+
# The PipelinedForward class that is used in _rewrite_model
846+
_pipelined_forward_type = InSyncEmbeddingPipelinedForward # pyre-ignore
847+
848+
def __init__(
849+
self,
850+
model: torch.nn.Module,
851+
optimizer: torch.optim.Optimizer,
852+
device: torch.device,
853+
execute_all_batches: bool = True,
854+
apply_jit: bool = False,
855+
pipeline_postproc: bool = True,
856+
custom_model_fwd: Optional[
857+
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
858+
] = None,
859+
strict: bool = False,
860+
use_emb_lookup_stream: bool = False, # default False explained below
861+
) -> None:
862+
super().__init__(
863+
model=model,
864+
optimizer=optimizer,
865+
device=device,
866+
execute_all_batches=execute_all_batches,
867+
apply_jit=apply_jit,
868+
context_type=EmbeddingTrainPipelineContext,
869+
pipeline_postproc=pipeline_postproc,
870+
custom_model_fwd=custom_model_fwd,
871+
)
872+
if use_emb_lookup_stream:
873+
self._emb_lookup_stream: Optional[torch.Stream] = (
874+
(torch.get_device_module(device).Stream())
875+
if device.type in ["cuda", "mtia"]
876+
else None
877+
)
878+
else:
879+
# default to False: re-use data_dist stream for emb lookup to reduce CUDA memory footprint
880+
# due to Caching Allocator reserving the memory for each stream
881+
self._emb_lookup_stream = self._data_dist_stream
882+
883+
def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
884+
"""
885+
This function is called in self.progress (one of the main APIs for running train pipeline)
886+
Here we assume the max pipelined len(batches) == 2 (capacity), which will be the most common
887+
scenario during the full training job, when this function is effectively doing nothing.
888+
There would only be two other scenarios:
889+
len(batches) == 0:
890+
initialize the pipeline, fill in two batches, start input_dist for the first batch.
891+
len(batches) == 1:
892+
dataloader_iter stops, the last batch, do nothing
893+
"""
894+
895+
# pipeline is already filled with max capacity (2)
896+
if len(self.batches) >= 2:
897+
return
898+
899+
# executes last batch in pipeline, when there is only one batch in the pipeline
900+
# TODO: this _execute_all_batches doesn't really work here D43546239. it will
901+
# just throw an exception at copy_to_gpu when the dataloader is exhausted
902+
if self.batches and self._execute_all_batches:
903+
return
904+
905+
# batch i, data (batch) and context
906+
if not self.enqueue_batch(dataloader_iter):
907+
return
908+
909+
# modify the (sharded) sparse module forward, and invoke the first part of input_dist
910+
self._init_pipelined_modules(
911+
self.batches[0], # pyre-ignore [6]
912+
self.contexts[0],
913+
self._pipelined_forward_type, # pyre-ignore [6]
914+
)
915+
# doing the second part of input_dist, the first part is invoked in _init_pipelined_modules
916+
self.wait_sparse_data_dist(self.contexts[0])
917+
918+
# this is the main difference from TrainPipelineSparseDist.fill_pipeline
919+
self.start_embedding_lookup(self.batches[0], self.contexts[0]) # pyre-ignore [6]
920+
921+
# batch i+1
922+
if not self.enqueue_batch(dataloader_iter):
923+
return
924+
925+
def wait_embedding_lookup(self) -> None:
926+
"""
927+
Waits on the embedding lookup requests to get the embedding lookup tensors requests
928+
"""
929+
current_stream = torch.get_device_module(self._device).current_stream()
930+
current_stream.wait_stream(self._emb_lookup_stream)
931+
932+
def start_embedding_lookup(
933+
self,
934+
batch: Optional[In],
935+
context: EmbeddingTrainPipelineContext,
936+
) -> None:
937+
"""
938+
Waits for batch to finish getting copied to GPU, then starts the input dist. This Event based vesrion.
939+
"""
940+
if batch is None:
941+
return
942+
943+
with record_function(f"## start_embedding_lookup {context.index} ##"):
944+
current_stream = torch.get_device_module(self._device).current_stream()
945+
with self._stream_context(self._emb_lookup_stream):
946+
for module in self._pipelined_modules:
947+
_start_embedding_lookup(
948+
module,
949+
context,
950+
source_stream=self._emb_lookup_stream,
951+
target_stream=current_stream,
952+
stream_context=self._stream_context,
953+
)
954+
955+
def progress(self, dataloader_iter: Iterator[In]) -> Out:
956+
"""
957+
For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity):
958+
batches[0]: i+0 batch, fwd/bwd/opt (expecting output_dist)
959+
batches[1]: i+1 batch, for input_dist (expecting copied to device), and compute_and_output_dist
960+
batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
961+
"""
962+
963+
# attach the model just in case the user forgets to call it, especially when the user
964+
# pauses the pipeline.progress and detach the model for other purpose.
965+
if not self._model_attached:
966+
self.attach(self._model)
967+
968+
# fill the pipeline is only needed for the beginning when the pipeline (batches) is empty
969+
self.fill_pipeline(dataloader_iter)
970+
971+
# here is the expected stop after exhausting all batches
972+
if not self.batches:
973+
raise StopIteration
974+
975+
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
976+
self._set_module_context(self.contexts[0])
977+
978+
# start embedding_lookup so it can overlap with previous optimizer
979+
self.start_embedding_lookup(self.batches[0], self.contexts[0]) # pyre-ignore [6]
980+
981+
if self._model.training:
982+
with record_function("## zero_grad ##"):
983+
self._optimizer.zero_grad()
984+
985+
# wait for batches[0] being available on device, this should always be completed since
986+
# the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
987+
self._wait_for_batch()
988+
989+
if len(self.batches) >= 2:
990+
# invoke splits all_to_all comms (first part of input_dist)
991+
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
992+
993+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
994+
self.enqueue_batch(dataloader_iter)
995+
996+
# wait for embedding lookup before forward
997+
self.wait_embedding_lookup()
998+
999+
# forward
1000+
with record_function("## forward ##"):
1001+
losses, output = self._model_fwd(self.batches[0])
1002+
1003+
if len(self.batches) >= 2:
1004+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
1005+
self.wait_sparse_data_dist(self.contexts[1])
1006+
1007+
if self._model.training:
1008+
# backward
1009+
self._backward(losses)
1010+
1011+
# update
1012+
with record_function("## optimizer ##"):
1013+
self._optimizer.step()
1014+
1015+
self.dequeue_batch()
1016+
return output
1017+
1018+
8061019
class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
8071020
"""
8081021
Novel method for RecSys model training by leveraging "Semi-Synchronous" training,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,20 @@ def detach_embeddings(
670670
self._context.detached_embedding_tensors.append(detached_tensors)
671671

672672

673+
class InSyncEmbeddingPipelinedForward(EmbeddingPipelinedForward):
674+
"""
675+
This pipeline is used in TrainPipelineFusedSparseDist
676+
"""
677+
678+
def detach_embeddings(
679+
self,
680+
embeddings: Union[Dict[str, JaggedTensor], KeyedTensor],
681+
cur_stream: torch.Stream,
682+
) -> None:
683+
# doing nothing
684+
pass
685+
686+
673687
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
674688
"""
675689
This pipeline is used in PrefetchTrainPipelineSparseDist
@@ -850,6 +864,7 @@ def _start_data_dist(
850864
PipelinedForward,
851865
PrefetchPipelinedForward,
852866
EmbeddingPipelinedForward,
867+
InSyncEmbeddingPipelinedForward,
853868
),
854869
)
855870

0 commit comments

Comments
 (0)