-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix cuda graph capture for grouped gemm #1345
Conversation
Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Xiaowei Ren <[email protected]>
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't we expect this to increase memory usage?
I see that torch.cuda.make_graphed_callables
doesn't set retain_graph=True
:
https://github.com/pytorch/pytorch/blob/c25b201583fc28243b87c460a2f18e2531a676e7/torch/cuda/graphs.py#L326-L336
We want to match plain PyTorch as much as possible unless there is a good reason to introduce divergence. If this is MoE-specific, perhaps we could add a kwarg like retain_graph_in_backward
that is False
by default.
Signed-off-by: Xiaowei Ren <[email protected]>
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Xiaowei Ren <[email protected]>
/te-ci pytorch |
Description
Cuda graph does not work with Grouped GEMM.
The saved forward activations are corrupted before bwd_graph is replayed. Explicitly setting
retain_graph=True
can hold the activations and fix the issue.Type of change
Checklist: