Skip to content

Commit

Permalink
Merge branch 'apex_ln_fix' into 'main'
Browse files Browse the repository at this point in the history
Fix for newer apex version

See merge request ADLR/megatron-lm!1021
  • Loading branch information
jaredcasper committed Dec 20, 2023
2 parents f64f91e + 1524ddc commit 2bc6cd3
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
HAVE_PERSIST_LAYER_NORM = False

try:
from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
from apex.normalization.fused_layer_norm import fused_layer_norm_affine
except:
FusedLayerNormAffineFunction = None
fused_layer_norm_affine = None

global fused_layer_norm_cuda
fused_layer_norm_cuda = None
Expand Down Expand Up @@ -79,9 +79,9 @@ def forward(self, input):
weight = self.weight + 1 if self.apply_layernorm_1p else self.weight

if self.no_persist_layer_norm:
assert FusedLayerNormAffineFunction is not None, \
"FusedLayerNormAffineFunction is not available, please install apex from https://github.com/NVIDIA/apex"
return FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps)
assert fused_layer_norm_affine is not None, \
"fused_layer_norm_affine is not available, please install apex from https://github.com/NVIDIA/apex"
return fused_layer_norm_affine(input, weight, self.bias, self.normalized_shape, eps=self.eps)
else:
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)

Expand Down

0 comments on commit 2bc6cd3

Please sign in to comment.