Skip to content

Commit

Permalink
[JAX] Expose sliding window attn to TE-JAX API (#1205)
Browse files Browse the repository at this point in the history
* Expose JAX sliding window attn API

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* No SWA in context parallel; fix RNG seed in test

Signed-off-by: Hua Huang <[email protected]>

* Handle SAW API discrepancy in cuDNN and Python

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add SAW API for flax, all tests passed

Will update tests/jax/test_praxis_layers.py next

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_praxis_layers.py for SWA, test passed

Signed-off-by: Hua Huang <[email protected]>

* Use tuple window_size; update for PR #1212

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add and adjust some pytest.skip

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revised following Reese Wang's comments

Still need further debugging:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-NO_BIAS] - AssertionError:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-KV_PACKED-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-NO_BIAS] - AssertionError:
FAILED test_fused_attn.py::TestFusedAttn::test_backward[NO_SWA-DROP_0.0-4-128-256-16-16-64-BF16-CROSS-SEPARATE-NO_MASK-POST_SCALE_BIAS-1HSS] - AssertionError:

These errors does not exist in the previous commit

Signed-off-by: Hua Huang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix no-SWA test case errors in previous commit

Signed-off-by: Hua Huang <[email protected]>

* Add Padding mask w/ sliding windows sanity tests

Signed-off-by: Reese Wang <[email protected]>

* Use float32 for the reference code softmax calculation

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Hua Huang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Reese Wang <[email protected]>
  • Loading branch information
3 people authored Oct 10, 2024
1 parent 5b6546c commit 85e60e6
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 69 deletions.
43 changes: 41 additions & 2 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from functools import partial
from math import sqrt
from typing import Tuple, Optional

import jax
import jax.numpy as jnp
Expand All @@ -27,6 +28,7 @@
fused_attn,
fused_attn_thd,
get_qkv_format,
make_swa_mask,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import (
Expand Down Expand Up @@ -123,6 +125,7 @@ def make_mask(
segment_pad_q: ArrayLike,
segment_pad_kv: ArrayLike,
attn_mask_type: AttnMaskType,
window_size: Optional[Tuple[int, int]] = None,
) -> Array:
"""
Create attention mask based on mask type. A `True` value in the mask means
Expand All @@ -140,6 +143,15 @@ def make_mask(
segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1)
)
inv_mask = combine_masks(inv_pad_mask, inv_mask)

if window_size is not None:
max_seqlen_q = inv_mask.shape[-2]
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
# In inv_swa_mask and inv_mask 0 is masked out
inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)

mask = jnp.logical_not(inv_mask)
return mask

Expand Down Expand Up @@ -274,6 +286,7 @@ class FusedAttnRunner:
is_training: bool
qkv_layout: QKVLayout
bias_shape: BiasShape
window_size: Optional[Tuple[int, int]] = None

# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
Expand All @@ -298,6 +311,11 @@ def _check_configs(self):
if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.")

if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)

self.backend = FusedAttnHelper(
self.dtype,
self.dtype,
Expand All @@ -310,6 +328,7 @@ def _check_configs(self):
self.max_seqlen_q,
self.max_seqlen_kv,
self.head_dim,
(-1, -1) if self.window_size is None else self.window_size,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")
Expand Down Expand Up @@ -456,6 +475,7 @@ def generate_random_segment_ids(
self.segment_pad_q,
self.segment_pad_kv,
self.attn_mask_type,
self.window_size,
)

if get_qkv_format(self.qkv_layout) == QKVFormat.THD:
Expand Down Expand Up @@ -500,6 +520,7 @@ def test_forward(self):
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
}

# Convert the outputs to float32 for the elementwise comparison
Expand Down Expand Up @@ -557,6 +578,7 @@ def grad_func(func, *args, **kwargs):
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
}

# We can compute dBias only for the [1, h, s, s] layout
Expand Down Expand Up @@ -668,7 +690,7 @@ def check_dqkv(primitive, reference, pad):
pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"),
pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"),
pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"),
pytest.param(4, 512, 128, 16, 16, 64, jnp.bfloat16, id="4-512-128-16-16-64-BF16-CROSS"),
pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"),
pytest.param(
2,
2048,
Expand All @@ -677,7 +699,7 @@ def check_dqkv(primitive, reference, pad):
12,
64,
jnp.bfloat16,
id="2-2048-1048-12-12-64-BF16-CROSS",
id="2-2048-1024-12-12-64-BF16-CROSS",
),
pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"),
pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"),
Expand All @@ -690,6 +712,13 @@ def check_dqkv(primitive, reference, pad):
pytest.param(0.1, id="DROP_0.1"),
],
)
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
pytest.param(True, id="SWA"),
],
)
class TestFusedAttn:
"""
Fused attention tester
Expand Down Expand Up @@ -717,12 +746,16 @@ def _test_forward(
is_training,
qkv_layout,
bias_shape,
swa,
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
window_size = None
if swa:
window_size = (s_kv // 10, 0)
runner = FusedAttnRunner(
b,
s_q,
Expand All @@ -737,6 +770,7 @@ def _test_forward(
is_training,
qkv_layout,
bias_shape,
window_size,
)
runner.test_forward()

Expand All @@ -754,10 +788,14 @@ def test_backward(
dtype,
qkv_layout,
bias_shape,
swa,
):
"""
Test backward with parameterized configs
"""
window_size = None
if swa:
window_size = (s_kv // 10, 0)
runner = FusedAttnRunner(
b,
s_q,
Expand All @@ -772,5 +810,6 @@ def test_backward(
True,
qkv_layout,
bias_shape,
window_size,
)
runner.test_backward()
21 changes: 18 additions & 3 deletions tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Test transformer_engine.jax.flax.TransformerLayer"""
import os
from functools import partial
from typing import Dict
from typing import Dict, Tuple

import flax
import jax
Expand Down Expand Up @@ -61,6 +61,7 @@ def enable_fused_attn():
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
_KEY_OF_WINDOW_SIZE = "window_size"

BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
Expand All @@ -70,6 +71,7 @@ def enable_fused_attn():
_KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_WINDOW_SIZE: (-1, -1),
}

ATTRS = [
Expand Down Expand Up @@ -193,6 +195,19 @@ def enable_fused_attn():
{
_KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "causal",
_KEY_OF_WINDOW_SIZE: (64, 0), # Left size must < DATA_SHAPE seqlen
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_WINDOW_SIZE: (2, 2),
},
]

ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
Expand Down Expand Up @@ -326,7 +341,7 @@ def generate_inputs(self, data_shape, dtype):

padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
mask = causal_mask
else:
mask = padded_mask
Expand Down Expand Up @@ -379,7 +394,7 @@ def generate_inputs(self, data_shape, dtype):

padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["causal", "padding_causal"]:
self_mask = causal_mask
else:
self_mask = padded_mask
Expand Down
53 changes: 52 additions & 1 deletion tests/jax/test_praxis_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from functools import partial
from typing import Dict
from typing import Dict, Tuple

import flax
import jax
Expand Down Expand Up @@ -645,6 +645,7 @@ class DotProductAttnAttr:
NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = "transpose_batch_sequence"
SCALE_FACTOR = "scale_factor"
WINDOW_SIZE = "window_size"
ATTRS = [
{
ATTN_MASK_TYPE: "padding",
Expand Down Expand Up @@ -681,6 +682,12 @@ class DotProductAttnAttr:
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
},
{
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
]


Expand All @@ -707,6 +714,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
num_gqa_groups = num_attention_heads
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
window_size = attrs.get(DotProductAttnAttr.WINDOW_SIZE, None)

praxis_p = pax_fiddle.Config(
DotProductAttention,
Expand All @@ -717,6 +725,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
flax_cls = partial(
flax_DotProductAttention,
Expand All @@ -726,6 +735,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)

return praxis_p, flax_cls
Expand All @@ -750,6 +760,7 @@ class MultiHeadAttnAttr:
ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope"
WINDOW_SIZE = "window_size"
ATTRS = [
{
USE_BIAS: True,
Expand Down Expand Up @@ -858,6 +869,17 @@ class MultiHeadAttnAttr:
LORA_SCOPE: "all",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
LORA_SCOPE: "all",
TRANSPOSE_BS: True,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
]


Expand Down Expand Up @@ -899,6 +921,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
scale_attn_logits = False
scaled_query_init = True
float32_logits = False
window_size = attrs.get(MultiHeadAttnAttr.WINDOW_SIZE, None)

praxis_p = pax_fiddle.Config(
MultiHeadAttention,
Expand All @@ -923,6 +946,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits,
window_size=window_size,
)
flax_cls = partial(
flax_MultiHeadAttention,
Expand All @@ -946,6 +970,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits,
window_size=window_size,
)

return praxis_p, flax_cls
Expand Down Expand Up @@ -983,6 +1008,7 @@ class TransformerLayerAttr:
ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope"
WINDOW_SIZE = "window_size"
ATTRS = [
{
USE_BIAS: True,
Expand Down Expand Up @@ -1246,6 +1272,28 @@ class TransformerLayerAttr:
TRANSPOSE_BS: False,
LORA_SCOPE: "all",
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
WINDOW_SIZE: (64, 0), # Left size must <= S in DATA_SHAPE
},
]


Expand Down Expand Up @@ -1289,6 +1337,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
)
drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
window_size = attrs.get(TransformerLayerAttr.WINDOW_SIZE, None)

rel_embedding_init = RelativePositionBiases.generate_embedding_init(
relative_embedding.embedding_init,
Expand Down Expand Up @@ -1330,6 +1379,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)
flax_cls = partial(
flax_TransformerLayer,
Expand Down Expand Up @@ -1358,6 +1408,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence,
window_size=window_size,
)

return praxis_p, flax_cls
Expand Down
Loading

0 comments on commit 85e60e6

Please sign in to comment.