From a51815120c477e8b35dd00a7404079cd25c687eb Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 16 Oct 2024 08:27:45 -0700 Subject: [PATCH] [PyTorch] Fix FP8 activation recompute (#1254) Fix FP8 activation recompute Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/distributed.py | 11 +++++++++++ transformer_engine/pytorch/module/layernorm_linear.py | 9 +++++---- transformer_engine/pytorch/module/layernorm_mlp.py | 4 ++++ transformer_engine/pytorch/module/linear.py | 9 +++++---- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 9ff7596669..490ac3b160 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -206,6 +206,8 @@ class activation_recompute_forward(AbstractContextManager, ContextDecorator): activations, followed by calculation of gradients using these values. """ + _is_first_fp8_module: List = [] + def __init__(self, activation_recompute: bool = False, recompute_phase: bool = False): super().__init__() self.activation_recompute = activation_recompute @@ -218,6 +220,15 @@ def __enter__(self): ) _FP8_ACTIVATION_RECOMPUTE_PHASE = self.recompute_phase + if self.activation_recompute and not self.recompute_phase: + activation_recompute_forward._is_first_fp8_module.append( + FP8GlobalStateManager.IS_FIRST_FP8_MODULE + ) + if self.activation_recompute and self.recompute_phase: + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = ( + activation_recompute_forward._is_first_fp8_module.pop(0) + ) + def __exit__(self, *exc_details): global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE _FP8_ACTIVATION_RECOMPUTE_ENABLED = False diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5969432509..6dea806993 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -36,6 +36,7 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -361,10 +362,10 @@ def forward( ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): - ctx.reduce_and_update_bwd_fp8_tensors = ( - ctx.reduce_and_update_bwd_fp8_tensors - or FP8GlobalStateManager.is_first_fp8_module() - ) + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 55a10fb666..6c1633111d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -43,6 +43,7 @@ reduce_scatter_along_first_dim, gather_along_first_dim, use_reentrant_activation_recompute, + in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -516,7 +517,10 @@ def forward( if ctx.fp8 and requires_grad( inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias ): + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear if ub_overlap_rs: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0d751caf03..f521cf4fb6 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -33,6 +33,7 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) @@ -349,10 +350,10 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, weight, bias): - ctx.reduce_and_update_bwd_fp8_tensors = ( - ctx.reduce_and_update_bwd_fp8_tensors - or FP8GlobalStateManager.is_first_fp8_module() - ) + _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE + ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() + if in_fp8_activation_recompute_phase(): + FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear if ub_overlap_rs: