Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Distributed intermediate/activation tensors for FSDP #687

Merged
merged 15 commits into from
Jun 7, 2024

Conversation

denera
Copy link
Collaborator

@denera denera commented Feb 28, 2024

torch.distributed.fsdp.FullyShardedDataParallel cannot scatter/gather the intermediate/activation tensors that TE modules pack into the autograd context at the end of their forward passes, resulting in globally sized activation and Fp8 weight tensors staying in memory.

This PR provides a te.distributed.prepare_te_modules_for_fsdp(fsdp_root) API that inserts references to the correct FSDP process group into FSDP-wrapped TE modules in a given model. The TE modules then use these process groups to scatter the intermediate/activation tensors at the end of the forward pass before packing them into the autograd context. The same tensors are gathered in the beginning of the backward pass before compute.

Using te.distributed.checkpoint() turns off the scatters/gathers to avoid unnecessary comm for tensors that need to be recomputed anyway.

nn.Sequential( 3 x te.LayerNormMLP ) before Fp8/intermediate sharding:

no_fp8_sharding

nn.Sequential( 3 x te.LayerNormMLP ) after Fp8/intermediate sharding:

with_fp8_sharding

@denera denera requested review from ptrendx and ksivaman February 28, 2024 01:29
@denera denera self-assigned this Feb 28, 2024
@denera denera force-pushed the databricks/distribute-fp8-weights-fsdp branch from d18b49f to 71f696b Compare February 28, 2024 15:24
@denera denera marked this pull request as ready for review February 28, 2024 15:33
@denera denera changed the title [PyTorch] Distributed intermediate/activation tensors for FSDP -- WIP [PyTorch] Distributed intermediate/activation tensors for FSDP Mar 6, 2024
fsdp_states, fsdp_modules = _get_fsdp_states_with_modules(fsdp_root)
for state, module in zip(fsdp_states, fsdp_modules):
if _is_te_module(module):
setattr(module, "fsdp_wrapepd", True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

@denera denera force-pushed the databricks/distribute-fp8-weights-fsdp branch from 1b822f4 to 9d8a7f5 Compare March 11, 2024 19:41
@denera denera force-pushed the databricks/distribute-fp8-weights-fsdp branch 2 times, most recently from 0267b08 to d62330d Compare April 16, 2024 00:32
@ptrendx ptrendx added the 1.7.0 label May 3, 2024
@denera
Copy link
Collaborator Author

denera commented May 21, 2024

/te-ci pytorch

@denera denera force-pushed the databricks/distribute-fp8-weights-fsdp branch 2 times, most recently from b260d50 to 8c6e9b7 Compare May 22, 2024 16:01
@denera
Copy link
Collaborator Author

denera commented May 22, 2024

/te-ci pytorch

transformer_engine/pytorch/distributed.py Show resolved Hide resolved
transformer_engine/pytorch/module/linear.py Outdated Show resolved Hide resolved
examples/pytorch/fsdp/fsdp.py Outdated Show resolved Hide resolved
examples/pytorch/fsdp/fsdp.py Outdated Show resolved Hide resolved
examples/pytorch/fsdp/fsdp.py Outdated Show resolved Hide resolved
denera added 13 commits May 23, 2024 19:23
…s distribute their activations after the forward pass and gather them before the backward pass

Signed-off-by: Alp Dener <[email protected]>
…ass and gathered before forward

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
@denera denera force-pushed the databricks/distribute-fp8-weights-fsdp branch from 7cb919b to 518df99 Compare May 23, 2024 19:24
@denera
Copy link
Collaborator Author

denera commented May 23, 2024

/te-ci pytorch

@ptrendx ptrendx removed the 1.7.0 label May 30, 2024
@denera
Copy link
Collaborator Author

denera commented Jun 4, 2024

/te-ci pytorch

@denera denera requested review from timmoon10 and rahul003 June 4, 2024 23:51
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -856,3 +865,110 @@ def allreduce(
handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op)

return input_, handle


def _fsdp_scatter_tensors(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting that the linter doesn't complain about the missing docstring, is this due to the __all__ decl or just using internal function convention with _func_name? Either way I think this is good practice going forward as well instead of adding filler docs!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe PyLint ignores some warnings by default for internal functions designated with the leading underscore.

@ksivaman
Copy link
Member

ksivaman commented Jun 6, 2024

/te-ci pytorch

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman
Copy link
Member

ksivaman commented Jun 7, 2024

te-ci

@ksivaman
Copy link
Member

ksivaman commented Jun 7, 2024

/te-ci

@ksivaman
Copy link
Member

ksivaman commented Jun 7, 2024

/te-ci pytorch

@ksivaman
Copy link
Member

ksivaman commented Jun 7, 2024

Pipeline 15637774

@ksivaman ksivaman merged commit 0edf30b into NVIDIA:main Jun 7, 2024
9 of 20 checks passed
phu0ngng pushed a commit to phu0ngng/TransformerEngine that referenced this pull request Jun 11, 2024
…A#687)

* New TE wrapper for PyTorch FullyShardedDataParallel to make TE modules distribute their activations after the forward pass and gather them before the backward pass

Signed-off-by: Alp Dener <[email protected]>

* simplified TE module setup for FSDP comms

Signed-off-by: Alp Dener <[email protected]>

* FSDP scatter/gather for tensors saved into autograd ctx now working for base TE modules

Signed-off-by: Alp Dener <[email protected]>

* make sure activation recompute disables FSDP scatter/gather

Signed-off-by: Alp Dener <[email protected]>

* make sure Fp8 weight buffers are sharded at the end of the backward pass and gathered before forward

Signed-off-by: Alp Dener <[email protected]>

* Fixed typo in attribute name

Signed-off-by: Alp Dener <[email protected]>

* fixed bug in finding FSDP-wrapped TE modules

Signed-off-by: Alp Dener <[email protected]>

* fixed typo in fp8 weight tensor name

Signed-off-by: Alp Dener <[email protected]>

* fixed incorrect # of gradients

Signed-off-by: Alp Dener <[email protected]>

* Added fp8 amax gradient hook tensor to the parameter reset

Signed-off-by: Alp Dener <[email protected]>

* get rid of erroneous dummy tensor leftover from incorrect rebase

Signed-off-by: Alp Dener <[email protected]>

* Linting fixes

Signed-off-by: Alp Dener <[email protected]>

* fixing git snafu and removing debug statements

Signed-off-by: Alp Dener <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants