From 838345eba4fdd2a169dd9e087d39c30a360e684a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 19 Dec 2024 21:32:41 -0800 Subject: [PATCH] [common/PyTorch] Add cuDNN SWA (left, 0) + padding + bottom right causal (#1378) * add swa (left,0) + padding + brcm support Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * final fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * upgrade to FE 1.9-rc Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip thd + CP + fused attn tests for cuDNN 9.6+ due to different stats shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- qa/L0_pytorch_unittest/test.sh | 2 +- tests/jax/test_fused_attn.py | 18 +- tests/pytorch/fused_attn/test_fused_attn.py | 186 ++++++++++++------ .../fused_attn/test_fused_attn_with_cp.py | 2 + .../common/fused_attn/fused_attn.cpp | 49 +++-- .../fused_attn_f16_arbitrary_seqlen.cu | 6 +- transformer_engine/pytorch/attention.py | 31 ++- 8 files changed, 195 insertions(+), 101 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 936021bfed..cc5632eda7 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45 +Subproject commit cc5632eda70bbdac34455c2d94066d27d10e2699 diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 17307574a9..61dd15d015 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -13,7 +13,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py @@ -22,3 +21,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py +NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 759ea893ef..10da7486cf 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -170,8 +170,7 @@ def make_mask( max_seqlen_kv = inv_mask.shape[-1] inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type) inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape) - # In inv_swa_mask and inv_mask 0 is masked out - inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask) + inv_mask = combine_masks(inv_mask, inv_swa_mask) mask = jnp.logical_not(inv_mask) return mask @@ -315,6 +314,13 @@ def _get_max_segments_per_sequence(self): return self.num_segments_per_seq + 1 def _check_configs(self): + # TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference + if ( + self.qkv_layout.is_thd() + and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK + and self.window_size is not None + ): + pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.") # TODO(rewang): probably adds this in is_fused_attn_available if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") @@ -504,7 +510,13 @@ def generate_random_segment_ids( if self.qkv_layout.is_thd(): self.mask_for_customcall = None # THD format doesn't support mask else: - self.mask_for_customcall = self.mask + self.mask_for_customcall = make_mask( + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, + self.attn_mask_type, + ) self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index dea31b5971..588e6e4ecd 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -237,19 +237,18 @@ def test_dot_product_attention( tols = dict(atol=1.5e-2, rtol=1.5e-2) config = model_configs[model] is_mla = config.head_dim_qk != config.head_dim_v + is_mqa_gqa = config.num_heads != config.num_gqa_groups if qkv_layout is None: if config.attn_type == "self": - qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd" + qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd" else: - qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd" + qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd" if "3" in qkv_layout and config.attn_type == "cross": pytest.skip("No need to test this layout for cross attention") - # Test backend availability - window_size = (-1, -1) - if swa: - window_size = [2, 2] - config.window_size = check_set_window_size(config.attn_mask_type, window_size) + if config.window_size == (-1, -1) and swa: + config.window_size = [2, 2] + config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) available_backends, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, @@ -334,16 +333,16 @@ def test_dot_product_attention( is_training, ) - if unfused_attn_supported and fused_attn_supported: - logging.info("[test_dot_product_attention]: unfused attn vs fused attn") - torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) - for i, _ in enumerate(unfused_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) + if unfused_attn_supported and fused_attn_supported: + logging.info("[test_dot_product_attention]: unfused attn vs fused attn") + torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) + for i, _ in enumerate(unfused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) @@ -399,30 +398,41 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias - "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), - "mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"), - "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "mask_8_0": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "mask_5_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_5_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_5_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), + "mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), + "mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"), + "mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"), + "mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_10_0": ModelConfig( + 2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" ), - "mask_8_1": ModelConfig( - 1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_10_1": ModelConfig( + 2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" ), - "mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), - "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), - "mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), } @@ -531,20 +541,28 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "swa_6_0": ModelConfig( - 4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" - ), + "swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), "swa_6_1": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_2": ModelConfig( + 2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_3": ModelConfig( 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" ), } @@ -623,18 +641,57 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), - "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "layout_2_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_3_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_4_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_5_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_2": ModelConfig( + 2, + 24, + 24, + 128, + 2048, + 4096, + 0.0, + "padding_causal_bottom_right", + "no_bias", + window_size=(4, 0), + ), } @@ -651,11 +708,13 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): config = model_configs[model] if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") + logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True") pad_between_seqs = True test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) if get_cudnn_version() >= (9, 3, 0): + logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False") # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run pad_between_seqs = False test_dot_product_attention( @@ -695,9 +754,12 @@ def _run_dot_product_attention( ) seqlens_kv = seqlens_q if config.attn_type == "cross": - seqlens_q = torch.randint( - 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" - ) + if config.max_seqlen_q > 1: + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda") seqlens_kv = torch.randint( 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" ) diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 1007d6aa34..fd8e543adc 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -121,6 +121,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") + if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0): + pytest.skip("THD format is not supported for cuDNN 9.6+!") config = model_configs_fused_attn[model] if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9cde765401..32e6d4df8f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -135,6 +135,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( !requires_64bit_ragged_offset) { flag_m512 = true; } + // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging if ( // architecture ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && @@ -152,7 +153,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - ((cudnn_runtime_version >= 8906) && + (cudnn_runtime_version >= 8906 && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_ALIBI && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && @@ -161,43 +162,67 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && sm_arch_ >= 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || - ((cudnn_runtime_version >= 90000) && + (cudnn_runtime_version >= 90000 && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && // mask type + // pre-8.9.6: causal ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - ((cudnn_runtime_version >= 8906) && + // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} + (cudnn_runtime_version >= 8906 && + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - ((cudnn_runtime_version >= 90300) && - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 9.1: adds thd + {padding, padding_causal} + (cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90300 && (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90600 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format - ((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) || + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 90600)))) && + cudnn_runtime_version >= 90600))) && // sliding window + // pre-9.2: full attn, causal ((cudnn_runtime_version < 90200 && window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && - dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD))))) && + qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || + // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} + (cudnn_runtime_version >= 90600 && + ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + dropout == 0.0)))) && // check 64-bit ragged offset support (supported_ragged_offset_size)) { flag_arb = true; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index b706eadace..cade624c8d 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -71,7 +71,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_bottom_right = false; } bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); @@ -451,7 +452,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_bottom_right = false; } bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index be0d176520..9268b9636e 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -602,6 +602,12 @@ def get_attention_backend( "Disabling FusedAttention as it does not support context parallelism with MLA" ) use_fused_attention = False + elif cudnn_version >= (9, 6, 0) and qkv_format == "thd": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with THD for" + " cuDNN 9.6+" + ) + use_fused_attention = False # Filter: Attention mask # attn_mask_type | attention_mask | supported backends @@ -618,9 +624,7 @@ def get_attention_backend( # self-attention | | All # cross-attention | | FusedAttention, UnfusedDotProductAttention # causal_bottom_right | None | All - # padding_causal_bottom_right | Same as "padding" | - # self-attention | | All - # cross-attention | | FlashAttention, UnfusedDotProductAttention + # padding_causal_bottom_right | Same as "padding" | All # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention # | [b, h, sq, skv] | if attn_mask_type == "arbitrary": @@ -697,29 +701,16 @@ def get_attention_backend( " for FP8" ) use_fused_attention = False - elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd": + elif window_size[1] != 0 or attention_dropout != 0.0: logger.debug( "Disabling FusedAttention as it only supports sliding window attention " - "with causal mask, no dropout, and qkv_format = bshd/sbhd" - ) - use_fused_attention = False - elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [ - "no_mask", - "padding", - "causal_bottom_right", - "padding_causal_bottom_right", - ]: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with attn_mask_type = %s for cross-attention", - attn_mask_type, + "with (left, 0) and no dropout" ) use_fused_attention = False - elif "padding" in attn_mask_type: + elif max_seqlen_q > max_seqlen_kv: logger.debug( "Disabling FusedAttention as it does not support sliding window attention " - "with attn_mask_type = %s", - attn_mask_type, + "with s_q > s_kv for cross-attention" ) use_fused_attention = False if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):