Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[COMMON/JAX] Support sliding window on THD format #1327

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
3 changes: 1 addition & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from utils import (
make_causal_mask,
make_self_mask,
assert_tree_like_allclose,
assert_allclose,
print_debug_tensor_stats,
)
Expand Down Expand Up @@ -492,7 +491,7 @@ def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
if attn_mask_type.is_causal():
gradient_multiplier /= 10
ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
Expand Down
13 changes: 3 additions & 10 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,6 @@ def general_dot_product_attention(
return context


def is_causal_mask(mask: AttnMaskType):
"""
Check if the mask is a causal mask
"""
return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]


def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
"""
Create inverse padded causal mask where `True` means allowing the corresponding
Expand Down Expand Up @@ -135,7 +128,7 @@ def make_mask(
inv_mask = make_attention_mask(
q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
)
if is_causal_mask(attn_mask_type):
if attn_mask_type.is_causal():
inv_causal_mask = make_causal_mask(q_token, kv_token)
inv_mask = combine_masks(inv_causal_mask, inv_mask)
if segment_pad_q is not None and segment_pad_kv is not None:
Expand Down Expand Up @@ -454,7 +447,7 @@ def generate_random_segment_ids(
)
# TODO(rewang): Check if qkvpacked supported different q/kv
# TODO(rewang): Causal with different q/kv segment_id fails
if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type):
if self.qkv_layout == QKVLayout.T3HD or self.attn_mask_type.is_causal():
self.token_kv = self.token_q
self.segment_pad_kv = self.segment_pad_q
else:
Expand Down Expand Up @@ -552,7 +545,7 @@ def test_backward(self):
def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
gradient_multiplier = self.max_seqlen_q * self.num_heads_q
if is_causal_mask(self.attn_mask_type):
if self.attn_mask_type.is_causal():
gradient_multiplier /= 10
# Keep only valid result for the gradient
ret_valid = jnp.where(
Expand Down
28 changes: 12 additions & 16 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();

const bool is_ragged = (qkv_format == NVTE_THD);
// For ragged offsets we only support 32-bit prior to cuDNN 9.5
// Only used when THD format is requested.
const bool requires_64bit_ragged_offset =
(qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype(
layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q,
max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64);
(is_ragged && fused_attn::get_ragged_offset_dtype(layout_group, num_attn_heads,
num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim_qk, head_dim_v) == DType::kInt64);
const bool supported_ragged_offset_size =
(!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500);

Expand All @@ -99,9 +100,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) &&
(head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) ||
((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) &&
(max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) &&
((qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(qkv_format == NVTE_QKV_Format::NVTE_SBHD)) &&
(max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && (!is_ragged) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))) &&
!requires_64bit_ragged_offset) {
Expand Down Expand Up @@ -173,31 +172,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((cudnn_runtime_version >= 90300) &&
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (!is_ragged) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) &&
// bias + mask combination
(!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
// qkv format
((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 90600)))) &&
((!is_ragged) || (is_ragged && sm_arch_ >= 90 &&
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 90600)))) &&
// sliding window
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
(window_size_right == -1 || window_size_right == 0)) ||
(cudnn_runtime_version >= 90200 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
((!is_ragged && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(is_ragged && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you factor the self-attn bottom_right combination into this condition as well?

(!is_ragged && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv))) || (is_ragged && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv))) 

(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q == max_seqlen_kv)) &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
qkv_format == NVTE_QKV_Format::NVTE_SBHD))))) &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)))) &&
// check 64-bit ragged offset support
(supported_ragged_offset_size)) {
flag_arb = true;
Expand Down
36 changes: 26 additions & 10 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,30 @@ class AttnMaskType(Enum):
CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK
PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK

def is_causal(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice pythonic helpers!

"""Returns True if the mask is a causal mask"""
return self in [
AttnMaskType.CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_MASK,
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]

def is_padding(self):
"""Returns True if the mask includes padding"""
return self in [
AttnMaskType.PADDING_MASK,
AttnMaskType.PADDING_CAUSAL_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]

def is_bottom_right(self):
"""Returns True if the causal mask is calculated from the bottom-right section"""
return self in [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]


class QKVLayout(Enum):
"""
Expand Down Expand Up @@ -123,12 +147,8 @@ def make_swa_mask(
swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
if window_size is None:
return swa_mask
bottom_right_masks = [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
left_window, right_window = window_size
if attn_mask_type in bottom_right_masks:
if attn_mask_type.is_bottom_right():
if left_window < 0:
left_window = max_seqlen_kv
if right_window < 0:
Expand Down Expand Up @@ -313,11 +333,7 @@ def fused_attn(
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"

# convert the mask to seqlens, mask doesn't support ragged offsets
if attn_mask_type in [
AttnMaskType.NO_MASK,
AttnMaskType.CAUSAL_MASK,
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
]:
if not attn_mask_type.is_padding():
batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout)
q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32)
kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32)
Expand Down
Loading