Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] TransformerLayer: add support for Falcon architecture #513

Merged
merged 7 commits into from
Dec 4, 2023
21 changes: 14 additions & 7 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,12 @@ def forward(self, x):


class TorchGPT(nn.Module):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool):
super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
self.resid_attn_dropout = nn.Dropout(0.1)
self.resid_mlp_dropout = nn.Dropout(0.1)
self.parallel_attention_mlp = parallel_attention_mlp

def forward(
self,
Expand All @@ -333,12 +332,17 @@ def forward(
) -> torch.Tensor:
a = self.ln(x)
b = self.causal_attn(a, attn_mask)
x = x + self.resid_attn_dropout(b)
n = self.ln_mlp(x)
x = x + self.resid_mlp_dropout(n)
if self.parallel_attention_mlp:
n = self.ln_mlp(x)
x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
else:
x = x + nn.functional.dropout(b, p=0.1, training=self.training)
n = self.ln_mlp(x)
x = x + nn.functional.dropout(n, p=0.1, training=self.training)
return x



def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
Expand Down Expand Up @@ -619,7 +623,8 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_accuracy(dtype, bs, model):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]

te_gpt = (
Expand All @@ -632,6 +637,7 @@ def test_gpt_accuracy(dtype, bs, model):
hidden_dropout=0.1,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
Expand All @@ -643,6 +649,7 @@ def test_gpt_accuracy(dtype, bs, model):
config.hidden_size,
config.eps,
config.num_attention_heads,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
Expand Down
4 changes: 3 additions & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,10 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization):
normalization, parallel_attention_mlp):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)

Expand Down Expand Up @@ -473,6 +474,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
bias=bias,
activation=activation,
normalization=normalization,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
Expand Down
119 changes: 60 additions & 59 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class TransformerLayer(torch.nn.Module):
if set to `True`, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
parallel_attention_mlp: bool, default = `False`
if set to `True`, self-attention and feedforward network are computed
based on the same input (in parallel) instead of sequentially.
Both blocks have an independent normalization.
This architecture is used in `Falcon` models.
layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the
Expand Down Expand Up @@ -224,6 +229,7 @@ def __init__(
sequence_parallel: bool = False,
apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False,
parallel_attention_mlp: bool = False,
layer_type: str = "encoder",
drop_path_rate: float = 0.0,
set_parallel_mode: bool = False,
Expand Down Expand Up @@ -274,6 +280,18 @@ def __init__(
apply_residual_connection_post_layernorm
)

if parallel_attention_mlp:
assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
assert (
not self.apply_residual_connection_post_layernorm
), "parallel_attention and apply_residual_connection_post_layernorm "\
"not supported simultaneously."
assert (
not self.output_layernorm
), "parallel_attention and output_layernorm not supported simultaneously"

self.parallel_attention_mlp = parallel_attention_mlp

assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

if not fuse_qkv_params:
Expand Down Expand Up @@ -336,7 +354,7 @@ def __init__(
input_layernorm=not output_layernorm,
attention_type="self",
bias=bias,
return_bias=True,
return_bias=not self.parallel_attention_mlp,
normalization=normalization,
device=device,
)
Expand Down Expand Up @@ -370,7 +388,7 @@ def __init__(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
bias=bias,
return_bias=True,
return_bias=not self.parallel_attention_mlp,
sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype,
return_layernorm_output=apply_residual_connection_post_layernorm,
Expand Down Expand Up @@ -578,41 +596,19 @@ def forward(

if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
attention_output, attention_bias, residual = self_attention_outputs
else:
hidden_states = self._bias_dropout_add(
attention_output, attention_bias, residual, self.drop_path
)
elif not self.parallel_attention_mlp:
attention_output, attention_bias = self_attention_outputs
residual = hidden_states

# Set BDA func.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)

# Bias dropoout add.
if self.drop_path is None and attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
if attention_bias.numel() != 0:
attention_output = attention_output + attention_bias
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
hidden_states = self._bias_dropout_add(
attention_output, attention_bias, hidden_states, self.drop_path
)
if self.drop_path is not None:
out = self.drop_path(out)
bda_output = residual + out

# Cross attention.
if self.layer_type == "decoder":
inter_attention_outputs = self.inter_attention(
bda_output,
hidden_states,
attention_mask=enc_dec_attn_mask,
attn_mask_type=self_attn_mask_type,
encoder_output=encoder_output,
Expand All @@ -626,49 +622,54 @@ def forward(
attention_output, attention_bias, residual = inter_attention_outputs
else:
attention_output, attention_bias = inter_attention_outputs
residual = bda_output
residual = hidden_states

hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)

if attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
)
bda_output = residual + out
# MLP.
mlp_outputs = self.layernorm_mlp(
bda_output, is_first_microbatch=is_first_microbatch
hidden_states, is_first_microbatch=is_first_microbatch
)
if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs
output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
elif self.parallel_attention_mlp:
output = self._bias_dropout_add(
self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
)
else:
mlp_output, mlp_bias = mlp_outputs
residual = bda_output
output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)

# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)

# output: [s, b, h]
return output

def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
if drop_path is None and bias.numel() != 0:
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)

# Bias dropoout add.
if self.drop_path is None and mlp_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output, mlp_bias, residual, self.hidden_dropout
hidden_state, bias, residual, self.hidden_dropout
)
else:
if mlp_bias.numel() != 0:
mlp_output = mlp_output + mlp_bias
if bias.numel() != 0:
hidden_state = hidden_state + bias
out = torch.nn.functional.dropout(
mlp_output, p=self.hidden_dropout, training=self.training
hidden_state, p=self.hidden_dropout, training=self.training
)
if self.drop_path is not None:
out = self.drop_path(out)
if drop_path is not None:
out = drop_path(out)
output = residual + out

# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)

# output: [s, b, h]
return output