-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Distributed intermediate/activation tensors for FSDP #687
[PyTorch] Distributed intermediate/activation tensors for FSDP #687
Conversation
d18b49f
to
71f696b
Compare
fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root) | ||
for state, module in zip(fsdp_states, fsdp_modules): | ||
if _is_te_module(module): | ||
setattr(module, "fsdp_wrapepd", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
1b822f4
to
9d8a7f5
Compare
0267b08
to
d62330d
Compare
/te-ci pytorch |
b260d50
to
8c6e9b7
Compare
/te-ci pytorch |
…s distribute their activations after the forward pass and gather them before the backward pass Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
…or base TE modules Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
…ass and gathered before forward Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
7cb919b
to
518df99
Compare
/te-ci pytorch |
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -856,3 +865,110 @@ def allreduce( | |||
handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op) | |||
|
|||
return input_, handle | |||
|
|||
|
|||
def _fsdp_scatter_tensors( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting that the linter doesn't complain about the missing docstring, is this due to the __all__
decl or just using internal function convention with _func_name
? Either way I think this is good practice going forward as well instead of adding filler docs!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I believe PyLint ignores some warnings by default for internal functions designated with the leading underscore.
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
te-ci |
/te-ci |
/te-ci pytorch |
Pipeline 15637774 |
…A#687) * New TE wrapper for PyTorch FullyShardedDataParallel to make TE modules distribute their activations after the forward pass and gather them before the backward pass Signed-off-by: Alp Dener <[email protected]> * simplified TE module setup for FSDP comms Signed-off-by: Alp Dener <[email protected]> * FSDP scatter/gather for tensors saved into autograd ctx now working for base TE modules Signed-off-by: Alp Dener <[email protected]> * make sure activation recompute disables FSDP scatter/gather Signed-off-by: Alp Dener <[email protected]> * make sure Fp8 weight buffers are sharded at the end of the backward pass and gathered before forward Signed-off-by: Alp Dener <[email protected]> * Fixed typo in attribute name Signed-off-by: Alp Dener <[email protected]> * fixed bug in finding FSDP-wrapped TE modules Signed-off-by: Alp Dener <[email protected]> * fixed typo in fp8 weight tensor name Signed-off-by: Alp Dener <[email protected]> * fixed incorrect # of gradients Signed-off-by: Alp Dener <[email protected]> * Added fp8 amax gradient hook tensor to the parameter reset Signed-off-by: Alp Dener <[email protected]> * get rid of erroneous dummy tensor leftover from incorrect rebase Signed-off-by: Alp Dener <[email protected]> * Linting fixes Signed-off-by: Alp Dener <[email protected]> * fixing git snafu and removing debug statements Signed-off-by: Alp Dener <[email protected]> --------- Signed-off-by: Alp Dener <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
torch.distributed.fsdp.FullyShardedDataParallel
cannot scatter/gather the intermediate/activation tensors that TE modules pack into the autograd context at the end of their forward passes, resulting in globally sized activation and Fp8 weight tensors staying in memory.This PR provides a
te.distributed.prepare_te_modules_for_fsdp(fsdp_root)
API that inserts references to the correct FSDP process group into FSDP-wrapped TE modules in a given model. The TE modules then use these process groups to scatter the intermediate/activation tensors at the end of the forward pass before packing them into the autograd context. The same tensors are gathered in the beginning of the backward pass before compute.Using
te.distributed.checkpoint()
turns off the scatters/gathers to avoid unnecessary comm for tensors that need to be recomputed anyway.nn.Sequential( 3 x te.LayerNormMLP )
before Fp8/intermediate sharding:nn.Sequential( 3 x te.LayerNormMLP )
after Fp8/intermediate sharding: