From 92c1e500dd14608e54f75df8276baa1104c61d48 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 1 Dec 2023 15:26:34 -0800 Subject: [PATCH] [PyTorch] Fix incorrect variable name in LayerNormMLP backward (#548) Fix incorrect variable name in LayerNormMLP backward Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f38abec4a7..e4c5046d6e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -914,7 +914,7 @@ def backward( # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, 'grad_added_to_main_grad'): fc1_weight.grad_added_to_main_grad = True - if getattr(weight, 'zero_out_wgrad', False): + if getattr(fc1_weight, 'zero_out_wgrad', False): fc1_wgrad = torch.zeros(fc1_weight.main_grad.shape, dtype=fc1_weight.dtype, device=torch.cuda.current_device(), @@ -935,7 +935,7 @@ def backward( # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, 'grad_added_to_main_grad'): fc2_weight.grad_added_to_main_grad = True - if getattr(weight, 'zero_out_wgrad', False): + if getattr(fc2_weight, 'zero_out_wgrad', False): fc2_wgrad = torch.zeros(fc2_weight.main_grad.shape, dtype=fc2_weight.dtype, device=torch.cuda.current_device(),