Skip to content

Explore new pipeline that overlaps optimizer with emb_lookup #2916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchrec/distributed/benchmark/benchmark_train_sparsenn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from torchrec.distributed.train_pipeline import (
TrainPipeline,
TrainPipelineBase,
TrainPipelineFusedSparseDist,
TrainPipelineSparseDist,
)
from torchrec.distributed.train_pipeline.train_pipelines import (
Expand Down Expand Up @@ -106,6 +107,7 @@ def generate_pipeline(
] = {
"base": TrainPipelineBase,
"sparse": TrainPipelineSparseDist,
"fused": TrainPipelineFusedSparseDist,
"semi": TrainPipelineSemiSync,
"prefetch": PrefetchTrainPipelineSparseDist,
}
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TorchCompileConfig, # noqa
TrainPipeline, # noqa
TrainPipelineBase, # noqa
TrainPipelineFusedSparseDist, # noqa
TrainPipelinePT2, # noqa
TrainPipelineSparseDist, # noqa
TrainPipelineSparseDistCompAutograd, # noqa
Expand Down
174 changes: 172 additions & 2 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
EmbeddingPipelinedForward,
EmbeddingTrainPipelineContext,
In,
InSyncEmbeddingPipelinedForward,
Out,
PipelinedForward,
PipelinedPostproc,
Expand Down Expand Up @@ -540,8 +541,7 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:

# modify the (sharded) sparse module forward, and invoke the first part of input_dist
self._init_pipelined_modules(
# pyre-ignore [6]
self.batches[0],
self.batches[0], # pyre-ignore [6]
self.contexts[0],
self._pipelined_forward_type,
)
Expand Down Expand Up @@ -803,6 +803,176 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)


class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]):
"""
This pipeline modifies TrainPipelineSparseDist by running embedding lookup in a
separate stream so that it can overlap with the previous optimizer. The assumption
made here is the embedding is updated in the fused backward (fused-TBE) so the
embedding lookup can start immediately after backward is completed without dependency
on the optiimzer.

NOTE: This assumption is not true if there is feature processor(s).
NOTE: This pipeline is still experimental, users should always run NE parity tests.

batch i+0:
ShardedModule.compute_and_output_dist - uses emb_lookup CUDA stream
forward (without emb lookup)
backward and optimizer
batch i+1:
ShardedModule.input_dist() - uses data_dist CUDA stream
batch i+2:
copy batch to device

`ShardedModule.input_dist()` is only done for top-level modules in the call graph.
To be considered a top-level module, a module can only depend on 'getattr' calls on
input.

Input model must be symbolically traceable with the exception of `ShardedModule` and
`DistributedDataParallel` modules.

Args:
model (torch.nn.Module): model to pipeline.
optimizer (torch.optim.Optimizer): optimizer to use.
device (torch.device): device where device transfer, sparse data dist, and
forward/backward pass will happen.
execute_all_batches (bool): executes remaining batches in pipeline after
exhausting dataloader iterator.
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
TODO: pipeline_postproc, custom_model_fwd, strict
use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist
(for batch i+1) using a new stream, else re-using the data_dist stream
"""

# The PipelinedForward class that is used in _rewrite_model
_pipelined_forward_type = InSyncEmbeddingPipelinedForward # pyre-ignore

def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
execute_all_batches: bool = True,
apply_jit: bool = False,
pipeline_postproc: bool = True,
custom_model_fwd: Optional[
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
] = None,
strict: bool = False,
use_emb_lookup_stream: bool = False, # default False explained below
) -> None:
super().__init__(
model=model,
optimizer=optimizer,
device=device,
execute_all_batches=execute_all_batches,
apply_jit=apply_jit,
context_type=EmbeddingTrainPipelineContext,
pipeline_postproc=pipeline_postproc,
custom_model_fwd=custom_model_fwd,
)
if use_emb_lookup_stream:
self._emb_lookup_stream: Optional[torch.Stream] = (
(torch.get_device_module(device).Stream())
if device.type in ["cuda", "mtia"]
else None
)
else:
# default to False: re-use data_dist stream for emb lookup to reduce CUDA memory footprint
# due to Caching Allocator reserving the memory for each stream
self._emb_lookup_stream = self._data_dist_stream

def wait_embedding_lookup(self) -> None:
"""
Waits on the embedding lookup requests to get the embedding lookup tensors requests
"""
current_stream = torch.get_device_module(self._device).current_stream()
current_stream.wait_stream(self._emb_lookup_stream)

def start_embedding_lookup(
self,
batch: Optional[In],
context: EmbeddingTrainPipelineContext,
) -> None:
"""
Waits for batch to finish getting copied to GPU, then starts the input dist. This Event based vesrion.
"""
if batch is None:
return

with record_function(f"## start_embedding_lookup {context.index} ##"):
current_stream = torch.get_device_module(self._device).current_stream()
with self._stream_context(self._emb_lookup_stream):
for module in self._pipelined_modules:
_start_embedding_lookup(
module,
context,
source_stream=self._emb_lookup_stream,
target_stream=current_stream,
stream_context=self._stream_context,
)

def progress(self, dataloader_iter: Iterator[In]) -> Out:
"""
For TrainPipelineSparseDist, we assume the max pipelined batches == 3 (capacity):
batches[0]: i+0 batch, fwd/bwd/opt (expecting output_dist)
batches[1]: i+1 batch, for input_dist (expecting copied to device), and compute_and_output_dist
batches[2]: i+2 batch, for copy_batch_to_gpu (expecting non-exhausted dataloader iter)
"""

# attach the model just in case the user forgets to call it, especially when the user
# pauses the pipeline.progress and detach the model for other purpose.
if not self._model_attached:
self.attach(self._model)

# fill the pipeline is only needed for the beginning when the pipeline (batches) is empty
self.fill_pipeline(dataloader_iter)

# here is the expected stop after exhausting all batches
if not self.batches:
raise StopIteration

# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
self._set_module_context(self.contexts[0])

# start embedding_lookup so it can overlap with previous optimizer
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[0], self.contexts[0])

if self._model.training:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()

# wait for batches[0] being available on device, this should always be completed since
# the input_dist of batches[0] has be invoked in previous iter. TODO: fact check
self._wait_for_batch()

if len(self.batches) >= 2:
# invoke splits all_to_all comms (first part of input_dist)
self.start_sparse_data_dist(self.batches[1], self.contexts[1])

# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
self.enqueue_batch(dataloader_iter)

# forward
with record_function("## forward ##"):
losses, output = self._model_fwd(self.batches[0])

if len(self.batches) >= 2:
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
self.wait_sparse_data_dist(self.contexts[1])

if self._model.training:
# backward
self._backward(losses)

# update
with record_function("## optimizer ##"):
self._optimizer.step()

self.dequeue_batch()
return output


class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
"""
Novel method for RecSys model training by leveraging "Semi-Synchronous" training,
Expand Down
15 changes: 15 additions & 0 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,20 @@ def detach_embeddings(
self._context.detached_embedding_tensors.append(detached_tensors)


class InSyncEmbeddingPipelinedForward(EmbeddingPipelinedForward):
"""
This pipeline is used in TrainPipelineFusedSparseDist
"""

def detach_embeddings(
self,
embeddings: Union[Dict[str, JaggedTensor], KeyedTensor],
cur_stream: torch.Stream,
) -> None:
# doing nothing
pass


class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
"""
This pipeline is used in PrefetchTrainPipelineSparseDist
Expand Down Expand Up @@ -853,6 +867,7 @@ def _start_data_dist(
PipelinedForward,
PrefetchPipelinedForward,
EmbeddingPipelinedForward,
InSyncEmbeddingPipelinedForward,
),
)

Expand Down
Loading