diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 936021bfed..6335448560 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45 +Subproject commit 633544856000a2c8b37620c42a36f700f6c709f3 diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 4ef56b1b10..b0a34f7e6c 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -20,7 +20,6 @@ from utils import ( make_causal_mask, make_self_mask, - assert_tree_like_allclose, assert_allclose, print_debug_tensor_stats, ) @@ -492,7 +491,7 @@ def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient _, max_seq_len, num_heads, _ = data_shape gradient_multiplier = max_seq_len * num_heads - if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]: + if attn_mask_type.is_causal(): gradient_multiplier /= 10 ret_valid = func(*args, **kwargs) return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index a48a38e6a9..4f93e0b5e8 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -101,13 +101,6 @@ def general_dot_product_attention( return context -def is_causal_mask(mask: AttnMaskType): - """ - Check if the mask is a causal mask - """ - return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] - - def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: """ Create inverse padded causal mask where `True` means allowing the corresponding @@ -135,7 +128,7 @@ def make_mask( inv_mask = make_attention_mask( q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) ) - if is_causal_mask(attn_mask_type): + if attn_mask_type.is_causal(): inv_causal_mask = make_causal_mask(q_token, kv_token) inv_mask = combine_masks(inv_causal_mask, inv_mask) if segment_pad_q is not None and segment_pad_kv is not None: @@ -454,7 +447,7 @@ def generate_random_segment_ids( ) # TODO(rewang): Check if qkvpacked supported different q/kv # TODO(rewang): Causal with different q/kv segment_id fails - if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type): + if self.qkv_layout == QKVLayout.T3HD or self.attn_mask_type.is_causal(): self.token_kv = self.token_q self.segment_pad_kv = self.segment_pad_q else: @@ -552,7 +545,7 @@ def test_backward(self): def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient gradient_multiplier = self.max_seqlen_q * self.num_heads_q - if is_causal_mask(self.attn_mask_type): + if self.attn_mask_type.is_causal(): gradient_multiplier /= 10 # Keep only valid result for the gradient ret_valid = jnp.where( diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9cde765401..505042cf18 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -84,12 +84,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); auto cudnn_runtime_version = cudnnGetVersion(); + const bool is_ragged = (qkv_format == NVTE_THD); // For ragged offsets we only support 32-bit prior to cuDNN 9.5 // Only used when THD format is requested. const bool requires_64bit_ragged_offset = - (qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype( - layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); + (is_ragged && fused_attn::get_ragged_offset_dtype(layout_group, num_attn_heads, + num_gqa_groups, max_seqlen_q, max_seqlen_kv, + head_dim_qk, head_dim_v) == DType::kInt64); const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); @@ -99,9 +100,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) && (head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || ((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) && - (max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && - ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || - (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && + (max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && (!is_ragged) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))) && !requires_64bit_ragged_offset) { @@ -173,8 +172,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((cudnn_runtime_version >= 90300) && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (!is_ragged) && max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && @@ -182,22 +180,20 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format - ((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) || - (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && - ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 90600)))) && + ((!is_ragged) || (is_ragged && sm_arch_ >= 90 && + ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || + (cudnn_runtime_version >= 90600)))) && // sliding window ((cudnn_runtime_version < 90200 && window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((!is_ragged && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || + (is_ragged && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && - dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD))))) && + dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)))) && // check 64-bit ragged offset support (supported_ragged_offset_size)) { flag_arb = true; diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 218fd174de..7782373620 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -46,6 +46,30 @@ class AttnMaskType(Enum): CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK + def is_causal(self): + """Returns True if the mask is a causal mask""" + return self in [ + AttnMaskType.CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_padding(self): + """Returns True if the mask includes padding""" + return self in [ + AttnMaskType.PADDING_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_bottom_right(self): + """Returns True if the causal mask is calculated from the bottom-right section""" + return self in [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + class QKVLayout(Enum): """ @@ -123,12 +147,8 @@ def make_swa_mask( swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) if window_size is None: return swa_mask - bottom_right_masks = [ - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - ] left_window, right_window = window_size - if attn_mask_type in bottom_right_masks: + if attn_mask_type.is_bottom_right(): if left_window < 0: left_window = max_seqlen_kv if right_window < 0: @@ -313,11 +333,7 @@ def fused_attn( ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" # convert the mask to seqlens, mask doesn't support ragged offsets - if attn_mask_type in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - ]: + if not attn_mask_type.is_padding(): batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)