Skip to content

Commit 80222dc

Browse files
mingxu1067ksivaman
andauthored
[JAX] Enhance Dropout in TransformerLayer. (#444)
* [JAX] Enhance Dropout in TransformerLayer. 1. Fixed missing setup of dropout RNG key in TransformerLayer and LayerNormMLP. 2. Allowing seperated dropout rate for FC1's output and other hiddens. Signed-off-by: Ming Huang <[email protected]> * Fix wrong fp8 scale in _update_fp8_metas_impl Signed-off-by: Ming Huang <[email protected]> * Fix typo Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 958e188 commit 80222dc

File tree

7 files changed

+44
-14
lines changed

7 files changed

+44
-14
lines changed

tests/jax/test_helper.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,10 @@ def get_fp8_scale(fp8_max, amax, scale):
7272
amax = np.array(amax)
7373
scale = np.array(scale)
7474

75-
exp = np.floor(np.log2(fp8_max / amax)) - FP8Helper.MARGIN
76-
sf = np.round(np.power(2, np.abs(exp)))
77-
sf = np.where(amax > 0.0, sf, scale)
78-
sf = np.where(np.isfinite(amax), sf, scale)
79-
return np.where(exp < 0, 1 / sf, sf)
75+
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
76+
sf = jnp.where(amax > 0.0, sf, scale)
77+
sf = jnp.where(jnp.isfinite(amax), sf, scale)
78+
return sf
8079

8180
amax_meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_LEN)
8281
scale_meta_shape = (num_of_meta, 1)

tests/jax/test_layer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,15 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
167167
if k == 'dropout_rate':
168168
te_layer_attrs['attention_dropout'] = v
169169
te_layer_attrs['hidden_dropout'] = v
170+
te_layer_attrs['intermediate_dropout'] = v
170171
elif k == 'fuse_mlp_wi':
171172
continue
172173
else:
173174
te_layer_attrs[k] = v
174175
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
175176
layer_cls = partial(TransformerLayer,
176177
hidden_dropout_dims=(sequence_dim,),
178+
intermediate_dropout_dims=(sequence_dim,),
177179
layer_type=TransformerLayerType.ENCODER,
178180
self_attn_mask_type='padding',
179181
dtype=dtype,
@@ -212,13 +214,15 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-
212214
if k == 'dropout_rate':
213215
te_layer_attrs['attention_dropout'] = v
214216
te_layer_attrs['hidden_dropout'] = v
217+
te_layer_attrs['intermediate_dropout'] = v
215218
elif k == 'fuse_mlp_wi':
216219
continue
217220
else:
218221
te_layer_attrs[k] = v
219222
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
220223
layer_cls = partial(TransformerLayer,
221224
hidden_dropout_dims=(sequence_dim,),
225+
intermediate_dropout_dims=(sequence_dim,),
222226
layer_type=TransformerLayerType.ENCODER,
223227
self_attn_mask_type='padding',
224228
dtype=dtype,
@@ -381,13 +385,15 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
381385
if k == 'dropout_rate':
382386
te_layer_attrs['attention_dropout'] = v
383387
te_layer_attrs['hidden_dropout'] = v
388+
te_layer_attrs['intermediate_dropout'] = v
384389
elif k == 'fuse_mlp_wi':
385390
continue
386391
else:
387392
te_layer_attrs[k] = v
388393
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
389394
layer_cls = partial(TransformerLayer,
390395
hidden_dropout_dims=(sequence_dim,),
396+
intermediate_dropout_dims=(sequence_dim,),
391397
layer_type=TransformerLayerType.DECODER,
392398
dtype=dtype,
393399
**te_layer_attrs)
@@ -426,13 +432,15 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-
426432
if k == 'dropout_rate':
427433
te_layer_attrs['attention_dropout'] = v
428434
te_layer_attrs['hidden_dropout'] = v
435+
te_layer_attrs['intermediate_dropout'] = v
429436
elif k == 'fuse_mlp_wi':
430437
continue
431438
else:
432439
te_layer_attrs[k] = v
433440
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
434441
layer_cls = partial(TransformerLayer,
435442
hidden_dropout_dims=(sequence_dim,),
443+
intermediate_dropout_dims=(sequence_dim,),
436444
layer_type=TransformerLayerType.DECODER,
437445
dtype=dtype,
438446
**te_layer_attrs)

tests/jax/test_praxis_layers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
957957
layernorm_type = attrs[TransformerLayerAttr.LN_TYPE]
958958
hidden_dropout = 0.0
959959
attention_dropout = 0.0
960+
intermediate_dropout = 0.0
960961
mlp_activations = attrs[TransformerLayerAttr.ACTIVATION]
961962
kernel_init = WeightInit.Gaussian(1.0)
962963
use_bias = attrs[TransformerLayerAttr.USE_BIAS]
@@ -991,6 +992,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
991992
layernorm_type=layernorm_type,
992993
hidden_dropout=hidden_dropout,
993994
attention_dropout=attention_dropout,
995+
intermediate_dropout=intermediate_dropout,
994996
mlp_activations=mlp_activations,
995997
use_bias=use_bias,
996998
bias_init=bias_init,
@@ -1007,6 +1009,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
10071009
layernorm_type=layernorm_type,
10081010
hidden_dropout=hidden_dropout,
10091011
attention_dropout=attention_dropout,
1012+
intermediate_dropout=intermediate_dropout,
10101013
mlp_activations=mlp_activations,
10111014
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
10121015
"mha_kernel", kernel_init),

transformer_engine/jax/flax/module.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,8 @@ class LayerNormMLP(TransformerEngineBase):
739739
activations: Sequence[Union[str, Callable]], default = ('relu',)
740740
The sequence of activation functions to apply after the first linear transformation.
741741
Each activation has its own transformation layer.
742+
intermediate_dropout_rng_name: str, default = 'dropout'
743+
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
742744
intermediate_dropout_rate: float, default = 0.1
743745
Dropout probability for the dropout op after the :attr:`activations`.
744746
intermediate_hidden_dropout_dims: Sequence[int], default = ()
@@ -779,6 +781,7 @@ class LayerNormMLP(TransformerEngineBase):
779781
bias_axes_2: Tuple[str, ...] = ('embed',)
780782
return_layernorm_output: bool = True
781783
activations: Sequence[Union[str, Callable]] = ('relu',)
784+
intermediate_dropout_rng_name: str = 'dropout'
782785
intermediate_dropout_rate: float = 0.1
783786
intermediate_hidden_dropout_dims: Sequence[int] = ()
784787
axis: Union[Iterable[int], int] = -1
@@ -985,7 +988,8 @@ def fp8_meta_generator():
985988
z = jnp.reshape(z, (*z.shape[:-2], -1))
986989

987990
z = nn.Dropout(rate=self.intermediate_dropout_rate,
988-
broadcast_dims=self.intermediate_hidden_dropout_dims)(
991+
broadcast_dims=self.intermediate_hidden_dropout_dims,
992+
rng_collection=self.intermediate_dropout_rng_name)(
989993
z, deterministic=deterministic)
990994

991995
# DenseGeneral 2

transformer_engine/jax/flax/transformer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,10 @@ class TransformerLayer(nn.Module):
883883
Dimensions that will share the same dropout mask for hidden
884884
attention_dropout: float, default = 0.1
885885
Dropout probability for the dropout op during multi-head attention.
886+
intermediate_dropout: float, default = 0.1
887+
Dropout probability for the dropout op after FC1 layer.
888+
intermediate_dropout_dims: Sequence[int], default = ()
889+
Dimensions that will share the same dropout mask for hidden after FC1 layer.
886890
dropout_rng_name: str, default = 'dropout'
887891
The key in given RNGs via flax.linen.Module.apply that for
888892
generating Dropout masks in the Multi-Head Attention.
@@ -963,6 +967,8 @@ class TransformerLayer(nn.Module):
963967
hidden_dropout: float = 0.1
964968
hidden_dropout_dims: Sequence[int] = ()
965969
attention_dropout: float = 0.1
970+
intermediate_dropout: float = 0.1
971+
intermediate_dropout_dims: Sequence[int] = ()
966972
dropout_rng_name: str = 'dropout'
967973
mha_kernel_init: Initializer = None
968974
mlp_kernel_init: Initializer = None
@@ -1078,6 +1084,8 @@ def __call__(self,
10781084
else:
10791085
mha_name = 'self_attention'
10801086

1087+
inputs = _with_sharding_constraint(inputs, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))
1088+
10811089
# [batch, length, emb_dim] -> [batch, length, emb_dim]
10821090
x, residual = MultiHeadAttention(
10831091
num_heads=self.num_attention_heads,
@@ -1113,14 +1121,15 @@ def hidden_dropout(x, deterministic):
11131121
assert -x_shape_len <= dims < x_shape_len
11141122

11151123
return nn.Dropout(rate=self.hidden_dropout,
1116-
broadcast_dims=self.hidden_dropout_dims)(x,
1117-
deterministic=deterministic)
1124+
broadcast_dims=self.hidden_dropout_dims,
1125+
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
11181126

11191127
x = hidden_dropout(x, deterministic)
11201128
if self.drop_path > 0.0:
11211129
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
11221130
x = nn.Dropout(rate=self.drop_path,
1123-
broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
1131+
broadcast_dims=drop_path_shape,
1132+
rng_collection=self.dropout_rng_name)(x, deterministic=deterministic)
11241133
x = x + residual
11251134

11261135
mlp_input = x
@@ -1156,6 +1165,8 @@ def hidden_dropout(x, deterministic):
11561165
y = hidden_dropout(y, deterministic)
11571166
mlp_input = y + residual
11581167

1168+
mlp_input = _with_sharding_constraint(mlp_input, (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES))
1169+
11591170
# MlpBlock
11601171
residual = mlp_input
11611172
z, ln_out = LayerNormMLP(
@@ -1167,8 +1178,9 @@ def hidden_dropout(x, deterministic):
11671178
return_layernorm_output=self.apply_residual_connection_post_layernorm,
11681179
intermediate_dim=self.mlp_hidden_size,
11691180
activations=self.mlp_activations,
1170-
intermediate_dropout_rate=self.hidden_dropout,
1171-
intermediate_hidden_dropout_dims=self.hidden_dropout_dims,
1181+
intermediate_dropout_rng_name=self.dropout_rng_name,
1182+
intermediate_dropout_rate=self.intermediate_dropout,
1183+
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
11721184
dtype=self.dtype,
11731185
scale_axes=(W_NO_SHARD_AXES,),
11741186
ln_bias_axes=(W_NO_SHARD_AXES,),

transformer_engine/jax/fp8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,11 @@ def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection:
310310
amax = fp8_meta_arrays[fp8_amax_idx][..., 0:1]
311311
scale = fp8_meta_arrays[fp8_scale_idx]
312312

313-
sf = (fp8_max / amax) / (2 ** FP8Helper.MARGIN)
313+
sf = (fp8_max / amax) / (2**FP8Helper.MARGIN)
314314
sf = jnp.where(amax > 0.0, sf, scale)
315315
sf = jnp.where(jnp.isfinite(amax), sf, scale)
316-
fp8_meta_arrays[fp8_scale_idx] = scale
317-
fp8_meta_arrays[fp8_scale_inv_idx] = 1 / scale
316+
fp8_meta_arrays[fp8_scale_idx] = sf
317+
fp8_meta_arrays[fp8_scale_inv_idx] = 1 / sf
318318

319319
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
320320

transformer_engine/jax/praxis/transformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ class TransformerLayer(TransformerEngineBaseLayer):
137137
hidden_dropout: float = 0.1
138138
hidden_dropout_dims: Sequence[int] = ()
139139
attention_dropout: float = 0.1
140+
intermediate_dropout: float = 0.1
141+
intermediate_dropout_dims: Sequence[int] = ()
140142
dropout_rng_name: str = 'dropout'
141143
mlp_activations: Sequence[str] = ('relu',)
142144
use_bias: bool = False
@@ -190,6 +192,8 @@ def setup(self) -> None:
190192
hidden_dropout=self.hidden_dropout,
191193
hidden_dropout_dims=self.hidden_dropout_dims,
192194
attention_dropout=self.attention_dropout,
195+
intermediate_dropout=self.intermediate_dropout,
196+
intermediate_dropout_dims=self.intermediate_dropout_dims,
193197
dropout_rng_name=self.dropout_rng_name,
194198
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
195199
"mha_kernel", self.params_init),

0 commit comments

Comments
 (0)