Skip to content

Commit

Permalink
Merge branch 'main' into float8tensor-pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
timmoon10 authored Dec 5, 2023
2 parents cd60e0f + 5debfdb commit 110b651
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 70 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/googletest
Submodule googletest updated 156 files
3 changes: 1 addition & 2 deletions docs/api/common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ Common API

.. autoapiclass:: transformer_engine.common.recipe.Format

.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.E4M3, amax_history_len=1, amax_compute_algo="most_recent", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))

.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.E4M3, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))
21 changes: 14 additions & 7 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,12 @@ def forward(self, x):


class TorchGPT(nn.Module):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int):
def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool):
super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
self.ln_mlp = TorchLayerNormMLP(hidden_size, 4 * hidden_size, eps)
self.resid_attn_dropout = nn.Dropout(0.1)
self.resid_mlp_dropout = nn.Dropout(0.1)
self.parallel_attention_mlp = parallel_attention_mlp

def forward(
self,
Expand All @@ -333,12 +332,17 @@ def forward(
) -> torch.Tensor:
a = self.ln(x)
b = self.causal_attn(a, attn_mask)
x = x + self.resid_attn_dropout(b)
n = self.ln_mlp(x)
x = x + self.resid_mlp_dropout(n)
if self.parallel_attention_mlp:
n = self.ln_mlp(x)
x = x + nn.functional.dropout(b + n, p=0.1, training=self.training)
else:
x = x + nn.functional.dropout(b, p=0.1, training=self.training)
n = self.ln_mlp(x)
x = x + nn.functional.dropout(n, p=0.1, training=self.training)
return x



def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
Expand Down Expand Up @@ -619,7 +623,8 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
def test_gpt_accuracy(dtype, bs, model):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]

te_gpt = (
Expand All @@ -632,6 +637,7 @@ def test_gpt_accuracy(dtype, bs, model):
hidden_dropout=0.1,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
Expand All @@ -643,6 +649,7 @@ def test_gpt_accuracy(dtype, bs, model):
config.hidden_size,
config.eps,
config.num_attention_heads,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
Expand Down
4 changes: 3 additions & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,10 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("bias", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization):
normalization, parallel_attention_mlp):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)

Expand Down Expand Up @@ -473,6 +474,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad,
bias=bias,
activation=activation,
normalization=normalization,
parallel_attention_mlp=parallel_attention_mlp,
)
.to(dtype=dtype)
.cuda()
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from jax import lax
from jax import nn as jax_nn
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name

from ..dot import type_safe_dot_general
from ..fp8 import FP8Helper, FP8MetaPackage
Expand Down Expand Up @@ -923,6 +924,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)

x = checkpoint_name(x, 'ffn1')

activations = []
if is_geglu(self.activations):
z = geglu(x)
Expand Down Expand Up @@ -957,4 +960,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
bias = bias.astype(self.dtype)
out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,))

out = checkpoint_name(out, 'ffn2')

return out, ln_output # Output, layner_norm_output
11 changes: 11 additions & 0 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jax import nn as jax_nn
from jax import random as jax_random
from jax import lax, vmap
from jax.ad_checkpoint import checkpoint_name

from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
from .module import LayerNorm, Softmax
Expand Down Expand Up @@ -211,6 +212,8 @@ def core_attention(query: Array,
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)

attn_weights = checkpoint_name(attn_weights, 'logits')

attn_weights = _with_sharding_constraint(attn_weights,
(BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES))

Expand Down Expand Up @@ -499,6 +502,7 @@ def _check_head_dim(head_dim):
bias_axes=(W_JOINED_AXES, W_TP_AXES),
name='qkv',
dtype=self.dtype)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, 'combined_qkv_proj')
if not use_fused_attn:
query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2)
else:
Expand Down Expand Up @@ -530,6 +534,7 @@ def _check_head_dim(head_dim):
bias_axes=(W_JOINED_AXES, W_TP_AXES),
name='kv',
dtype=self.dtype)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
if not use_fused_attn:
key, value = jnp.split(kv_proj, [1], axis=-2)
else:
Expand Down Expand Up @@ -574,6 +579,9 @@ def _check_head_dim(head_dim):
residual = ln_out

if not use_fused_attn:
query = checkpoint_name(query, 'query_proj')
key = checkpoint_name(key, 'key_proj')
value = checkpoint_name(value, 'value_proj')
query = query.reshape((query.shape[0], query.shape[1], self.num_heads, self.head_dim))
key = key.reshape((key.shape[0], key.shape[1], self.num_heads, self.head_dim))
value = value.reshape((value.shape[0], value.shape[1], self.num_heads, self.head_dim))
Expand Down Expand Up @@ -706,6 +714,8 @@ def convert_to_softmax_type(attn_mask_type, mask):
dtype=self.dtype,
float32_logits=self.float32_logits)

x = checkpoint_name(x, 'context')

x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))

