diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1a651474bf..16218a4ab5 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -373,7 +373,7 @@ def forward( ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) - if not is_grad_enabled: + if not is_grad_enabled and not return_layernorm_output: clear_tensor_data(ln_out_total) if bias_gelu_nvfusion: