diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 474f0a95b9..d28f3eea64 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -318,13 +318,12 @@ def forward(self, x): class TorchGPT(nn.Module): - def __init__(self, hidden_size: int, eps: float, num_attention_heads: int): + def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool): super().__init__() self.ln = nn.LayerNorm(hidden_size, eps=eps) self.causal_attn = TorchMHA(hidden_size, num_attention_heads) self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps) - self.resid_attn_dropout = nn.Dropout(0.1) - self.resid_mlp_dropout = nn.Dropout(0.1) + self.parallel_attention_mlp = parallel_attention_mlp def forward( self, @@ -333,12 +332,17 @@ def forward( ) -> torch.Tensor: a = self.ln(x) b = self.causal_attn(a, attn_mask) - x = x + self.resid_attn_dropout(b) - n = self.ln_mlp(x) - x = x + self.resid_mlp_dropout(n) + if self.parallel_attention_mlp: + n = self.ln_mlp(x) + x = x + nn.functional.dropout(b + n, p=0.1, training=self.training) + else: + x = x + nn.functional.dropout(b, p=0.1, training=self.training) + n = self.ln_mlp(x) + x = x + nn.functional.dropout(n, p=0.1, training=self.training) return x + def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -619,7 +623,8 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) -def test_gpt_accuracy(dtype, bs, model): +@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) +def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] te_gpt = ( @@ -632,6 +637,7 @@ def test_gpt_accuracy(dtype, bs, model): hidden_dropout=0.1, fuse_qkv_params=True, qkv_weight_interleaved=False, + parallel_attention_mlp=parallel_attention_mlp, ) .to(dtype=dtype) .cuda() @@ -643,6 +649,7 @@ def test_gpt_accuracy(dtype, bs, model): config.hidden_size, config.eps, config.num_attention_heads, + parallel_attention_mlp=parallel_attention_mlp, ) .to(dtype=dtype) .cuda() diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index a5728a05fc..da4714c7ea 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -441,9 +441,10 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("normalization", all_normalizations) +@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, bias, activation, - normalization): + normalization, parallel_attention_mlp): if fp8_recipe is not None and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -473,6 +474,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, bias=bias, activation=activation, normalization=normalization, + parallel_attention_mlp=parallel_attention_mlp, ) .to(dtype=dtype) .cuda() diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index c21be000e3..dd86260f9f 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -115,6 +115,11 @@ class TransformerLayer(torch.nn.Module): if set to `True`, layer normalization is applied on the output side, after the final dropout-add. default behavior is to apply layer normalization on the input side, before the QKV transformation. + parallel_attention_mlp: bool, default = `False` + if set to `True`, self-attention and feedforward network are computed + based on the same input (in parallel) instead of sequentially. + Both blocks have an independent normalization. + This architecture is used in `Falcon` models. layer_type: {'encoder', 'decoder'}, default = `encoder` if set to `decoder`, an additional cross-attn block is added after self-attn. This can be used for structures like `T5` Transformer in conjunction with the @@ -224,6 +229,7 @@ def __init__( sequence_parallel: bool = False, apply_residual_connection_post_layernorm: bool = False, output_layernorm: bool = False, + parallel_attention_mlp: bool = False, layer_type: str = "encoder", drop_path_rate: float = 0.0, set_parallel_mode: bool = False, @@ -274,6 +280,18 @@ def __init__( apply_residual_connection_post_layernorm ) + if parallel_attention_mlp: + assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'" + assert ( + not self.apply_residual_connection_post_layernorm + ), "parallel_attention and apply_residual_connection_post_layernorm "\ + "not supported simultaneously." + assert ( + not self.output_layernorm + ), "parallel_attention and output_layernorm not supported simultaneously" + + self.parallel_attention_mlp = parallel_attention_mlp + assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" if not fuse_qkv_params: @@ -336,7 +354,7 @@ def __init__( input_layernorm=not output_layernorm, attention_type="self", bias=bias, - return_bias=True, + return_bias=not self.parallel_attention_mlp, normalization=normalization, device=device, ) @@ -370,7 +388,7 @@ def __init__( init_method=init_method, output_layer_init_method=output_layer_init_method, bias=bias, - return_bias=True, + return_bias=not self.parallel_attention_mlp, sequence_parallel=self.sequence_parallel, params_dtype=params_dtype, return_layernorm_output=apply_residual_connection_post_layernorm, @@ -578,41 +596,19 @@ def forward( if self.apply_residual_connection_post_layernorm and not self.output_layernorm: attention_output, attention_bias, residual = self_attention_outputs - else: + hidden_states = self._bias_dropout_add( + attention_output, attention_bias, residual, self.drop_path + ) + elif not self.parallel_attention_mlp: attention_output, attention_bias = self_attention_outputs - residual = hidden_states - - # Set BDA func. - if self.bias_dropout_fusion: - if self.training: - bias_dropout_add_func = bias_dropout_add_fused_train - else: - bias_dropout_add_func = bias_dropout_add_fused_inference - else: - bias_dropout_add_func = get_bias_dropout_add(self.training) - - # Bias dropoout add. - if self.drop_path is None and attention_bias.numel() != 0: - with self.bias_dropout_add_exec_handler(): - bda_output = bias_dropout_add_func( - attention_output, attention_bias, residual, self.hidden_dropout - ) - else: - if attention_bias.numel() != 0: - attention_output = attention_output + attention_bias - out = torch.nn.functional.dropout( - attention_output, - p=self.hidden_dropout, - training=self.training, + hidden_states = self._bias_dropout_add( + attention_output, attention_bias, hidden_states, self.drop_path ) - if self.drop_path is not None: - out = self.drop_path(out) - bda_output = residual + out # Cross attention. if self.layer_type == "decoder": inter_attention_outputs = self.inter_attention( - bda_output, + hidden_states, attention_mask=enc_dec_attn_mask, attn_mask_type=self_attn_mask_type, encoder_output=encoder_output, @@ -626,49 +622,54 @@ def forward( attention_output, attention_bias, residual = inter_attention_outputs else: attention_output, attention_bias = inter_attention_outputs - residual = bda_output + residual = hidden_states + + hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual) - if attention_bias.numel() != 0: - with self.bias_dropout_add_exec_handler(): - bda_output = bias_dropout_add_func( - attention_output, attention_bias, residual, self.hidden_dropout - ) - else: - out = torch.nn.functional.dropout( - attention_output, - p=self.hidden_dropout, - training=self.training, - ) - bda_output = residual + out # MLP. mlp_outputs = self.layernorm_mlp( - bda_output, is_first_microbatch=is_first_microbatch + hidden_states, is_first_microbatch=is_first_microbatch ) if self.apply_residual_connection_post_layernorm: mlp_output, mlp_bias, residual = mlp_outputs + output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path) + elif self.parallel_attention_mlp: + output = self._bias_dropout_add( + self_attention_outputs, mlp_outputs, hidden_states, self.drop_path + ) else: mlp_output, mlp_bias = mlp_outputs - residual = bda_output + output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path) + + # For BERT like architectures. + if self.output_layernorm: + output = self.layernorm(output) + + # output: [s, b, h] + return output + + def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None): + if drop_path is None and bias.numel() != 0: + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) - # Bias dropoout add. - if self.drop_path is None and mlp_bias.numel() != 0: with self.bias_dropout_add_exec_handler(): output = bias_dropout_add_func( - mlp_output, mlp_bias, residual, self.hidden_dropout + hidden_state, bias, residual, self.hidden_dropout ) else: - if mlp_bias.numel() != 0: - mlp_output = mlp_output + mlp_bias + if bias.numel() != 0: + hidden_state = hidden_state + bias out = torch.nn.functional.dropout( - mlp_output, p=self.hidden_dropout, training=self.training + hidden_state, p=self.hidden_dropout, training=self.training ) - if self.drop_path is not None: - out = self.drop_path(out) + if drop_path is not None: + out = drop_path(out) output = residual + out - # For BERT like architectures. - if self.output_layernorm: - output = self.layernorm(output) - - # output: [s, b, h] return output