Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added reshard hook for frozen params in backward #1159

Open
wants to merge 5 commits into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from

Conversation

awgu
Copy link

@awgu awgu commented Jan 12, 2024

What does this PR do?

This PR reshards frozen parameters that got all-gathered in backward. The technical approach follows from pytorch/pytorch#101982. The idea is that we use a new feature in PyTorch, the multi-grad hook (register_multi_grad_hook) to enable us to insert logic that runs after the gradients with respect to a module's inputs have been computed. For modules with frozen parameters, that point is a safe place to reshard the parameters since they have already been used to compute input gradients.

I was not too familiar with the Fairscale testing infra, so I did not use dist_init() and set the env vars manually.

python -m pytest tests/nn/data_parallel/test_fsdp_freezing_weights.py -s -k test_reshard_frozen_weights

While working on this I noticed two possible minor bugs:

  1. Here, I think we may prefer if not fsdp_module._require_backward_grad_sync:. It is not an issue if every FSDP module atomically is under no_grad or not.
  2. Here, we call _use_fp32_param_shard() in _finalize_parameters() citing the case that the parameters get accidentally gathered after post-backward. I think we additionally need to check if we need to call _free_full_params(), or else the optimizer step could update the sharded parameters but the next forward will not re-all-gather the parameters, resulting in a correctness issue.

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 12, 2024
@awgu awgu force-pushed the ngoyal_changes_for_pp_fp8_awgu branch from 08816c8 to a4f02ef Compare January 12, 2024 19:27
@awgu
Copy link
Author

awgu commented Jan 12, 2024

I realized that given the way I wrote the unit test, it is possible to pass it even if all of the resharding happens at the very end of backward, not saving memory during backward.

However, I did check manually that the hooks are being called throughout backward and not only in _finalize_parameters(). (In fact, in the beginning, I did not even have the reshard hook call in _finalize_parameters() at all, and I confirmed the hook was firing.)

I will not update the unit test for now since it is a bit tricky on how to make it stricter.

Copy link
Contributor

@ngoyal2707 ngoyal2707 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me

@jspark1105
Copy link

@awgu didn't have a chance to go through the code changes but I have a high level question. If we set requires_grad=False for frozen parameters, we won't do any FSDP comms for them (maybe all-gather only at the first iteration and then keep them)?

@awgu
Copy link
Author

awgu commented Feb 7, 2024

@jspark1105 We will still all-gather and free them in forward every iteration. In backward, we may still all-gather them every iteration if their module forward outputs required gradients, in which case the parameters may be needed to compute gradients with respect to the module inputs. This is done via a pre-backward hook registered on those forward outputs.

Before this PR, after the parameters were all-gathered, there were never freed since the complementary post-backward hook did not run. This PR adds another way to run the post-backward hook so that it can run even when the parameters do not require gradients.

In other words, for frozen parameters (requires_grad=False), we can still all-gather and free them for both forward and backward every iteration.

@jspark1105
Copy link

Thanks @awgu for explanation! I guess this is optimizing mostly for ZeRO3 (reshard_after_forward=True). I guess we don't need to bother ZeRO2 for frozen since we can just use DDP (without all-gather and optimizers) although not sure how easy to use DDP for some parts of model and FSDP for the remain.

@awgu
Copy link
Author

awgu commented Feb 7, 2024

@jspark1105 For FSDP with reshard_after_forward=False, it could still be beneficial to reshard parameters in backward layer by layer as we finish using them for gradient computation. Otherwise, we would only reshard everything at the end of backward. Depending on where peak memory is, maybe resharding earlier could help.

Otherwise, I agree that if you have a sharded optimizer implementation with DDP, then you can avoid worrying about FSDP here. Composing that DDP and FSDP might be somewhat tricky with today's APIs though, as you mentioned 😞 .

@awgu awgu force-pushed the ngoyal_changes_for_pp_fp8_awgu branch from 5cbaffb to eebaa6e Compare March 27, 2024 23:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants