diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 73dfa23d9a..afc2081752 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -77,12 +77,13 @@ def __init__( batch_size: int, num_heads: int, num_gqa_groups: int, - head_dim: int, + head_dim_qk: int, max_seqlen_q: int, max_seqlen_kv: int, dropout_p: float, attn_mask_type: str, attn_bias_type: str, + head_dim_v: int = None, alibi_type: str = "none", num_layers: int = 1, bias_shape: str = "1hss", @@ -91,9 +92,10 @@ def __init__( self.batch_size = batch_size self.num_heads = num_heads self.num_gqa_groups = num_gqa_groups - self.head_dim = head_dim - self.hidden_size = num_heads * head_dim - self.hidden_size_kv = num_gqa_groups * head_dim + self.head_dim_qk = head_dim_qk + self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v + self.hidden_size = num_heads * head_dim_qk + self.hidden_size_kv = num_gqa_groups * self.head_dim_v self.max_seqlen_q = max_seqlen_q self.max_seqlen_kv = max_seqlen_kv self.dropout_p = dropout_p @@ -137,7 +139,11 @@ def _get_attention_backends( ) core_attention_bias_requires_grad = False # d=256 is supported by cuDNN 9.0+ for inference but not training - if config.attn_bias_type == "post_scale_bias" and config.head_dim <= 128: + if ( + config.attn_bias_type == "post_scale_bias" + and config.head_dim_qk <= 128 + and config.head_dim_v <= 128 + ): core_attention_bias_requires_grad = True fused_attn_backends = [] @@ -153,7 +159,8 @@ def test(): num_gqa_groups=config.num_gqa_groups, max_seqlen_q=config.max_seqlen_q, max_seqlen_kv=config.max_seqlen_kv, - head_dim=config.head_dim, + head_dim_qk=config.head_dim_qk, + head_dim_v=config.head_dim_v, attn_mask_type=config.attn_mask_type, window_size=window_size, alibi_slopes_shape=alibi_slopes_shape, @@ -218,11 +225,12 @@ def test_dot_product_attention( if dtype == torch.bfloat16: tols = dict(atol=2.5e-2, rtol=2.5e-2) config = model_configs[model] + is_mla = config.head_dim_qk != config.head_dim_v if qkv_layout is None: if config.attn_type == "self": - qkv_layout = "sb3hd" + qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd" else: - qkv_layout = "sbhd_sb2hd" + qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd" if "3" in qkv_layout and config.attn_type == "cross": pytest.skip("No need to test this layout for cross attention") @@ -241,14 +249,17 @@ def test_dot_product_attention( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes - if pad_between_seqs: + if pad_between_seqs and not ( + config.max_seqlen_q != config.max_seqlen_kv + and config.attn_mask_type in ["causal", "padding_causal"] + ): flash_attn_supported = True # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") - is_training = config.head_dim <= 128 + is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 # UnfusedDotProductAttention backend if unfused_attn_supported: unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( @@ -343,6 +354,38 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) +model_configs_mla = { + # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend + "mla_1_0": ModelConfig( + 8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # self , 0 + "mla_1_1": ModelConfig( + 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # cross, 0 + "mla_2_0": ModelConfig( + 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 + ), # self , 1 + "mla_2_1": ModelConfig( + 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 + ), # cross, 1 + "mla_3_0": ModelConfig( + 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 + ), # inference + "mla_3_1": ModelConfig( + 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # inference +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_mla]) +@pytest.mark.parametrize("model", model_configs_mla.keys()) +def test_dpa_mla(dtype, model_configs, model): + """Test DotProductAttention module with Multi-Latent Attention (MLA)""" + test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) + + model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), @@ -586,14 +629,16 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): @pytest.mark.parametrize("qkv_layout", qkv_layouts_thd) def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" - pad_between_seqs = False - test_dot_product_attention( - dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs - ) pad_between_seqs = True test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) + if get_cudnn_version() >= (9, 3, 0): + # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run + pad_between_seqs = False + test_dot_product_attention( + dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs + ) def _run_dot_product_attention( @@ -736,7 +781,8 @@ def _run_dot_product_attention( "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim, + "dqk": config.head_dim_qk, + "dv": config.head_dim_v, "t": cu_seqlens_q_after_pad[-1], "tg": cu_seqlens_kv_after_pad[-1], "3": 3, @@ -753,12 +799,16 @@ def _run_dot_product_attention( layout = layout.replace("s", "skv") layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") + if i == 2: + layout = layout.replace("d", "dv") + else: + layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") tensor_orig = tensor if qkv_format == "thd" and pad_between_seqs: tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]: + if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]: for i in range(1, config.batch_size + 1): valid_range = ( cu_seqlens_q_after_pad[i - 1], @@ -772,7 +822,7 @@ def _run_dot_product_attention( tensor_orig = torch.cat( [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0 ) - if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]: + if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]: for i in range(1, config.batch_size + 1): valid_range = ( cu_seqlens_kv_after_pad[i - 1], @@ -811,13 +861,14 @@ def _run_dot_product_attention( # Create output gradient qkv_format_kv = "_".join(qkv_format) qkv_format_kv = qkv_format_kv.replace("s", "sq") + qkv_format_kv = qkv_format_kv.replace("d", "dv") out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")] out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda") out_grad_orig = out_grad if qkv_format == "thd" and pad_between_seqs: out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - if qkv_format_kv == "t_h_d": + if qkv_format_kv == "t_h_dv": for i in range(1, config.batch_size + 1): valid_range = ( cu_seqlens_q_after_pad[i - 1], @@ -851,7 +902,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: # Set up model block = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, qkv_format=qkv_format, @@ -906,9 +957,10 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) - v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + if is_training: + q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) + v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) for i in range(1, config.batch_size + 1): valid_range_q = ( cu_seqlens_q_after_pad[i - 1], @@ -919,15 +971,16 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: cu_seqlens_kv_after_pad[i] - pad_len[i - 1], ) out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0) - q_grad_orig = torch.cat( - [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0 - ) - k_grad_orig = torch.cat( - [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 - ) - v_grad_orig = torch.cat( - [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 - ) + if is_training: + q_grad_orig = torch.cat( + [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0 + ) + k_grad_orig = torch.cat( + [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 + ) + v_grad_orig = torch.cat( + [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 + ) if is_training: return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) else: @@ -1168,7 +1221,7 @@ def _run_transformer_layer( # Create RoPE rotary_pos_emb = None if RoPE: - PE = RotaryPositionEmbedding(dim=config.head_dim) + PE = RotaryPositionEmbedding(dim=config.head_dim_qk) rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda") # Set up model @@ -1183,7 +1236,7 @@ def _run_transformer_layer( init_method=init_method, output_layer_init_method=output_layer_init_method, layer_number=layer_number, - kv_channels=config.head_dim, + kv_channels=config.head_dim_qk, self_attn_mask_type=config.attn_mask_type, tp_group=None, tp_size=1, @@ -1356,7 +1409,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: mha = MultiheadAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_heads, - kv_channels=config.head_dim, + kv_channels=config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, layer_number=1, @@ -1387,7 +1440,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim, + "d": config.head_dim_qk, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -1531,7 +1584,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: with fp8_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, num_gqa_groups=config.num_gqa_groups, attention_dropout=config.dropout_p, sequence_parallel=False, @@ -1560,7 +1613,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "skv": config.max_seqlen_kv, "h": config.num_heads, "hg": config.num_gqa_groups, - "d": config.head_dim, + "d": config.head_dim_qk, "t": cu_seqlens_q[-1], "tg": cu_seqlens_kv[-1], "3": 3, @@ -1732,7 +1785,7 @@ def _run_custom_mha_fp8(dtype, config, backend): inp = 0.0001 * torch.randint( -100, 100, - (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim), + (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk), dtype=dtype, device="cuda", requires_grad=True, @@ -1743,7 +1796,7 @@ def _run_custom_mha_fp8(dtype, config, backend): out_grad = 0.01 * torch.randn( config.batch_size * config.max_seqlen_q, - config.num_heads * config.head_dim, + config.num_heads * config.head_dim_qk, dtype=dtype, device="cuda", ) @@ -1766,7 +1819,7 @@ def _run_custom_mha_fp8(dtype, config, backend): return ( out.view(config.batch_size, config.max_seqlen_q, -1), dqkv.view( - config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim + config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk ).contiguous(), ) @@ -1809,7 +1862,7 @@ def get_dummy_cuda_rng_tracker(): block = DotProductAttention( config.num_heads, - config.head_dim, + config.head_dim_qk, attention_dropout=config.dropout_p, sequence_parallel=False, tp_size=1, @@ -2105,7 +2158,7 @@ def __init__(self, config, params_dtype: torch.dtype = torch.float32): self.p_dropout = config.dropout_p self.h = config.num_heads self.hidden_size = config.hidden_size - self.head_dim = config.head_dim + self.head_dim = config.head_dim_qk self.fast_zero_fill = True self.mask_type = config.attn_mask_type diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index bdc459cdcc..e8361a2190 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -1083,7 +1083,7 @@ def test_export_core_attention( model = te.attention.DotProductAttention( num_attention_heads=num_attention_heads, - kv_channels=kv_channels, + k_channels=kv_channels, attention_dropout=0.5, qkv_format=qkv_format, attn_mask_type=attn_mask_type, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 895baea789..0fe62f8cb4 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -72,8 +72,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, - int64_t window_size_right) { + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -84,10 +84,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && - (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || + (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) && + (head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || ((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) && - (max_seqlen_kv % 128 == 0) && (head_dim == 128) && + (max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -104,8 +104,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool flag_m512 = false; bool flag_arb = false; if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && - (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim == 64) && - (num_attn_heads == num_gqa_groups) && + (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) && + (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -131,11 +131,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || (cudnn_runtime_version >= 8907)) && // head dimension - ((head_dim <= 128 && head_dim % 8 == 0) || + ((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || // TODO (cyang): add is_training to nvte_get_fused_attn_backend // d=256 only supported for forward - (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 && - head_dim % 8 == 0)) && + (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && + head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version >= 8906) && @@ -155,6 +155,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || ((cudnn_runtime_version >= 90300) && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && @@ -259,7 +260,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, window_size_left, window_size_right); + max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -336,7 +337,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, window_size_left, window_size_right); + max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -430,7 +431,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -514,7 +515,7 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -595,7 +596,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -603,13 +605,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, + input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); @@ -617,18 +619,18 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, - handle); + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, + input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); @@ -674,7 +676,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim - 2]; - size_t d = input_Q->data.shape[ndim - 1]; + size_t d_qk = input_Q->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim - 1]; auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -682,15 +685,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, window_size_left, window_size_right); + max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_K, input_V, input_dO, output_S, - output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, - input_cu_seqlens_kv, wkspace, stream, handle); + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, + input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, + input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -705,9 +708,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_K, - input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, + input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -721,7 +724,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_M = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, output_dK, output_dV, input_cu_seqlens_q, input_cu_seqlens_kv, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 7ee7ba33bd..42fb779717 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -48,11 +48,11 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, - int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, + float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, + void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, @@ -86,7 +86,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( hg, s_q, s_kv, - d, + d_qk, + d_v, bias_b, bias_h, scaling_factor, @@ -167,41 +168,41 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); if (is_ragged) { Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_ragged_offset(offset_q)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_ragged_offset(offset_k)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_ragged_offset(offset_v)); } else { Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride)); K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride)); V = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride)); } @@ -265,15 +266,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { O->set_output(true) - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_ragged_offset(offset_o); } else { - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); } Stats->set_output(true) @@ -360,7 +361,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void *devOffsetsO = static_cast(devOffsetsV) + (b + 1) * sizeof(int32_t); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d, static_cast(devPtrSeqOffsetsQ), + layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), static_cast(devOffsetsQ), static_cast(devOffsetsK), static_cast(devOffsetsV), static_cast(devOffsetsO)); @@ -381,13 +382,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } void fused_attn_arbitrary_seqlen_bwd_impl( - int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d, int64_t bias_b, - int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, - void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, - void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, + NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, + void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, + void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, + void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -419,7 +420,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( hg, s_q, s_kv, - d, + d_qk, + d_v, bias_b, bias_h, scaling_factor, @@ -505,61 +507,61 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::vector k_stride(4); std::vector v_stride(4); std::vector o_stride(4); - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, k_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d, v_stride.data(), layout, + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, + generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_ragged_offset(offset_q)); k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_ragged_offset(offset_k)); v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_ragged_offset(offset_v)); o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_ragged_offset(offset_o)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) .set_ragged_offset(offset_o)); } else { q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride)); k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride)); v = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("V") - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride)); o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); } stats = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -644,21 +646,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged) { dQ->set_output(true) - .set_dim({b, h, s_q, d}) + .set_dim({b, h, s_q, d_qk}) .set_stride(q_stride) .set_ragged_offset(offset_q); dK->set_output(true) - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_qk}) .set_stride(k_stride) .set_ragged_offset(offset_k); dV->set_output(true) - .set_dim({b, hg, s_kv, d}) + .set_dim({b, hg, s_kv, d_v}) .set_stride(v_stride) .set_ragged_offset(offset_v); } else { - dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d}).set_stride(v_stride); + dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); + dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride); + dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride); } std::tuple, // q @@ -758,7 +760,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( void *devOffsetsO = static_cast(devOffsetsV) + (b + 1) * sizeof(int32_t); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d, static_cast(devPtrSeqOffsetsQ), + layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), static_cast(devOffsetsQ), static_cast(devOffsetsK), static_cast(devOffsetsV), static_cast(devOffsetsO)); @@ -865,11 +867,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, + bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, + devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, + handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -941,11 +944,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, + bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { @@ -1051,12 +1054,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, + bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1131,12 +1134,13 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, + bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1155,8 +1159,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -1233,12 +1237,12 @@ void fused_attn_arbitrary_seqlen_fwd( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, + devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1257,7 +1261,7 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, @@ -1302,12 +1306,13 @@ void fused_attn_arbitrary_seqlen_bwd( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, bias_b, bias_h, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, - devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, + devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 27a2dd37ea..4b523cca1a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -58,8 +58,8 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -68,7 +68,7 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index fcce30d6a1..bda3f5beba 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1679,6 +1679,7 @@ void fused_attn_fp8_fwd_impl_v1( s_q, s_kv, d, + d, bias_b, bias_h, scaling_factor, @@ -1976,6 +1977,7 @@ void fused_attn_fp8_bwd_impl_v1( s_q, s_kv, d, + d, bias_b, bias_h, scaling_factor, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 7467462d2a..56dbb278b4 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -363,29 +363,30 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu // convert cu_seqlens_padded to offsets __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, - size_t hg, size_t d, int32_t *cu_seqlens_q_padded, + size_t hg, size_t d_qk, size_t d_v, + int32_t *cu_seqlens_q_padded, int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, int32_t *offsets_k, int32_t *offsets_v, int32_t *offsets_o) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < b + 1) { - offsets_o[tid] = h * d * cu_seqlens_q_padded[tid]; + offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid]; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - offsets_q[tid] = h * d * cu_seqlens_q_padded[tid]; - offsets_k[tid] = hg * d * cu_seqlens_kv_padded[tid]; - offsets_v[tid] = offsets_k[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; + offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[tid]; + offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[tid]; break; case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_H3D: - offsets_q[tid] = 3 * h * d * cu_seqlens_q_padded[tid]; + offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[tid]; offsets_k[tid] = offsets_q[tid]; offsets_v[tid] = offsets_q[tid]; break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - offsets_q[tid] = h * d * cu_seqlens_q_padded[tid]; - offsets_k[tid] = 2 * hg * d * cu_seqlens_kv_padded[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; + offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[tid]; offsets_v[tid] = offsets_k[tid]; break; } diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 74d1628a33..d5cf450181 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -91,7 +91,8 @@ struct FADescriptor_v1 { std::int64_t hg; std::int64_t s_q; std::int64_t s_kv; - std::int64_t d; + std::int64_t d_qk; + std::int64_t d_v; std::int64_t bias_b; std::int64_t bias_h; float attnScale; @@ -107,11 +108,11 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { - return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining, + return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < - std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h, - rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, + std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b, + rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); } @@ -126,7 +127,8 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu int32_t *kv_seqlens); __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, - size_t hg, size_t d, int32_t *cu_seqlens_q_padded, + size_t hg, size_t d_qk, size_t d_v, + int32_t *cu_seqlens_q_padded, int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, int32_t *offsets_k, int32_t *offsets_v, int32_t *offsets_o); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 342c53bc7f..fa358bc86c 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -147,15 +147,16 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); * \param[in] num_gqa_groups The number of heads in K, V. * \param[in] max_seqlen_q The sequence length of Q. * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim The head dimension of Q, K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, - int64_t window_size_right); + size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 640869ac36..382b17d207 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -19,7 +19,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, auto backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, - head_dim, -1, -1); + head_dim, head_dim, -1, -1); return backend; } @@ -255,10 +255,10 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - auto backend = - nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), - qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1); + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + head_dim, head_dim, -1, -1); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -486,10 +486,10 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); - auto backend = - nvte_get_fused_attn_backend(static_cast(dtype), static_cast(dtype), - qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, -1, -1); + auto backend = nvte_get_fused_attn_backend( + static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, + mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + head_dim, head_dim, -1, -1); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index 60f06a2188..6ce250432a 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -131,10 +131,10 @@ inline NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) { - NVTE_Fused_Attn_Backend fused_attention_backend = - nvte_get_fused_attn_backend(static_cast(q_dtype), static_cast(kv_dtype), - qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, - num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, -1, -1); + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( + static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, + attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, + head_dim, head_dim, -1, -1); return fused_attention_backend; } diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d899934d76..0790315400 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -142,8 +142,10 @@ class AttentionParams: Maximum sequence length of the query tensor. max_seqlen_kv: int, default = 128 Maximum sequence length of the key and value tensors. - head_dim: int, default = 64 - The size of each attention head. + head_dim_qk: int, default = 64 + The size of each attention head in query and key tensors. + head_dim_v: int, default = 64 + The size of each attention head in the value tensor. attn_mask_type: str, default = `no_mask` Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} @@ -182,7 +184,8 @@ class AttentionParams: num_gqa_groups: int = 16 max_seqlen_q: int = 128 max_seqlen_kv: int = 128 - head_dim: int = 64 + head_dim_qk: int = 64 + head_dim_v: int = 64 attn_mask_type: str = "no_mask" window_size: Union[Tuple[int, int], None] = None alibi_slopes_shape: Union[torch.Size, List, None] = None @@ -245,7 +248,8 @@ def get_attention_backend( num_gqa_groups = attention_params.num_gqa_groups max_seqlen_q = attention_params.max_seqlen_q max_seqlen_kv = attention_params.max_seqlen_kv - head_dim = attention_params.head_dim + head_dim_qk = attention_params.head_dim_qk + head_dim_v = attention_params.head_dim_v attn_mask_type = attention_params.attn_mask_type window_size = attention_params.window_size alibi_slopes_shape = attention_params.alibi_slopes_shape @@ -352,19 +356,31 @@ def get_attention_backend( use_unfused_attention = False # Filter: Head dimension + if use_flash_attention and head_dim_qk != head_dim_v: + logger.debug("Disabling FlashAttention as it does not support MLA.") + use_flash_attention = False if use_flash_attention and ( - head_dim > 256 - or head_dim % 8 != 0 - or (head_dim > 192 and device_compute_capability not in ((8, 0), (9, 0))) + head_dim_qk > 256 + or head_dim_qk % 8 != 0 + or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0))) ): logger.debug( - "Disabling FlashAttention due to unsupported head_dim. " - "Supported: head_dim %%8 = 0, head_dim <= 256 (>192 requires sm80/90). " - "Found: head_dim = %s on sm%s.", - head_dim, + "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " + "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " + "head_dim_qk <= 256 (>192 requires sm80/90). " + "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", + head_dim_qk, + head_dim_v, ".".join([str(i) for i in device_compute_capability]), ) use_flash_attention = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") + if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": + logger.debug( + "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", + qkv_layout, + ) + use_fused_attention = False # Filter: QKV layout qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) @@ -557,7 +573,8 @@ def get_attention_backend( num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim, + head_dim_qk, + head_dim_v, window_size[0], window_size[1], ) @@ -3132,12 +3149,14 @@ def run_iteratively(q, k, v): stride = q.stride() check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) stride = k.stride() - check_strides_kv = all(stride == x.stride() for x in [k, v]) + check_strides_kv = torch.equal( + torch.Tensor(stride[:-1]) / k.shape[-1], torch.Tensor(v.stride()[:-1]) / v.shape[-1] + ) shape = q.shape check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) shape = k.shape - check_shapes_kv = all(shape == x.shape for x in [k, v]) + check_shapes_kv = shape[:-1] == v.shape[:-1] last_dim_size = q.shape[-1] check_last_dim_offsets_qkv = all( @@ -5177,8 +5196,10 @@ class DotProductAttention(TransformerEngineBaseModule): ---------- num_attention_heads : int number of attention heads in the transformer layer. - kv_channels : int - number of key-query-value channels per attention head. + k_channels : int + number of channels per attention head in key. + v_channels : Optional[int] = None + number of channels per attention head in value. num_gqa_groups : Optional[int] = None number of GQA groups in the transformer layer. Grouped Query Attention is described in @@ -5264,7 +5285,8 @@ class DotProductAttention(TransformerEngineBaseModule): def __init__( self, num_attention_heads: int, - kv_channels: int, + k_channels: int, + v_channels: Optional[int] = None, num_gqa_groups: Optional[int] = None, attention_dropout: float = 0.0, qkv_format: str = "sbhd", @@ -5304,7 +5326,8 @@ def __init__( self.cp_global_ranks = cp_global_ranks self.cp_stream = cp_stream - self.hidden_size_per_attention_head = kv_channels + self.hidden_size_per_attention_head = k_channels + self.v_channels = k_channels if v_channels is None else v_channels self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) @@ -5322,7 +5345,7 @@ def __init__( attention_dropout_ctx = self.rng_states_tracker.fork if softmax_scale is None: - softmax_scale = 1.0 / math.sqrt(kv_channels) + softmax_scale = 1.0 / math.sqrt(k_channels) self.deterministic = ( not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) @@ -5469,16 +5492,6 @@ def forward( Argument :attr:`attention_mask` is only used when :attr:`attn_mask_type` includes '"padding"' or `"arbitrary"`. - .. note:: - - Input tensor :attr:`query_layer` must be of shape - (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`, - :attr:`kv_channels`) and the tensors :attr:`key_layer` and :attr:`value_layer` - must each be of shape (:attr:`sequence_length`, :attr:`batch_size`, - :attr:`num_gqa_groups`, :attr:`kv_channels`). Output of shape - (:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads` - * :attr:`kv_channels`) is returned. - .. note:: DotProductAttention supports three backends: 1) FlashAttention which calls @@ -5628,7 +5641,9 @@ def forward( assert ( query_layer.dtype == key_layer.dtype and query_layer.dtype == value_layer.dtype ), "Queries, keys and values must have the same data type!" - assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" + assert ( + key_layer.shape[:-1] == value_layer.shape[:-1] + ), "Keys and values must have the same batch size, sequence length and number of heads!" if attn_mask_type is None: attn_mask_type = self.attn_mask_type @@ -5861,7 +5876,8 @@ def forward( num_gqa_groups=key_layer.shape[-2], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, - head_dim=query_layer.shape[-1], + head_dim_qk=query_layer.shape[-1], + head_dim_v=value_layer.shape[-1], attn_mask_type=attn_mask_type, window_size=window_size, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 4dc169da00..d0ba644621 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -140,7 +140,7 @@ def fused_attn_fwd_qkvpacked( output tensor, amax of O, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -342,7 +342,7 @@ def fused_attn_bwd_qkvpacked( output tensor, amax of dQKV, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -508,7 +508,7 @@ def fused_attn_fwd_kvpacked( output tensor, amax of O, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -729,7 +729,7 @@ def fused_attn_bwd_kvpacked( output tensor, amax of dQKV, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -907,7 +907,7 @@ def fused_attn_fwd( output tensor, amax of O, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -1135,7 +1135,7 @@ def fused_attn_bwd( output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim) as the default + if None, use 1.0/sqrt(head_dim_qk) as the default dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f06b0cb197..bd908e9336 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -14,11 +14,14 @@ * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend( - const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right); +NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype, + const transformer_engine::DType kv_dtype, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, float p_dropout, + size_t num_attn_heads, size_t num_gqa_groups, + size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, + int64_t window_size_left, int64_t window_size_right); std::vector fused_attn_fwd_qkvpacked( size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 9cdc79ed64..50eb7b830f 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -14,11 +14,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, int64_t window_size_left, int64_t window_size_right) { + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim, window_size_left, window_size_right); + head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; } @@ -761,7 +762,11 @@ std::vector fused_attn_fwd( std::vector v_shape{v_sizes.begin(), v_sizes.end()}; // create output tensor O - auto O = torch::empty_like(Q); + auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); + auto o_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + o_shape[o_shape.size() - 1] = v_sizes[v_sizes.size() - 1]; + std::vector o_shape_tmp{o_shape.begin(), o_shape.end()}; + auto O = torch::empty(c10::IntArrayRef(o_shape_tmp), options); // construct NVTE tensors TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; @@ -790,7 +795,7 @@ std::vector fused_attn_fwd( descale_QKV.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_S.value().data_ptr(), scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, amax_O.value().data_ptr(), + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, amax_O.value().data_ptr(), scale_O.value().data_ptr(), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { @@ -801,7 +806,7 @@ std::vector fused_attn_fwd( te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -839,8 +844,7 @@ std::vector fused_attn_fwd( auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); @@ -935,8 +939,11 @@ std::vector fused_attn_bwd( std::vector v_shape{v_sizes.begin(), v_sizes.end()}; auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; + auto d_qk = q_shape[q_shape.size() - 1]; + auto d_v = v_shape[v_shape.size() - 1]; auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); + std::vector o_shape{q_sizes.begin(), q_sizes.end()}; + o_shape[o_shape.size() - 1] = d_v; at::Tensor dQ; at::Tensor dK; @@ -1015,7 +1022,7 @@ std::vector fused_attn_bwd( TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) && + if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); @@ -1041,9 +1048,9 @@ std::vector fused_attn_bwd( descale_QKV.value().data_ptr()); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, + te_dO = makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, descale_dO.value().data_ptr()); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, scale_S.value().data_ptr(), descale_S.value().data_ptr()); @@ -1068,9 +1075,9 @@ std::vector fused_attn_bwd( te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); + te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); te_dO = - makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); + makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, nullptr); te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); te_dQ =