We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d16843f commit a63ec56Copy full SHA for a63ec56
transformer_engine/jax/flax/module.py
@@ -809,7 +809,7 @@ def is_geglu(acts):
809
if not isinstance(act, str):
810
return False
811
normalize_acts.append(act.lower())
812
- return normalize_acts in geglu_act_pool
+ return tuple(normalize_acts) in geglu_act_pool
813
814
use_fused_ln_mlp = fuse_layernorm \
815
and (not self.use_bias) and is_geglu(self.activations) \
0 commit comments