Skip to content

Commit

Permalink
Check for backend support in Jax context parallel fused attention test (
Browse files Browse the repository at this point in the history
#1227)

Update test to check support for context parallel attention.

Signed-off-by: Michael Goldfarb <[email protected]>
  • Loading branch information
mgoldfarb-nvidia authored Oct 15, 2024
1 parent 86f07be commit 20c55e4
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 33 deletions.
27 changes: 24 additions & 3 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ def test_self_attn(
seqlen,
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backwend found")
pytest.skip(f"No FusedAttn backend found")

def target_func(qkv, bias, mask):
return jnp.mean(
Expand Down Expand Up @@ -257,8 +259,10 @@ def test_cross_attn(
seqlen,
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backwend found")
pytest.skip(f"No FusedAttn backend found")

def target_func(q, kv, mask):
return jnp.mean(
Expand Down Expand Up @@ -403,7 +407,24 @@ def test_contex_parallel_self_attn(
_, seqlen, num_head, hidden = data_shape
num_kv_heads = num_head // kv_groups

# make sure the mesh evently divides cp and tp axis
if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")

# make sure the mesh even divides cp and tp axis
if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0:
pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}")

Expand Down
41 changes: 27 additions & 14 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,24 +190,37 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim,
window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
):
"""
To check whether the fused attention kernel is supported
"""
return tex.FusedAttnHelper(
q_dtype,
kv_dtype,
qkv_layout.value,
attn_bias_type.value,
attn_mask_type.value,
dropout_probability,
q_num_heads,
kv_num_heads,
q_max_seqlen,
kv_max_seqlen,
head_dim,
(-1, -1) if window_size is None else window_size,
).is_fused_attn_kernel_available()

def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
q_dtype,
kv_dtype,
qkv_layout.value,
attn_bias_type.value,
attn_mask_type.value,
dropout_probability,
q_num_heads,
kv_num_heads,
q_max_seqlen,
kv_max_seqlen,
head_dim,
(-1, -1) if window_size is None else window_size,
)

if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False

# For context parallel need to check additional masking types
if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False

return True


def _obtain_batch_and_max_seqlen(qkv, qkv_layout):
Expand Down
36 changes: 20 additions & 16 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,26 +923,30 @@ def check_supported(self):
header = "Context parallel fused attention"

allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD]
assert self.config.qkv_layout in allowed_layouts, (
f"{header} only supports layouts: {','.join([str(x) for x in allowed_layouts])} got:"
f" {self.config.qkv_layout}"
)
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}"
)

assert (
self.config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS
), f"{header} does not support bias got: {self.config.attn_bias_type}"
if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")

allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK]
assert self.config.attn_mask_type in allowed_masks, (
f"{header} only supports masking types: "
f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
)
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}"
)

assert self.config.max_segments_per_seq == 1, (
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
)
assert self.config.dropout_probability == 0.0, f"{header} does not support dropout"
if self.config.max_segments_per_seq != 1:
raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
)

if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout")

def get_adjusted_mask(self):
"""Converts the mask for context parallelism."""
Expand Down

0 comments on commit 20c55e4

Please sign in to comment.