diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index bcb7bd7ecd..f076302e4e 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -19,9 +19,9 @@ HAVE_PERSIST_LAYER_NORM = False try: - from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction + from apex.normalization.fused_layer_norm import fused_layer_norm_affine except: - FusedLayerNormAffineFunction = None + fused_layer_norm_affine = None global fused_layer_norm_cuda fused_layer_norm_cuda = None @@ -79,9 +79,9 @@ def forward(self, input): weight = self.weight + 1 if self.apply_layernorm_1p else self.weight if self.no_persist_layer_norm: - assert FusedLayerNormAffineFunction is not None, \ - "FusedLayerNormAffineFunction is not available, please install apex from https://github.com/NVIDIA/apex" - return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps, False) + assert fused_layer_norm_affine is not None, \ + "fused_layer_norm_affine is not available, please install apex from https://github.com/NVIDIA/apex" + return fused_layer_norm_affine(input, weight, self.bias, self.normalized_shape, eps=self.eps) else: output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)