Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into fused_attn/graph_api_v1
Browse files Browse the repository at this point in the history
  • Loading branch information
cyanguwa authored Dec 5, 2023
2 parents ed3358c + 5debfdb commit 8c4379c
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 70 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/googletest
Submodule googletest updated 156 files
3 changes: 1 addition & 2 deletions docs/api/common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ Common API

.. autoapiclass:: transformer_engine.common.recipe.Format

.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.E4M3, amax_history_len=1, amax_compute_algo="most_recent", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))

.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.E4M3, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))
21 changes: 14 additions & 7 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,12 @@ def forward(self, x):


class TorchGPT(nn.Module):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool):
super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
self.resid_attn_dropout = nn.Dropout(0.1)
self.resid_mlp_dropout = nn.Dropout(0.1)
self.parallel_attention_mlp = parallel_attention_mlp

def forward(
self,
Expand All @@ -333,12 +332,17 @@ def forward(
) -> torch.Tensor:
a = self.ln(x)
b = self.causal_attn(a, attn_mask)
x = x + self.resid_attn_dropout(b)
n = self.ln_mlp(x)
x = x + self.resid_mlp_dropout(n)
if self.parallel_attention_mlp:
n = self.ln_mlp(x)
x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
else:
x = x + nn.functional.dropout(b, p=0.1, training=self.training)
n = self.ln_mlp(x)
x = x + nn.functional.dropout(n, p=0.1, training=self.training)
return x



def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
Expand Down Expand Up @@ -624,7 +628,8 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_accuracy(dtype, bs, model):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]

te_gpt = (
Expand All @@ -637,6 +642,7 @@ def test_gpt_accuracy(dtype, bs, model):
hidden_dropout=0.1,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
Expand All @@ -648,6 +654,7 @@ def test_gpt_accuracy(dtype, bs, model):
config.hidden_size,
config.eps,
config.num_attention_heads,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
Expand Down
4 changes: 3 additions & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,10 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization):
normalization, parallel_attention_mlp):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)

Expand Down Expand Up @@ -473,6 +474,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
bias=bias,
activation=activation,
normalization=normalization,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
Expand Down
119 changes: 60 additions & 59 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class TransformerLayer(torch.nn.Module):
if set to `True`, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
parallel_attention_mlp: bool, default = `False`
if set to `True`, self-attention and feedforward network are computed
based on the same input (in parallel) instead of sequentially.
Both blocks have an independent normalization.
This architecture is used in `Falcon` models.
layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the
Expand Down Expand Up @@ -225,6 +230,7 @@ def __init__(
sequence_parallel: bool = False,
apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False,
parallel_attention_mlp: bool = False,
layer_type: str = "encoder",
drop_path_rate: float = 0.0,
set_parallel_mode: bool = False,
Expand Down Expand Up @@ -275,6 +281,18 @@ def __init__(
apply_residual_connection_post_layernorm
)

if parallel_attention_mlp:
assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
assert (
not self.apply_residual_connection_post_layernorm
), "parallel_attention and apply_residual_connection_post_layernorm "\
"not supported simultaneously."
assert (
not self.output_layernorm
), "parallel_attention and output_layernorm not supported simultaneously"

self.parallel_attention_mlp = parallel_attention_mlp

assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

if not fuse_qkv_params:
Expand Down Expand Up @@ -337,7 +355,7 @@ def __init__(
input_layernorm=not output_layernorm,
attention_type="self",
bias=bias,
return_bias=True,
return_bias=not self.parallel_attention_mlp,
normalization=normalization,
device=device,
)
Expand Down Expand Up @@ -371,7 +389,7 @@ def __init__(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
bias=bias,
return_bias=True,
return_bias=not self.parallel_attention_mlp,
sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype,
return_layernorm_output=apply_residual_connection_post_layernorm,
Expand Down Expand Up @@ -587,41 +605,19 @@ def forward(

if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
attention_output, attention_bias, residual = self_attention_outputs
else:
hidden_states = self._bias_dropout_add(
attention_output, attention_bias, residual, self.drop_path
)
elif not self.parallel_attention_mlp:
attention_output, attention_bias = self_attention_outputs
residual = hidden_states

# Set BDA func.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)

# Bias dropoout add.
if self.drop_path is None and attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
if attention_bias.numel() != 0:
attention_output = attention_output + attention_bias
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
hidden_states = self._bias_dropout_add(
attention_output, attention_bias, hidden_states, self.drop_path
)
if self.drop_path is not None:
out = self.drop_path(out)
bda_output = residual + out

# Cross attention.
if self.layer_type == "decoder":
inter_attention_outputs = self.inter_attention(
bda_output,
hidden_states,
attention_mask=enc_dec_attn_mask,
attn_mask_type=self_attn_mask_type,
encoder_output=encoder_output,
Expand All @@ -635,49 +631,54 @@ def forward(
attention_output, attention_bias, residual = inter_attention_outputs
else:
attention_output, attention_bias = inter_attention_outputs
residual = bda_output
residual = hidden_states

hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)

if attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
)
bda_output = residual + out
# MLP.
mlp_outputs = self.layernorm_mlp(
bda_output, is_first_microbatch=is_first_microbatch
hidden_states, is_first_microbatch=is_first_microbatch
)
if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs
output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
elif self.parallel_attention_mlp:
output = self._bias_dropout_add(
self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
)
else:
mlp_output, mlp_bias = mlp_outputs
residual = bda_output
output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)

# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)

# output: [s, b, h]
return output

def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
if drop_path is None and bias.numel() != 0:
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)

# Bias dropoout add.
if self.drop_path is None and mlp_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output, mlp_bias, residual, self.hidden_dropout
hidden_state, bias, residual, self.hidden_dropout
)
else:
if mlp_bias.numel() != 0:
mlp_output = mlp_output + mlp_bias
if bias.numel() != 0:
hidden_state = hidden_state + bias
out = torch.nn.functional.dropout(
mlp_output, p=self.hidden_dropout, training=self.training
hidden_state, p=self.hidden_dropout, training=self.training
)
if self.drop_path is not None:
out = self.drop_path(out)
if drop_path is not None:
out = drop_path(out)
output = residual + out

# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)

# output: [s, b, h]
return output

0 comments on commit 8c4379c

Please sign in to comment.