diff --git a/setup.py b/setup.py index 4a344191de..5959c2b941 100644 --- a/setup.py +++ b/setup.py @@ -290,7 +290,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None: # Framework-specific requirements if "pytorch" in frameworks(): - add_unique(install_reqs, ["torch", "flash-attn>=1.0.6"]) + add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.0.4"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) if "jax" in frameworks(): if not found_pybind11(): diff --git a/tests/pytorch/test_fused_attn.py b/tests/pytorch/test_fused_attn.py index 99a82eb6e1..57148d8846 100644 --- a/tests/pytorch/test_fused_attn.py +++ b/tests/pytorch/test_fused_attn.py @@ -17,6 +17,7 @@ from pkg_resources import packaging from importlib.metadata import version +from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states fp8_available, reason_for_no_fp8 = is_fp8_available() _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2") @@ -58,29 +59,32 @@ def __init__( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) -def test_dot_product_attention(dtype, bs, model): +@pytest.mark.parametrize("ckpt_attn", [True, False]) +@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) +def test_dot_product_attention(dtype, bs, model, ckpt_attn, bias_type): """Test DotProductAttention module with three backends, FlashAttention, FusedAttention and UnfusedDotProductAttention""" config = model_configs[model] - flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( - dtype, bs, config, "FlashAttention") + if bias_type == "no_bias": + flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( + dtype, bs, config, "FlashAttention", ckpt_attn, bias_type) fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( - dtype, bs, config, "FusedAttention") + dtype, bs, config, "FusedAttention", ckpt_attn, bias_type) unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( - dtype, bs, config, "UnfusedDotProductAttention") + dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type) - atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (2.5e-3, 2.5e-3) - assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) - assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) + atol, rtol = (2.5e-2, 2.5e-2) if dtype == torch.bfloat16 else (5e-3, 5e-3) + if bias_type == "no_bias": + assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) -def _run_dot_product_attention(dtype, bs, config, backend): +def _run_dot_product_attention(dtype, bs, config, backend, ckpt_attn, bias_type): - torch.manual_seed(1234) - torch.cuda.manual_seed(1234) + reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" if backend == "FlashAttention": @@ -88,7 +92,7 @@ def _run_dot_product_attention(dtype, bs, config, backend): if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - inp = 0.1 * torch.randn( + inp = torch.randn( config.seq_len, bs, 3, config.num_attention_heads, config.head_dim, dtype = dtype).cuda() inp.requires_grad=True @@ -96,9 +100,14 @@ def _run_dot_product_attention(dtype, bs, config, backend): seqlens.fill_(config.seq_len) cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) - op_grad = 0.001 * torch.randint(0, 200, ( - config.seq_len, bs, config.num_attention_heads * config.head_dim - ), dtype = dtype).cuda() + op_grad = torch.randn( + config.seq_len, bs, config.num_attention_heads * config.head_dim, + dtype = dtype).cuda() + if bias_type != "no_bias": + bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len, + dtype = dtype).cuda() + else: + bias = None block = ( DotProductAttention( @@ -108,7 +117,7 @@ def _run_dot_product_attention(dtype, bs, config, backend): attn_mask_type = config.attn_mask_type, sequence_parallel = False, tp_size = 1, - get_rng_state_tracker = None, + get_rng_state_tracker = get_dummy_cuda_rng_tracker, tp_group = None, layer_number = 1, attention_type = "self" @@ -118,7 +127,10 @@ def _run_dot_product_attention(dtype, bs, config, backend): q = inp[:, :,0,:,:] k = inp[:, :,1,:,:] v = inp[:, :,2,:,:] - op = block(q, k, v) + op = block(q, k, v, + checkpoint_core_attention = ckpt_attn, + core_attention_bias_type = bias_type, + core_attention_bias = bias) op.backward(op_grad) return op, inp.grad @@ -128,29 +140,32 @@ def _run_dot_product_attention(dtype, bs, config, backend): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) -def test_transformer_layer(dtype, bs, model): +@pytest.mark.parametrize("ckpt_attn", [False]) +@pytest.mark.parametrize("bias_type", ["no_bias", "post_scale_bias"]) +def test_transformer_layer(dtype, bs, model, ckpt_attn, bias_type): """Test TransformerLayer module when its DotProductAttention is enabled with FlashAttention, FusedAttention, or UnfusedDotProductAttention backend""" config = model_configs[model] - flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( - dtype, bs, config, "FlashAttention") + if bias_type == "no_bias": + flash_attn_fwd, flash_attn_bwd = _run_transformer_layer( + dtype, bs, config, "FlashAttention", ckpt_attn, bias_type) fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( - dtype, bs, config, "FusedAttention") + dtype, bs, config, "FusedAttention", ckpt_attn, bias_type) unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer( - dtype, bs, config, "UnfusedDotProductAttention") + dtype, bs, config, "UnfusedDotProductAttention", ckpt_attn, bias_type) - atol, rtol = (5e-1, 5e-1) if dtype == torch.bfloat16 else (5e-1, 5e-1) - assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) - assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) + atol, rtol = (5e-1, 5e-2) + if bias_type == "no_bias": + assert torch.allclose(fused_attn_fwd, flash_attn_fwd, atol = atol, rtol = rtol) + assert torch.allclose(fused_attn_bwd, flash_attn_bwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) -def _run_transformer_layer(dtype, bs, config, backend): +def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type): - torch.manual_seed(1234) - torch.cuda.manual_seed(1234) + reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" if backend == "FlashAttention": @@ -158,7 +173,7 @@ def _run_transformer_layer(dtype, bs, config, backend): if backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - inp = 0.1 * torch.randn( + inp = torch.randn( config.seq_len, bs, config.num_attention_heads * config.head_dim, dtype = dtype).cuda() inp.requires_grad=True @@ -166,9 +181,9 @@ def _run_transformer_layer(dtype, bs, config, backend): seqlens.fill_(config.seq_len) cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) - op_grad = 0.001 * torch.randint(0, 200, ( - config.seq_len, bs, config.num_attention_heads * config.head_dim - ), dtype = dtype).cuda() + op_grad = torch.randn( + config.seq_len, bs, config.num_attention_heads * config.head_dim, + dtype = dtype).cuda() sigma = 0.02 init_method = init_method_normal(sigma) @@ -178,6 +193,11 @@ def _run_transformer_layer(dtype, bs, config, backend): drop_path_rate = 0.0 drop_path_rates = [ rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)] + if bias_type != "no_bias": + bias = torch.randn(1, config.num_attention_heads, config.seq_len, config.seq_len, + dtype = dtype).cuda() + else: + bias = None block = ( TransformerLayer( @@ -215,8 +235,13 @@ def _run_transformer_layer(dtype, bs, config, backend): .cuda() ) - op = block(inp) - op.backward(op_grad) + num_iters = 10 + for i in range(num_iters): + op = block(inp, + checkpoint_core_attention = ckpt_attn, + core_attention_bias_type = bias_type, + core_attention_bias = bias) + op.backward(op_grad) return op, inp.grad @@ -246,19 +271,18 @@ def find_factors(x): unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer_gqa( dtype, bs, config, "UnfusedDotProductAttention", num_q_per_gqa_group) - atol, rtol = 5e-1, 5e-1 + atol, rtol = 5e-1, 5e-2 assert torch.allclose(flash_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(flash_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_group): - torch.manual_seed(1234) - torch.cuda.manual_seed(1234) + reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" - inp = 0.1 * torch.randn( + inp = torch.randn( config.seq_len, bs, config.num_attention_heads * config.head_dim, dtype = dtype).cuda() inp.requires_grad=True @@ -266,9 +290,9 @@ def _run_transformer_layer_gqa(dtype, bs, config, backend, num_querys_per_gqa_gr seqlens.fill_(config.seq_len) cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) - op_grad = 0.001 * torch.randint(0, 200, ( - config.seq_len, bs, config.num_attention_heads * config.head_dim - ), dtype = dtype).cuda() + op_grad = torch.randn( + config.seq_len, bs, config.num_attention_heads * config.head_dim, + dtype = dtype).cuda() sigma = 0.02 init_method = init_method_normal(sigma) @@ -342,14 +366,13 @@ def test_dpa_fp8(dtype, bs, model): unfused_attn_fwd, unfused_attn_bwd = _run_dpa_fp8_ref( dtype, bs, config, "UnfusedDotProductAttention") - atol, rtol = (5e-2, 1e-1) + atol, rtol = (2.5e-2, 2.5e-2) assert torch.allclose(fused_attn_fwd, unfused_attn_fwd, atol = atol, rtol = rtol) assert torch.allclose(fused_attn_bwd, unfused_attn_bwd, atol = atol, rtol = rtol) def _run_dpa_fp8(dtype, bs, config, backend): - torch.manual_seed(1234) - torch.cuda.manual_seed(1234) + reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" @@ -361,9 +384,9 @@ def _run_dpa_fp8(dtype, bs, config, backend): seqlens.fill_(config.seq_len) cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32) cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0) - op_grad = 0.001 * torch.randint(0, 200, ( - bs * config.seq_len, config.num_attention_heads * config.head_dim - ), dtype = dtype).cuda() + op_grad = 0.01 * torch.randn( + bs * config.seq_len, config.num_attention_heads * config.head_dim, + dtype = dtype).cuda() torch.save(op_grad, 'op_grad.pt') fp8_recipe = recipe.DelayedScaling( diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index dec8d9e107..6260c291c4 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -25,7 +25,6 @@ ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint - seed = 1234 rng_str = "rng_state" torch.manual_seed(seed) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 25f62cad09..957c0b4735 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -32,9 +32,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( && (max_seqlen_q <= 512) && (head_dim == 64) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) - && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) + && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) && (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) { +#if (CUDNN_VERSION >= 8900) backend = NVTE_Fused_Attn_Backend::NVTE_FP8; +#else + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." + " Please upgrade your cuDNN version if possible." << std::endl; +#endif } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { bool flag_m512 = false; bool flag_arb = false; @@ -76,6 +82,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) { backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } +#if (CUDNN_VERSION < 8901) + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." + " Please upgrade your cuDNN version if possible." << std::endl; + } +#endif +#if (CUDNN_VERSION < 8900) + if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." + " Please upgrade your cuDNN version if possible." << std::endl; + } +#endif } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; } diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 447b1f9d6a..b71573ec1b 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -136,10 +136,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * * Support Matrix: \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | - | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 | - | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | + | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | \endverbatim * * \param[in] QKV The QKV tensor in packed format, @@ -181,10 +181,10 @@ void nvte_fused_attn_fwd_qkvpacked( * * Support Matrix: \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | - | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 | - | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | QKV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | + | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL_MASK | Yes | > 512 | 64, 128 | + | 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING_MASK | Yes | <= 512 | 64 | \endverbatim * * \param[in] QKV The QKV tensor in packed format, @@ -235,8 +235,8 @@ void nvte_fused_attn_bwd_qkvpacked( * * Support Matrix: \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | \endverbatim * * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. @@ -283,8 +283,8 @@ void nvte_fused_attn_fwd_kvpacked( * * Support Matrix: \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | + | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | + | 0 | FP16/BF16 | KV_INTERLEAVED | NO/POST_SCALE_BIAS | PADDING/CAUSAL/NO_MASK | Yes | <= 512 | 64 | \endverbatim * * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 56040fc490..58a4279067 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -8,7 +8,7 @@ import math from importlib.metadata import version from contextlib import nullcontext -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union, Dict from pkg_resources import packaging import torch @@ -34,6 +34,7 @@ from transformer_engine.pytorch.constants import ( AttnMaskTypes, AttnTypes, + AttnBiasTypes, dist_group_type, TE_DType, ) @@ -227,6 +228,8 @@ def forward( key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + core_attention_bias_type: str = "no_bias", + core_attention_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: """core attention fprop""" batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] @@ -275,13 +278,42 @@ def forward( scale *= self.layer_number # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / scale), - ) + if core_attention_bias_type == "no_bias": + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / scale), + ) + + elif core_attention_bias_type == "pre_scale_bias": + assert core_attention_bias is not None, "core_attention_bias should not be None!" + assert (core_attention_bias.shape == torch.Size(1, *output_size[1:]) + ), "core_attention_bias must be in [1, h, sq, skv] shape!" + matmul_result = torch.bmm( + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + ) + matmul_result = (matmul_result.view( + output_size[0], output_size[1], output_size[2], output_size[3]) + + core_attention_bias).view(-1, output_size[2], output_size[3]) + matmul_result /= scale + + elif core_attention_bias_type == "post_scale_bias": + assert core_attention_bias is not None, "core_attention_bias should not be None!" + assert (core_attention_bias.shape == torch.Size([1, *output_size[1:]]) + ), "core_attention_bias must be in [1, h, sq, skv] shape!" + matmul_result = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / scale), + ) + matmul_result = (matmul_result.view( + output_size[0], output_size[1], output_size[2], output_size[3]) + + core_attention_bias).view(-1, output_size[2], output_size[3]) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) @@ -689,13 +721,17 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, + fused_attention_backend: + tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, ) -> torch.Tensor: """fused attention fprop""" + assert (fused_attention_backend + != tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend + ), 'No fused attention backend supports this input combination!' assert ( (query_layer.dtype in [torch.float16, torch.bfloat16]) and (key_layer.dtype in [torch.float16, torch.bfloat16]) @@ -865,7 +901,7 @@ class DotProductAttention(torch.nn.Module): is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. attention_dropout: float, default = 0.0 dropout probability for the dropout op during multi-head attention. - attn_mask_type: {'causal', 'padding'}, default = `causal` + attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` type of attention mask passed into softmax operation. layer_number: int, default = `None` layer number of the current `DotProductAttention` when multiple such modules @@ -964,11 +1000,12 @@ def _checkpointed_attention_forward( self, attention_func: Callable, *forward_args: Tuple[torch.Tensor, ...], + **forward_kwargs: Dict[str, Any], ) -> torch.Tensor: """Forward method with activation checkpointing.""" - def custom_forward(*inputs): - return attention_func(*inputs) + def custom_forward(*input_args, **input_kwargs): + return attention_func(*input_args, **input_kwargs) hidden_states = checkpoint( custom_forward, @@ -976,6 +1013,7 @@ def custom_forward(*inputs): self.get_rng_state_tracker, self.tp_group, *forward_args, + **forward_kwargs, ) return hidden_states @@ -1067,33 +1105,38 @@ def forward( use_flash_attention = False use_fused_attention = False + if core_attention_bias_type != "no_bias" or core_attention_bias is not None: + use_flash_attention = False + if is_in_onnx_export_mode(): use_flash_attention = False use_fused_attention = False qkv_layout = "qkv_interleaved" if self.attention_type == "self" else "kv_interleaved" - fused_attention_backend = tex.get_fused_attn_backend( - TE_DType[query_layer.dtype], - TE_DType[key_layer.dtype], - QKVLayout[qkv_layout], - AttnBiasType[core_attention_bias_type], - AttnMaskType[self.attn_mask_type], - self.attention_dropout, - query_layer.shape[0], key_layer.shape[0], - query_layer.shape[-1]) - # DPA does not support FP8; for FP8, use cpp_extensions modules directly - is_backend_avail = (fused_attention_backend in - [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) - use_fused_attention = (use_fused_attention - and is_backend_avail - and self.num_gqa_groups == self.num_attention_heads) - if (self.deterministic - and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]): - use_fused_attention = False - warnings.warn( - "Disabling usage of FusedAttention since this FusedAttention" - "backend does not support deterministic execution." - ) + + if use_fused_attention: + fused_attention_backend = tex.get_fused_attn_backend( + TE_DType[query_layer.dtype], + TE_DType[key_layer.dtype], + QKVLayout[qkv_layout], + AttnBiasType[core_attention_bias_type], + AttnMaskType[self.attn_mask_type], + self.attention_dropout, + query_layer.shape[0], key_layer.shape[0], + query_layer.shape[-1]) + # DPA does not support FP8; for FP8, use cpp_extensions modules directly + is_backend_avail = (fused_attention_backend in + [FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]]) + use_fused_attention = (use_fused_attention + and is_backend_avail + and self.num_gqa_groups == self.num_attention_heads) + if (self.deterministic + and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]): + use_fused_attention = False + warnings.warn( + "Disabling usage of FusedAttention since the FusedAttention" + "backend does not support deterministic exection." + ) if use_flash_attention: if checkpoint_core_attention: @@ -1106,18 +1149,18 @@ def forward( if use_fused_attention: if checkpoint_core_attention: return self._checkpointed_attention_forward(self.fused_attention, - query_layer, - key_layer, - value_layer, - fused_attention_backend, - core_attention_bias_type, - core_attention_bias, - fast_zero_fill) + query_layer, + key_layer, + value_layer, + fused_attention_backend = fused_attention_backend, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias, + fast_zero_fill = fast_zero_fill) return self.fused_attention(query_layer, key_layer, value_layer, - fused_attention_backend, - core_attention_bias_type, - core_attention_bias, - fast_zero_fill) + fused_attention_backend = fused_attention_backend, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias, + fast_zero_fill = fast_zero_fill) if checkpoint_core_attention: return self._checkpointed_attention_forward( @@ -1125,9 +1168,17 @@ def forward( query_layer, key_layer, value_layer, - attention_mask, + attention_mask = attention_mask, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias, ) - return self.unfused_attention(query_layer, key_layer, value_layer, attention_mask) + return self.unfused_attention(query_layer, + key_layer, + value_layer, + attention_mask = attention_mask, + core_attention_bias_type = core_attention_bias_type, + core_attention_bias = core_attention_bias, + ) class MultiHeadAttention(torch.nn.Module): @@ -1350,6 +1401,8 @@ def forward( attention_mask.dtype == torch.bool ), "Attention mask must be a boolean tensor" + assert (core_attention_bias_type in AttnBiasTypes + ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 8d109026fb..ee43fa10d9 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -26,6 +26,8 @@ AttnTypes = ("self", "cross") +AttnBiasTypes = ("pre_scale_bias", "post_scale_bias", "no_bias") + LayerTypes = ("encoder", "decoder") GemmParallelModes = ("row", "column", None) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 35a1fa72f3..dd6fb3e2f8 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -254,7 +254,7 @@ def fused_attn_fwd_qkvpacked( if attn_bias_type != "no_bias": assert (attn_bias is not None ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." - assert (attn_bias.shape == [1, h, max_seqlen, max_seqlen] + assert (attn_bias.shape == torch.Size([1, h, max_seqlen, max_seqlen]) ), "attn_bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." assert (attn_bias.dtype == qkv.dtype ), "attn_bias tensor must be in the same dtype as qkv." @@ -599,7 +599,7 @@ def fused_attn_fwd_kvpacked( if attn_bias_type != "no_bias": assert (attn_bias is not None ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias." - assert (attn_bias.shape == [1, h, max_seqlen_q, max_seqlen_kv] + assert (attn_bias.shape == torch.Size([1, h, max_seqlen_q, max_seqlen_kv]) ), "attn_bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." assert (attn_bias.dtype == q.dtype ), "attn_bias tensor must be in the same dtype as q and kv."