-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PyTorch] Non-reentrant mode for activation recompute (#670)
* added non-reentrant mode support to TE checkpoint Signed-off-by: Alp Dener <[email protected]> * updated get_cuda_rng_tracker kwarg to get_rng_state_tracker to remain consistent with other TE API Signed-off-by: Alp Dener <[email protected]> * docstring cleanup Signed-off-by: Alp Dener <[email protected]> * added mechanism to disable bias_gelu_nvfusion in LayerNormMLP when checkpointing in non-reentrant mode Signed-off-by: Alp Dener <[email protected]> * refactored checkpoint and recompute hook names to match PyTorch implementation Signed-off-by: Alp Dener <[email protected]> * Fixed incorrect reference before assignment Signed-off-by: Alp Dener <[email protected]> * fixed argument error in calling native PyTorch checkpoint Signed-off-by: Alp Dener <[email protected]> * fixed linting errors for missing docstrings Signed-off-by: Alp Dener <[email protected]> * Fix lint Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * bias GELU fusion consistency between checkpoint test and reference comparison Signed-off-by: Alp Dener <[email protected]> --------- Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
- Loading branch information
Showing
4 changed files
with
358 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.