-
Notifications
You must be signed in to change notification settings - Fork 286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1180
base: ngoyal_changes_for_pp_fp8
Are you sure you want to change the base?
[FSDPv1] Only perform cat() during last microbatch backward() within FlattenParamsWrapper #1180
Conversation
…t for last microbatch
… flatten_parameter.unsharded_main_grad in last microbatch backward()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach makes sense to me!
If True, only let backward pass propagate to self.params, which will | ||
invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync | ||
is True (e.g. last microbatch) | ||
NOTE: this likely will incur more GPU memory usage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why there will be more GPU memory usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @awgu, currently by testing results it shows the GPU memory overhead could be non-trivial (20% of 80G), we will follow up on reducing the memory usage
if self.fp32_grads[param_index] is None: | ||
self.fp32_grads[param_index] = grad.to(torch.float32) | ||
else: | ||
self.fp32_grads[param_index].add_(grad.data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think grad.data
can just be grad
(save one aten.detach
call)
* Changed to only run reshard hook if all gradients computed * Fix decreasing it/s with multi-grad hook
Co-authored-by: Jie Wang <[email protected]>
Hi @chrisxcai! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
If optimize_backward_concat is set to be True, only let the backward() pass propagate to FSDP.flat_params, which will
invoke the FSDP. _post_backward_hook() and concat() op, when FSDP._require_backward_grad_sync
is True (e.g. last microbatch)
Trace comparison
trace before change (SplitWithSizesBackward triggered every microbatch per FSDP module):
https://fburl.com/perfdoctor/qdt32ibh
trace with applied change (SplitWithSizesBackward triggered only in last microbatch per FSDP module):
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.229652302632210.json.gz&bucket=acadia
numerics verification
local run with deterministic mode
TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, fp8 (no 1F1B) (loss bitwise on par)
baseline
https://www.internalfb.com/intern/paste/P1363180533/
test
https://www.internalfb.com/intern/paste/P1363177870/
TP=2, GPU=8, DP = 4, BF16, non-PP microbatching (loss bitwise on par)
baseline:
https://www.internalfb.com/intern/paste/P1322976356/
test :
https://www.internalfb.com/intern/paste/P1322871976/
TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 2, DP = 2, BF16 (no 1F1B) (loss bitwise on par)
baseline
https://www.internalfb.com/intern/paste/P1358660231/
test
https://www.internalfb.com/intern/paste/P1358659328/
TP=2, PP=2, num_layers_per_virtual_pipeline_stage=4, 8 GPUs, batch_size 4, DP = 2, BF16 (1F1B) (loss bitwise on par)
baseline
https://www.internalfb.com/intern/paste/P1358780690
test
https://www.internalfb.com/intern/paste/P1358786994/
E2E MAST tests:
model = small, TP = 2, PP = 2, DP = 2 (loss on par)
baseline:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-tl66r0qd
test:
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-km46966
Perf evaluation
model= llama3_kv8_balance2_ffn12, n_layers = 1, non-PP microbatching, bs = 128, fp8, TP 4, CP = 8
baseline:
e2e TFLOPS/s: 339.53
comp TFLOPS/s: 625.64
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-f7cdn9q
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.27299292624533.json.gz&bucket=acadia
test:
e2e TFLOPS/s: 387.98 (~15%)
comp TFLOPS/s: 817.5 (~30%)
https://www.internalfb.com/mlhub/pipelines/runs/mast/conda-xlformers-t56xpf
trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.71951644521316.json.gz&bucket=acadia