attn_context_sharding_constraint = \
Expand All @@ -724,6 +734,7 @@ def convert_to_softmax_type(attn_mask_type, mask):
bias_axes=(W_NO_SHARD_AXES,),
dtype=self.dtype,
name='out')(x)
out = checkpoint_name(out, 'out_proj')
return out, residual


Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/jax/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from enum import Enum
from functools import partial
from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -91,6 +92,9 @@ def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.nda
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training)
output = checkpoint_name(output, 'context')
softmax_aux = checkpoint_name(softmax_aux, 'context')
rng_state = checkpoint_name(rng_state, 'context')
return output, (qkv, bias, softmax_aux, rng_state, output, squeezed_mask)


Expand Down
119 changes: 60 additions & 59 deletions transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ class TransformerLayer(torch.nn.Module):
if set to `True`, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
parallel_attention_mlp: bool, default = `False`
if set to `True`, self-attention and feedforward network are computed
based on the same input (in parallel) instead of sequentially.
Both blocks have an independent normalization.
This architecture is used in `Falcon` models.
layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the
Expand Down Expand Up @@ -224,6 +229,7 @@ def __init__(
sequence_parallel: bool = False,
apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False,
parallel_attention_mlp: bool = False,
layer_type: str = "encoder",
drop_path_rate: float = 0.0,
set_parallel_mode: bool = False,
Expand Down Expand Up @@ -274,6 +280,18 @@ def __init__(
apply_residual_connection_post_layernorm
)

if parallel_attention_mlp:
assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'"
assert (
not self.apply_residual_connection_post_layernorm
), "parallel_attention and apply_residual_connection_post_layernorm "\
"not supported simultaneously."
assert (
not self.output_layernorm
), "parallel_attention and output_layernorm not supported simultaneously"

self.parallel_attention_mlp = parallel_attention_mlp

assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"

if not fuse_qkv_params:
Expand Down Expand Up @@ -336,7 +354,7 @@ def __init__(
input_layernorm=not output_layernorm,
attention_type="self",
bias=bias,
return_bias=True,
return_bias=not self.parallel_attention_mlp,
normalization=normalization,
device=device,
)
Expand Down Expand Up @@ -370,7 +388,7 @@ def __init__(
init_method=init_method,
output_layer_init_method=output_layer_init_method,
bias=bias,
return_bias=True,
return_bias=not self.parallel_attention_mlp,
sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype,
return_layernorm_output=apply_residual_connection_post_layernorm,
Expand Down Expand Up @@ -578,41 +596,19 @@ def forward(

if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
attention_output, attention_bias, residual = self_attention_outputs
else:
hidden_states = self._bias_dropout_add(
attention_output, attention_bias, residual, self.drop_path
)
elif not self.parallel_attention_mlp:
attention_output, attention_bias = self_attention_outputs
residual = hidden_states

# Set BDA func.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)

# Bias dropoout add.
if self.drop_path is None and attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
if attention_bias.numel() != 0:
attention_output = attention_output + attention_bias
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
hidden_states = self._bias_dropout_add(
attention_output, attention_bias, hidden_states, self.drop_path
)
if self.drop_path is not None:
out = self.drop_path(out)
bda_output = residual + out

# Cross attention.
if self.layer_type == "decoder":
inter_attention_outputs = self.inter_attention(
bda_output,
hidden_states,
attention_mask=enc_dec_attn_mask,
attn_mask_type=self_attn_mask_type,
encoder_output=encoder_output,
Expand All @@ -626,49 +622,54 @@ def forward(
attention_output, attention_bias, residual = inter_attention_outputs
else:
attention_output, attention_bias = inter_attention_outputs
residual = bda_output
residual = hidden_states

hidden_states = self._bias_dropout_add(attention_output, attention_bias, residual)

if attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
)
bda_output = residual + out
# MLP.
mlp_outputs = self.layernorm_mlp(
bda_output, is_first_microbatch=is_first_microbatch
hidden_states, is_first_microbatch=is_first_microbatch
)
if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs
output = self._bias_dropout_add(mlp_output, mlp_bias, residual, self.drop_path)
elif self.parallel_attention_mlp:
output = self._bias_dropout_add(
self_attention_outputs, mlp_outputs, hidden_states, self.drop_path
)
else:
mlp_output, mlp_bias = mlp_outputs
residual = bda_output
output = self._bias_dropout_add(mlp_output, mlp_bias, hidden_states, self.drop_path)

# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)

# output: [s, b, h]
return output

def _bias_dropout_add(self, hidden_state, bias, residual, drop_path=None):
if drop_path is None and bias.numel() != 0:
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)

# Bias dropoout add.
if self.drop_path is None and mlp_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output, mlp_bias, residual, self.hidden_dropout
hidden_state, bias, residual, self.hidden_dropout
)
else:
if mlp_bias.numel() != 0:
mlp_output = mlp_output + mlp_bias
if bias.numel() != 0:
hidden_state = hidden_state + bias
out = torch.nn.functional.dropout(
mlp_output, p=self.hidden_dropout, training=self.training
hidden_state, p=self.hidden_dropout, training=self.training
)
if self.drop_path is not None:
out = self.drop_path(out)
if drop_path is not None:
out = drop_path(out)
output = residual + out

# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)

# output: [s, b, h]
return output

0 comments on commit 110b651

Please sign in to comment.