-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added reshard hook for frozen params in backward #1159
base: ngoyal_changes_for_pp_fp8
Are you sure you want to change the base?
Added reshard hook for frozen params in backward #1159
Conversation
08816c8
to
a4f02ef
Compare
I realized that given the way I wrote the unit test, it is possible to pass it even if all of the resharding happens at the very end of backward, not saving memory during backward. However, I did check manually that the hooks are being called throughout backward and not only in I will not update the unit test for now since it is a bit tricky on how to make it stricter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me
@awgu didn't have a chance to go through the code changes but I have a high level question. If we set requires_grad=False for frozen parameters, we won't do any FSDP comms for them (maybe all-gather only at the first iteration and then keep them)? |
@jspark1105 We will still all-gather and free them in forward every iteration. In backward, we may still all-gather them every iteration if their module forward outputs required gradients, in which case the parameters may be needed to compute gradients with respect to the module inputs. This is done via a pre-backward hook registered on those forward outputs. Before this PR, after the parameters were all-gathered, there were never freed since the complementary post-backward hook did not run. This PR adds another way to run the post-backward hook so that it can run even when the parameters do not require gradients. In other words, for frozen parameters ( |
Thanks @awgu for explanation! I guess this is optimizing mostly for ZeRO3 (reshard_after_forward=True). I guess we don't need to bother ZeRO2 for frozen since we can just use DDP (without all-gather and optimizers) although not sure how easy to use DDP for some parts of model and FSDP for the remain. |
@jspark1105 For FSDP with Otherwise, I agree that if you have a sharded optimizer implementation with DDP, then you can avoid worrying about FSDP here. Composing that DDP and FSDP might be somewhat tricky with today's APIs though, as you mentioned 😞 . |
5cbaffb
to
eebaa6e
Compare
What does this PR do?
This PR reshards frozen parameters that got all-gathered in backward. The technical approach follows from pytorch/pytorch#101982. The idea is that we use a new feature in PyTorch, the multi-grad hook (
register_multi_grad_hook
) to enable us to insert logic that runs after the gradients with respect to a module's inputs have been computed. For modules with frozen parameters, that point is a safe place to reshard the parameters since they have already been used to compute input gradients.I was not too familiar with the Fairscale testing infra, so I did not use
dist_init()
and set the env vars manually.While working on this I noticed two possible minor bugs:
if not fsdp_module._require_backward_grad_sync:
. It is not an issue if every FSDP module atomically is underno_grad
or not._use_fp32_param_shard()
in_finalize_parameters()
citing the case that the parameters get accidentally gathered after post-backward. I think we additionally need to check if we need to call_free_full_params()
, or else the optimizer step could update the sharded parameters but the next forward will not re-all-gather the parameters, resulting in a correctness issue.Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.