@@ -883,6 +883,10 @@ class TransformerLayer(nn.Module):
883
883
Dimensions that will share the same dropout mask for hidden
884
884
attention_dropout: float, default = 0.1
885
885
Dropout probability for the dropout op during multi-head attention.
886
+ intermediate_dropout: float, default = 0.1
887
+ Dropout probability for the dropout op after FC1 layer.
888
+ intermediate_dropout_dims: Sequence[int], default = ()
889
+ Dimensions that will share the same dropout mask for hidden after FC1 layer.
886
890
dropout_rng_name: str, default = 'dropout'
887
891
The key in given RNGs via flax.linen.Module.apply that for
888
892
generating Dropout masks in the Multi-Head Attention.
@@ -963,6 +967,8 @@ class TransformerLayer(nn.Module):
963
967
hidden_dropout : float = 0.1
964
968
hidden_dropout_dims : Sequence [int ] = ()
965
969
attention_dropout : float = 0.1
970
+ intermediate_dropout : float = 0.1
971
+ intermediate_dropout_dims : Sequence [int ] = ()
966
972
dropout_rng_name : str = 'dropout'
967
973
mha_kernel_init : Initializer = None
968
974
mlp_kernel_init : Initializer = None
@@ -1078,6 +1084,8 @@ def __call__(self,
1078
1084
else :
1079
1085
mha_name = 'self_attention'
1080
1086
1087
+ inputs = _with_sharding_constraint (inputs , (BATCH_AXES , SEQLEN_AXES , HIDDEN_AXES ))
1088
+
1081
1089
# [batch, length, emb_dim] -> [batch, length, emb_dim]
1082
1090
x , residual = MultiHeadAttention (
1083
1091
num_heads = self .num_attention_heads ,
@@ -1113,14 +1121,15 @@ def hidden_dropout(x, deterministic):
1113
1121
assert - x_shape_len <= dims < x_shape_len
1114
1122
1115
1123
return nn .Dropout (rate = self .hidden_dropout ,
1116
- broadcast_dims = self .hidden_dropout_dims )( x ,
1117
- deterministic = deterministic )
1124
+ broadcast_dims = self .hidden_dropout_dims ,
1125
+ rng_collection = self . dropout_rng_name )( x , deterministic = deterministic )
1118
1126
1119
1127
x = hidden_dropout (x , deterministic )
1120
1128
if self .drop_path > 0.0 :
1121
1129
drop_path_shape = _generate_drop_path_shape (x .shape , batch_dim )
1122
1130
x = nn .Dropout (rate = self .drop_path ,
1123
- broadcast_dims = drop_path_shape )(x , deterministic = deterministic )
1131
+ broadcast_dims = drop_path_shape ,
1132
+ rng_collection = self .dropout_rng_name )(x , deterministic = deterministic )
1124
1133
x = x + residual
1125
1134
1126
1135
mlp_input = x
@@ -1156,6 +1165,8 @@ def hidden_dropout(x, deterministic):
1156
1165
y = hidden_dropout (y , deterministic )
1157
1166
mlp_input = y + residual
1158
1167
1168
+ mlp_input = _with_sharding_constraint (mlp_input , (BATCH_AXES , SEQLEN_AXES , HIDDEN_AXES ))
1169
+
1159
1170
# MlpBlock
1160
1171
residual = mlp_input
1161
1172
z , ln_out = LayerNormMLP (
@@ -1167,8 +1178,9 @@ def hidden_dropout(x, deterministic):
1167
1178
return_layernorm_output = self .apply_residual_connection_post_layernorm ,
1168
1179
intermediate_dim = self .mlp_hidden_size ,
1169
1180
activations = self .mlp_activations ,
1170
- intermediate_dropout_rate = self .hidden_dropout ,
1171
- intermediate_hidden_dropout_dims = self .hidden_dropout_dims ,
1181
+ intermediate_dropout_rng_name = self .dropout_rng_name ,
1182
+ intermediate_dropout_rate = self .intermediate_dropout ,
1183
+ intermediate_hidden_dropout_dims = self .intermediate_dropout_dims ,
1172
1184
dtype = self .dtype ,
1173
1185
scale_axes = (W_NO_SHARD_AXES ,),
1174
1186
ln_bias_axes = (W_NO_SHARD_AXES ,),
0 commit comments