Skip to content

Commit cfe1f94

Browse files
authored
Merge branch 'main' into add-parallel-attention-mlp
2 parents 50baf6a + 7976bd0 commit cfe1f94

30 files changed

+766
-300
lines changed

README.rst

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,37 +135,47 @@ Installation
135135
----------
136136
.. installation
137137
138-
In the NGC container
138+
Pre-requisites
139139
^^^^^^^^^^^^^^^^^^^^
140+
* Linux x86_64
141+
* CUDA 11.8+ for Hopper and CUDA 12.1+ for Ada
142+
* NVIDIA Driver supporting CUDA 11.8 or later
143+
* cuDNN 8.1 or later
144+
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
140145

141-
The quickest way to get started with Transformer Engine is the NGC PyTorch container on
142-
`NVIDIA GPU Cloud Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_ (versions 22.09 and later).
146+
Docker
147+
^^^^^^^^^^^^^^^^^^^^
148+
149+
The quickest way to get started with Transformer Engine is by using Docker images on
150+
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_. For example to use the NGC PyTorch container interactively,
143151

144152
.. code-block:: bash
145153
146-
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.04-py3
154+
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3
147155
148-
Where 23.04 is the container version. For example, 23.04 for the April 2023 release.
156+
Where 23.10 is the container version. For example, 23.10 for the October 2023 release.
149157

150-
Pre-requisites
158+
pip
151159
^^^^^^^^^^^^^^^^^^^^
152-
* Linux x86_64
153-
* CUDA 11.8 or later
154-
* NVIDIA Driver supporting CUDA 11.8 or later
155-
* cuDNN 8.1 or later
156-
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
160+
To install the latest stable version of Transformer Engine,
161+
162+
.. code-block:: bash
163+
164+
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
165+
166+
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
157167

158168
From source
159169
^^^^^^^^^^^
170+
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_.
160171

161-
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html>`_.
162-
163-
Compiling with Flash Attention 2
172+
Compiling with FlashAttention-2
164173
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
174+
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance.
175+
176+
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. If the errors persist, install a supported version of FlashAttention-1 (v1.0.6 to v1.0.9).
165177

166-
TransformerEngine release v0.11.0 adds support for Flash Attention 2.0 for improved performance. It is a known issue that Flash Attention 2.0 compilation is
167-
resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory
168-
errors during the installation of TransformerEngine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. If the errors persist, install a supported version of Flash Attention 1 (v1.0.6 to v1.0.9).
178+
Note that NGC PyTorch 23.08+ containers include FlashAttention-2.
169179

170180
Model Support
171181
----------

tests/jax/test_custom_call_compute.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@
3737
is_fp8_supported, reason = is_fp8_available()
3838

3939

40+
@pytest.fixture(autouse=True, scope='function')
41+
def clear_live_arrays():
42+
"""
43+
Clear all live arrays to keep the resource clean
44+
"""
45+
yield
46+
for arr in jax.live_arrays():
47+
arr.delete()
48+
49+
4050
class TestFP8Dot:
4151

4252
@pytest.mark.skipif(not is_fp8_supported, reason=reason)

tests/jax/test_custom_call_shape.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# See LICENSE for license information.
44

55
import pytest
6+
import jax
67
import jax.numpy as jnp
78
from jax.core import ShapedArray
89

@@ -31,6 +32,16 @@
3132
TRANSPOSE = [True, False]
3233

3334

35+
@pytest.fixture(autouse=True, scope='function')
36+
def clear_live_arrays():
37+
"""
38+
Clear all live arrays to keep the resource clean
39+
"""
40+
yield
41+
for arr in jax.live_arrays():
42+
arr.delete()
43+
44+
3445
class TestGEMMShapeInfer:
3546

3647
@staticmethod

tests/jax/test_fused_attn.py

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,25 @@
1818
from flax.linen import make_causal_mask
1919
from jax import value_and_grad, jit
2020

21-
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType
21+
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
2222
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
2323
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
24-
from transformer_engine_jax import get_device_compute_capability
24+
from transformer_engine_jax import get_device_compute_capability # pylint: disable=wrong-import-order
2525

2626
# Type annotations
2727
Array = jnp.ndarray
2828

2929

30+
@pytest.fixture(autouse=True, scope='function')
31+
def clear_live_arrays():
32+
"""
33+
Clear all live arrays to keep the resource clean
34+
"""
35+
yield
36+
for arr in jax.live_arrays():
37+
arr.delete()
38+
39+
3040
class Backend(Enum):
3141
"""
3242
Fused attn backend.
@@ -52,6 +62,13 @@ def fixture_backend(request):
5262
DTYPES = [jnp.bfloat16, jnp.float16]
5363

5464

65+
def is_causal_mask(mask: AttnMaskType):
66+
"""
67+
Check if the mask is a causal mask
68+
"""
69+
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]
70+
71+
5572
def make_decoder_mask(tokens: Array) -> Array:
5673
"""
5774
Create padded causal mask
@@ -66,7 +83,7 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
6683
Self attention with JAX native implementation
6784
"""
6885
attn_mask_type = kwargs['attn_mask_type']
69-
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
86+
if is_causal_mask(attn_mask_type):
7087
mask = make_decoder_mask(q_token)
7188
else:
7289
mask = make_attention_mask(q_token > 0, kv_token > 0)
@@ -84,8 +101,8 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
84101
deterministic=not kwargs['is_training'],
85102
dropout_rate=kwargs['dropout_probability'],
86103
dropout_rng=dropout_rng,
87-
dtype=qkv.dtype)
88-
return output
104+
dtype=jnp.float32)
105+
return output.astype(qkv.dtype)
89106

90107

91108
def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
@@ -95,7 +112,7 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
95112
assert q.dtype == kv.dtype
96113

97114
attn_mask_type = kwargs['attn_mask_type']
98-
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
115+
if is_causal_mask(attn_mask_type):
99116
raise NotImplementedError
100117
mask = make_attention_mask(q_token > 0, kv_token > 0)
101118

@@ -112,15 +129,16 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
112129
deterministic=not kwargs['is_training'],
113130
dropout_rate=kwargs['dropout_probability'],
114131
dropout_rng=dropout_rng,
115-
dtype=q.dtype)
116-
return output
132+
dtype=jnp.float32)
133+
return output.astype(q.dtype)
117134

118135

119136
def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
120137
"""
121138
Self fused attention
122139
"""
123-
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
140+
attn_mask_type = kwargs['attn_mask_type']
141+
if is_causal_mask(attn_mask_type):
124142
mask = make_decoder_mask(q_token)
125143
else:
126144
mask = make_attention_mask(q_token > 0, kv_token > 0)
@@ -137,7 +155,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
137155
"""
138156
assert q.dtype == kv.dtype
139157

140-
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
158+
attn_mask_type = kwargs['attn_mask_type']
159+
if is_causal_mask(attn_mask_type):
141160
raise NotImplementedError
142161
mask = make_attention_mask(q_token > 0, kv_token > 0)
143162

@@ -149,41 +168,42 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
149168

150169
@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
151170
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
152-
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
171+
@pytest.mark.parametrize('attn_mask_type', [
172+
AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK,
173+
AttnMaskType.PADDING_CAUSAL_MASK
174+
])
153175
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
154176
@pytest.mark.parametrize('dtype', DTYPES)
155177
@pytest.mark.parametrize('is_training', [True, False])
156-
@pytest.mark.parametrize('pad_ratio', [0, 0.3])
157178
class TestSelfFusedAttn():
158179
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn"""
159180

160181
@staticmethod
161182
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
162-
head_dim, pad_ratio):
163-
if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0:
164-
pytest.skip("Arbitrary seqlen backend hasn't support padded input.")
183+
head_dim):
165184

166-
if not is_fused_attn_kernel_available(dtype, dtype, attn_bias_type, attn_mask_type,
167-
dropout_probability, s, s, head_dim):
168-
pytest.skip("Unsupported inputs combination or device compute capability.")
185+
assert isinstance(backend, Backend)
169186

170-
compute_capability = get_device_compute_capability(0)
171-
if (backend == Backend.Max512
172-
and not (compute_capability == 80 or compute_capability >= 90)):
173-
pytest.skip("Unsupported compute capability for "
174-
"fused attention with <=512 sequence length")
187+
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
188+
attn_mask_type, dropout_probability, s, s, head_dim):
189+
pytest.skip("Unsupported inputs combination or device compute capability.")
175190

176191
def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
177-
dropout_probability, dtype, is_training, pad_ratio):
192+
dropout_probability, dtype, is_training):
178193
"""Setup the test inputs"""
179194
self.__class__._check_inputs(s,
180195
attn_bias_type=attn_bias_type,
181196
attn_mask_type=attn_mask_type,
182197
backend=backend,
183198
dropout_probability=dropout_probability,
184199
dtype=dtype,
185-
head_dim=d,
186-
pad_ratio=pad_ratio)
200+
head_dim=d)
201+
202+
if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
203+
pad_ratio = 0.0
204+
else:
205+
pad_ratio = 0.3
206+
187207
key = jax.random.PRNGKey(0)
188208
subkeys = jax.random.split(key, 2)
189209

@@ -212,7 +232,7 @@ def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
212232
self.is_training = is_training
213233

214234
def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
215-
dtype, is_training, pad_ratio):
235+
dtype, is_training):
216236
"""
217237
Test forward without using JIT
218238
"""
@@ -225,8 +245,7 @@ def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, drop
225245
backend=backend,
226246
dropout_probability=dropout_probability,
227247
dtype=dtype,
228-
is_training=is_training,
229-
pad_ratio=pad_ratio)
248+
is_training=is_training)
230249

231250
primitive_out = customcall_self_fused_attn(self.qkv,
232251
self.bias,
@@ -265,7 +284,7 @@ def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, drop
265284
jnp.zeros_like(pri_invalid, jnp.float32))
266285

267286
def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
268-
dropout_probability, dtype, is_training, pad_ratio):
287+
dropout_probability, dtype, is_training):
269288
"""
270289
Test forward, backward, and autodiff by jax.value_and_grad
271290
"""
@@ -281,13 +300,12 @@ def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, back
281300
backend=backend,
282301
dropout_probability=dropout_probability,
283302
dtype=dtype,
284-
is_training=is_training,
285-
pad_ratio=pad_ratio)
303+
is_training=is_training)
286304

287305
def grad_func(fused_attn_func, *args, **kwargs):
288306
# Gradient is small, use a gradient multiplier to amplify the graident
289307
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
290-
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
308+
if is_causal_mask(attn_mask_type):
291309
gradient_multiplier = gradient_multiplier / 10
292310
# Keep only valid result for the gradient
293311
# fused_attn output has shape (b, s, h, d)
@@ -333,15 +351,15 @@ def grad_func(fused_attn_func, *args, **kwargs):
333351
rtol=1e-4,
334352
atol=1e-5)
335353

336-
valid_primitive_dqkv, invalid_primitive_dqkv = jnp.split(primitive_dqkv, (self.valid_len,),
337-
axis=1)
338-
valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,),
339-
axis=1)
354+
valid_primitive_dqkv, invalid_primitive_dqkv = \
355+
jnp.split(primitive_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
356+
valid_reference_dqkv, invalid_reference_dqkv = \
357+
jnp.split(reference_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
340358

341-
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = jnp.split(
342-
valid_primitive_dqkv.astype(jnp.float32), 3, axis=2)
343-
valid_reference_dq, valid_reference_dk, valid_reference_dv = jnp.split(
344-
valid_reference_dqkv.astype(jnp.float32), 3, axis=2)
359+
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = \
360+
jnp.split(valid_primitive_dqkv, 3, axis=2)
361+
valid_reference_dq, valid_reference_dk, valid_reference_dv = \
362+
jnp.split(valid_reference_dqkv, 3, axis=2)
345363

346364
np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
347365
np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
@@ -482,9 +500,7 @@ def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_prob
482500

483501
def grad_func(fused_attn_func, *args, **kwargs):
484502
# Gradient is small, use a gradient multiplier to amplify the graident
485-
gradient_multiplier = 10000
486-
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
487-
gradient_multiplier = gradient_multiplier / 10
503+
gradient_multiplier = 1e4
488504
# Keep only valid result for the gradient
489505
# fused_attn output has shape (b, s_q, h, d)
490506
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),

tests/jax/test_layer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@
1919
is_fp8_supported, reason = is_fp8_available()
2020

2121

22+
@pytest.fixture(autouse=True, scope='function')
23+
def clear_live_arrays():
24+
"""
25+
Clear all live arrays to keep the resource clean
26+
"""
27+
yield
28+
for arr in jax.live_arrays():
29+
arr.delete()
30+
31+
2232
def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
2333
output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
2434
return jnp.mean(output)

tests/jax/test_praxis_layers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@
3838
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
3939

4040

41+
@pytest.fixture(autouse=True, scope='function')
42+
def clear_live_arrays():
43+
"""
44+
Clear all live arrays to keep the resource clean
45+
"""
46+
yield
47+
for arr in jax.live_arrays():
48+
arr.delete()
49+
50+
4151
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
4252
for key in ref_fd:
4353
assert key in test_fd, \

0 commit comments

Comments
 (0)