Skip to content

Commit 62090b8

Browse files
committed
Fix the wrong order of parameters to LN/RMSN bwd in ln_mlp_fp8.
Signed-off-by: Ming Huang <[email protected]>
1 parent 784749a commit 62090b8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

transformer_engine/jax/mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,16 +415,16 @@ def _fp8_mlp_bwd(
415415

416416
if layernorm_type == 'layernorm':
417417
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad_1,
418+
inputs_,
418419
mu,
419420
rsigma,
420-
inputs_,
421421
gamma,
422422
zero_centered_gamma=zero_centered_gamma,
423423
epsilon=epsilon)
424424
else:
425425
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
426426
"if layernorm_type is 'rmsnorm'"
427-
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon)
427+
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, inputs_, rsigma, gamma, epsilon=epsilon)
428428
grad_beta = None
429429

430430
amax = amax.at[gemm1_input_idx, 0].set(ln_out_amax[0])

0 commit comments

Comments
 (0)