Skip to content

Commit a63ec56

Browse files
committed
Fix the bug to enable layernorm_geglu_fp8_dot in LayernormMlp
Signed-off-by: Ming Huang <[email protected]>
1 parent d16843f commit a63ec56

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

transformer_engine/jax/flax/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ def is_geglu(acts):
809809
if not isinstance(act, str):
810810
return False
811811
normalize_acts.append(act.lower())
812-
return normalize_acts in geglu_act_pool
812+
return tuple(normalize_acts) in geglu_act_pool
813813

814814
use_fused_ln_mlp = fuse_layernorm \
815815
and (not self.use_bias) and is_geglu(self.activations) \

0 commit comments

Comments
 (0)