Skip to content

Commit

Permalink
[C/JAX] Support more mask types for the arbitrary seqlen kernels and …
Browse files Browse the repository at this point in the history
…minor changes of JAX bias (#469)

* Move bias to float32

Signed-off-by: Reese Wang <[email protected]>

* Enable varlen

Signed-off-by: Reese Wang <[email protected]>

* Increase neg infinity abs values

Signed-off-by: Reese Wang <[email protected]>

* Enable varlen tests

Signed-off-by: Reese Wang <[email protected]>

* Remove unnecessary code

Signed-off-by: Reese Wang <[email protected]>

* Fix lint

Signed-off-by: Reese Wang <[email protected]>

* Support variable sequence length after cuDNN 8.9.6

Signed-off-by: Reese Wang <[email protected]>

* Use unique_ptr instead of shared_ptr

Signed-off-by: Reese Wang <[email protected]>

* Add a new mask type: PADDING_CAUSAL_MASK

Signed-off-by: Reese Wang <[email protected]>

* Support flash padding mask after 8.9.6

Signed-off-by: Reese Wang <[email protected]>

* Enhance the Max512 handling for causal masking and add the related tests

Signed-off-by: Reese Wang <[email protected]>

* Update the fused attn support lists

Signed-off-by: Reese Wang <[email protected]>

* Remove padding_aware from the caching

Signed-off-by: Reese Wang <[email protected]>

* Fix libtransformer.so issue

Signed-off-by: Reese Wang <[email protected]>

* Reduce the pad ratio tests

Signed-off-by: Reese Wang <[email protected]>

* Fix a bug with cuDNN 8.9.5

Signed-off-by: Reese Wang <[email protected]>

* Release backend resource after the module level unit test

Signed-off-by: Reese Wang <[email protected]>

* Clean the jax live arrays before running the unit tests

Signed-off-by: Reese Wang <[email protected]>

* Fix too-few-public-methods lint

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 authored Nov 13, 2023
1 parent 64a3d1d commit bfaec64
Show file tree
Hide file tree
Showing 13 changed files with 500 additions and 166 deletions.
10 changes: 10 additions & 0 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@
is_fp8_supported, reason = is_fp8_available()


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


class TestFP8Dot:

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
Expand Down
11 changes: 11 additions & 0 deletions tests/jax/test_custom_call_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

import pytest
import jax
import jax.numpy as jnp
from jax.core import ShapedArray

Expand Down Expand Up @@ -31,6 +32,16 @@
TRANSPOSE = [True, False]


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


class TestGEMMShapeInfer:

@staticmethod
Expand Down
98 changes: 57 additions & 41 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,22 @@
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine_jax import get_device_compute_capability
from transformer_engine_jax import get_device_compute_capability # pylint: disable=wrong-import-order

# Type annotations
Array = jnp.ndarray


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


class Backend(Enum):
"""
Fused attn backend.
Expand All @@ -52,6 +62,13 @@ def fixture_backend(request):
DTYPES = [jnp.bfloat16, jnp.float16]


def is_causal_mask(mask: AttnMaskType):
"""
Check if the mask is a causal mask
"""
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


def make_decoder_mask(tokens: Array) -> Array:
"""
Create padded causal mask
Expand All @@ -66,7 +83,7 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
Self attention with JAX native implementation
"""
attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
Expand All @@ -84,8 +101,8 @@ def jax_self_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=qkv.dtype)
return output
dtype=jnp.float32)
return output.astype(qkv.dtype)


def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
Expand All @@ -95,7 +112,7 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
assert q.dtype == kv.dtype

attn_mask_type = kwargs['attn_mask_type']
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
if is_causal_mask(attn_mask_type):
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)

Expand All @@ -112,15 +129,16 @@ def jax_cross_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs):
deterministic=not kwargs['is_training'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=q.dtype)
return output
dtype=jnp.float32)
return output.astype(q.dtype)


def customcall_self_fused_attn(qkv, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
Self fused attention
"""
if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
mask = make_decoder_mask(q_token)
else:
mask = make_attention_mask(q_token > 0, kv_token > 0)
Expand All @@ -137,7 +155,8 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)
"""
assert q.dtype == kv.dtype

if kwargs['attn_mask_type'] == AttnMaskType.CAUSAL_MASK:
attn_mask_type = kwargs['attn_mask_type']
if is_causal_mask(attn_mask_type):
raise NotImplementedError
mask = make_attention_mask(q_token > 0, kv_token > 0)

Expand All @@ -149,41 +168,42 @@ def customcall_cross_fused_attn(q, kv, q_token, kv_token, dropout_rng, **kwargs)

@pytest.mark.parametrize('b, s, h, d', SELF_CASES)
@pytest.mark.parametrize('attn_bias_type', [AttnBiasType.NO_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type', [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('attn_mask_type', [
AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_MASK
])
@pytest.mark.parametrize('dropout_probability', [0., 0.1])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('is_training', [True, False])
@pytest.mark.parametrize('pad_ratio', [0, 0.3])
class TestSelfFusedAttn():
"""Tests for transformer_engine.jax.fused_attn.self_fused_attn"""

@staticmethod
def _check_inputs(s, *, attn_bias_type, attn_mask_type, backend, dropout_probability, dtype,
head_dim, pad_ratio):
if (s > 512 or backend == Backend.Arbitrary) and pad_ratio != 0:
pytest.skip("Arbitrary seqlen backend hasn't support padded input.")
head_dim):

assert isinstance(backend, Backend)

if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_probability, s, s, head_dim):
pytest.skip("Unsupported inputs combination or device compute capability.")

compute_capability = get_device_compute_capability(0)
if (backend == Backend.Max512
and not (compute_capability == 80 or compute_capability >= 90)):
pytest.skip("Unsupported compute capability for "
"fused attention with <=512 sequence length")

def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
dropout_probability, dtype, is_training):
"""Setup the test inputs"""
self.__class__._check_inputs(s,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
head_dim=d,
pad_ratio=pad_ratio)
head_dim=d)

if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pad_ratio = 0.0
else:
pad_ratio = 0.3

key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)

Expand Down Expand Up @@ -212,7 +232,7 @@ def _set_inputs(self, b, s, h, d, *, attn_bias_type, attn_mask_type, backend,
self.is_training = is_training

def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, dropout_probability,
dtype, is_training, pad_ratio):
dtype, is_training):
"""
Test forward without using JIT
"""
Expand All @@ -225,8 +245,7 @@ def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, drop
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
is_training=is_training)

primitive_out = customcall_self_fused_attn(self.qkv,
self.bias,
Expand Down Expand Up @@ -265,7 +284,7 @@ def test_forward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend, drop
jnp.zeros_like(pri_invalid, jnp.float32))

def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, backend,
dropout_probability, dtype, is_training, pad_ratio):
dropout_probability, dtype, is_training):
"""
Test forward, backward, and autodiff by jax.value_and_grad
"""
Expand All @@ -281,13 +300,12 @@ def test_forward_backward(self, b, s, h, d, attn_bias_type, attn_mask_type, back
backend=backend,
dropout_probability=dropout_probability,
dtype=dtype,
is_training=is_training,
pad_ratio=pad_ratio)
is_training=is_training)

def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 1000 if dtype == jnp.bfloat16 else 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
if is_causal_mask(attn_mask_type):
gradient_multiplier = gradient_multiplier / 10
# Keep only valid result for the gradient
# fused_attn output has shape (b, s, h, d)
Expand Down Expand Up @@ -333,15 +351,15 @@ def grad_func(fused_attn_func, *args, **kwargs):
rtol=1e-4,
atol=1e-5)

valid_primitive_dqkv, invalid_primitive_dqkv = jnp.split(primitive_dqkv, (self.valid_len,),
axis=1)
valid_reference_dqkv, invalid_reference_dqkv = jnp.split(reference_dqkv, (self.valid_len,),
axis=1)
valid_primitive_dqkv, invalid_primitive_dqkv = \
jnp.split(primitive_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)
valid_reference_dqkv, invalid_reference_dqkv = \
jnp.split(reference_dqkv.astype(jnp.float32), (self.valid_len,), axis=1)

valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = jnp.split(
valid_primitive_dqkv.astype(jnp.float32), 3, axis=2)
valid_reference_dq, valid_reference_dk, valid_reference_dv = jnp.split(
valid_reference_dqkv.astype(jnp.float32), 3, axis=2)
valid_primitive_dq, valid_primitive_dk, valid_primitive_dv = \
jnp.split(valid_primitive_dqkv, 3, axis=2)
valid_reference_dq, valid_reference_dk, valid_reference_dv = \
jnp.split(valid_reference_dqkv, 3, axis=2)

np.testing.assert_allclose(valid_primitive_dq, valid_reference_dq, rtol=1e-4, atol=1e-5)
np.testing.assert_allclose(valid_primitive_dk, valid_reference_dk, rtol=1e-4, atol=1e-5)
Expand Down Expand Up @@ -482,9 +500,7 @@ def test_forward_backward(self, b, s_q, s_kv, h, d, attn_mask_type, dropout_prob

def grad_func(fused_attn_func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the graident
gradient_multiplier = 10000
if attn_mask_type == AttnMaskType.CAUSAL_MASK:
gradient_multiplier = gradient_multiplier / 10
gradient_multiplier = 1e4
# Keep only valid result for the gradient
# fused_attn output has shape (b, s_q, h, d)
valid_fused_attn_ret, _ = jnp.split(fused_attn_func(*args, **kwargs),
Expand Down
10 changes: 10 additions & 0 deletions tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
is_fp8_supported, reason = is_fp8_available()


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
return jnp.mean(output)
Expand Down
10 changes: 10 additions & 0 deletions tests/jax/test_praxis_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@
FP8_FORMATS = [Format.E4M3, Format.HYBRID]


@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()


def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, \
Expand Down
8 changes: 7 additions & 1 deletion transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
Expand All @@ -111,6 +112,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
Expand All @@ -131,7 +133,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
&& (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
&& ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| ((cudnn_runtime_version >= 8906) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD))) {
Expand Down
Loading

0 comments on commit bfaec64

Please sign in to comment.