Skip to content

Commit

Permalink
Merge remote-tracking branch 'personal/add-parallel-attention-mlp' in…
Browse files Browse the repository at this point in the history
…to add-parallel-attention-mlp
  • Loading branch information
Marks101 committed Nov 22, 2023
2 parents b7c908b + cfe1f94 commit af7d81a
Show file tree
Hide file tree
Showing 30 changed files with 766 additions and 300 deletions.
44 changes: 27 additions & 17 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,37 +135,47 @@ Installation
----------
.. installation
In the NGC container
Pre-requisites
^^^^^^^^^^^^^^^^^^^^
* Linux x86_64
* CUDA 11.8+ for Hopper and CUDA 12.1+ for Ada
* NVIDIA Driver supporting CUDA 11.8 or later
* cuDNN 8.1 or later
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.

The quickest way to get started with Transformer Engine is the NGC PyTorch container on
`NVIDIA GPU Cloud Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_ (versions 22.09 and later).
Docker
^^^^^^^^^^^^^^^^^^^^

The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_. For example to use the NGC PyTorch container interactively,

.. code-block:: bash
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.04-py3
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3
Where 23.04 is the container version. For example, 23.04 for the April 2023 release.
Where 23.10 is the container version. For example, 23.10 for the October 2023 release.

Pre-requisites
pip
^^^^^^^^^^^^^^^^^^^^
* Linux x86_64
* CUDA 11.8 or later
* NVIDIA Driver supporting CUDA 11.8 or later
* cuDNN 8.1 or later
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
To install the latest stable version of Transformer Engine,

.. code-block:: bash
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).

From source
^^^^^^^^^^^
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_.

`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html>`_.

Compiling with Flash Attention 2
Compiling with FlashAttention-2
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance.

It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. If the errors persist, install a supported version of FlashAttention-1 (v1.0.6 to v1.0.9).

TransformerEngine release v0.11.0 adds support for Flash Attention 2.0 for improved performance. It is a known issue that Flash Attention 2.0 compilation is
resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory
errors during the installation of TransformerEngine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. If the errors persist, install a supported version of Flash Attention 1 (v1.0.6 to v1.0.9).
Note that NGC PyTorch 23.08+ containers include FlashAttention-2.

Model Support
----------
Expand Down
10 changes: 10 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@
is_fp8_supported, reason = is_fp8_available()


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


class TestFP8Dot:

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
Expand Down
11 changes: 11 additions & 0 deletions tests/jax/test_custom_call_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

import pytest
import jax
import jax.numpy as jnp
from jax.core import ShapedArray

Expand Down Expand Up @@ -31,6 +32,16 @@
TRANSPOSE = [True, False]


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


class TestGEMMShapeInfer:

@staticmethod
Expand Down
104 changes: 60 additions & 44 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,25 @@
from flax.linen import make_causal_mask
from jax import value_and_grad, jit

from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability
from transformer_engine_jax import get_device_compute_capability # pylint: disable=wrong-import-order

# Type annotations
Array = jnp.ndarray


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


class Backend(Enum):
"""
Fused attn backend.
Expand All @@ -52,6 +62,13 @@ def fixture_backend(request):
DTYPES = [jnp.bfloat16, jnp.float16]


def is_causal_mask(mask: AttnMaskType):
"""
Check if the mask is a causal mask
"""
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


def make_decoder_mask(tokens: Array) -> Array:
"""
Create padded causal mask
Expand All @@ -66,7 +83,7 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
Self attention with JAX native implementation
"""
attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
Expand All @@ -84,8 +101,8 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=qkv.dtype)
return output
dtype=jnp.float32)
return output.astype(qkv.dtype)


def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
Expand All @@ -95,7 +112,7 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
assert q.dtype == kv.dtype

attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
if is_causal_mask(attn_mask_type):
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)

Expand All @@ -112,15 +129,16 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=q.dtype)
return output
dtype=jnp.float32)
return output.astype(q.dtype)


def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self fused attention
"""
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
Expand All @@ -137,7 +155,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
"""
assert q.dtype == kv.dtype

if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)

Expand All @@ -149,41 +168,42 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)

@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('attn_mask_type', [
AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_MASK
])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0, 0.3])
class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn"""

@staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
head_dim, pad_ratio):
if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0:
pytest.skip("Arbitrary seqlen backend hasn't support padded input.")
head_dim):

if not is_fused_attn_kernel_available(dtype, dtype, attn_bias_type, attn_mask_type,
dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")
assert isinstance(backend, Backend)

compute_capability = get_device_compute_capability(0)
if (backend == Backend.Max512
and not (compute_capability == 80 or compute_capability >= 90)):
pytest.skip("Unsupported compute capability for "
"fused attention with <=512 sequence length")
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")

def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
dropout_probability, dtype, is_training):
"""Setup the test inputs"""
self.__class__._check_inputs(s,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
head_dim=d,
pad_ratio=pad_ratio)
head_dim=d)

if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pad_ratio = 0.0
else:
pad_ratio = 0.3

key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)

Expand Down Expand Up @@ -212,7 +232,7 @@ def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
self.is_training = is_training

def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
dtype, is_training, pad_ratio):
dtype, is_training):
"""
Test forward without using JIT
"""
Expand All @@ -225,8 +245,7 @@ def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, drop
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
is_training=is_training)

primitive_out = customcall_self_fused_attn(self.qkv,
self.bias,
Expand Down Expand Up @@ -265,7 +284,7 @@ def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, drop
jnp.zeros_like(pri_invalid, jnp.float32))

def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
dropout_probability, dtype, is_training):
"""
Test forward, backward, and autodiff by jax.value_and_grad
"""
Expand All @@ -281,13 +300,12 @@ def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, back
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
is_training=is_training)

def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
if is_causal_mask(attn_mask_type):
gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient
# fused_attn output has shape (b, s, h, d)
Expand Down Expand Up @@ -333,15 +351,15 @@ def grad_func(fused_attn_func, *args, **kwargs):
rtol=1e-4,
atol=1e-5)

valid_primitive_dqkv, invalid_primitive_dqkv = jnp.split(primitive_dqkv, (self.valid_len,),
axis=1)
valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,),
axis=1)
valid_primitive_dqkv, invalid_primitive_dqkv = \
jnp.split(primitive_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
valid_reference_dqkv, invalid_reference_dqkv = \
jnp.split(reference_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)

valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = jnp.split(
valid_primitive_dqkv.astype(jnp.float32), 3, axis=2)
valid_reference_dq, valid_reference_dk, valid_reference_dv = jnp.split(
valid_reference_dqkv.astype(jnp.float32), 3, axis=2)
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = \
jnp.split(valid_primitive_dqkv, 3, axis=2)
valid_reference_dq, valid_reference_dk, valid_reference_dv = \
jnp.split(valid_reference_dqkv, 3, axis=2)

np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
Expand Down Expand Up @@ -482,9 +500,7 @@ def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_prob

def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10
gradient_multiplier = 1e4
# Keep only valid result for the gradient
# fused_attn output has shape (b, s_q, h, d)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
Expand Down
10 changes: 10 additions & 0 deletions tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
is_fp8_supported, reason = is_fp8_available()


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
return jnp.mean(output)
Expand Down
10 changes: 10 additions & 0 deletions tests/jax/test_praxis_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@
FP8_FORMATS = [Format.E4M3, Format.HYBRID]


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, \
Expand Down
Loading

0 comments on commit af7d81a

Please sign in to comment.