Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Expose cp params to jax DPA api #1292

Merged
merged 4 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ def test_self_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -268,7 +267,6 @@ def test_cross_attn(
seqlen,
hidden,
None, # no window
False, # not context parallel
):
pytest.skip(f"No FusedAttn backend found")

Expand Down Expand Up @@ -425,22 +423,32 @@ def test_contex_parallel_self_attn(
num_kv_heads = num_head // kv_groups
scaling_factor = 1.0 / np.sqrt(num_head)

if not is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None, # no window
cp_size > 1,
):
pytest.skip(f"No FusedAttn backend found")
def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
dtype,
dtype,
qkv_layout,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_kv_heads,
seqlen,
seqlen,
hidden,
None,
) # no SWA for CP

# For causal masking we depend on having bottom right support also.
# The API does not check this and instead we rely on lower level checks to raise
# and exception if the step backend is not supported. This was a deliberate API
# decision to keep the CP size or flag out of the function.
has_backend = check_has_backend_for_mask(attn_mask_type)
if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK:
has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK)

if not has_backend:
pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.")

if dp_size > 1 and batch % dp_size != 0:
pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}")
Expand Down
6 changes: 0 additions & 6 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def is_fused_attn_kernel_available(
kv_max_seqlen,
head_dim,
window_size: Optional[Tuple[int, int]] = None,
is_context_parallel: bool = False,
):
"""
To check whether the fused attention kernel is supported
Expand All @@ -215,11 +214,6 @@ def make_helper(attn_mask_type):
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False

# For context parallel need to check additional masking types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there is no way to pass None from maxtext when CP is not used for the axis? I think we should keep this check here so that we can avoid exceptions on cuDNN when we don't have appropriate support.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mgoldfarb-nvidia , we can definitely pass a boolean from jax side. However, I decided to remove this because:

  1. Wanted to make as minimal changes as possible to the high level DotProductAttention api
  2. The same check is performed here. It essentially provides meaningful error msg and its higher than cudnn level:
...
...
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 757, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 753, in main
    train_loop(config)
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 649, in train_loop
    state, metrics = p_train_step(state, example_batch, nextrng)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: custom_partitioner: Traceback (most recent call last):
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 757, in <module>
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 753, in main
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 649, in train_loop
  File "/opt/jax/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/opt/jax/jax/_src/pjit.py", line 337, in cache_miss
  File "/opt/jax/jax/_src/pjit.py", line 187, in _python_pjit_helper
  File "/opt/jax/jax/_src/core.py", line 2820, in bind
  File "/opt/jax/jax/_src/core.py", line 442, in bind_with_trace
  File "/opt/jax/jax/_src/core.py", line 955, in process_primitive
  File "/opt/jax/jax/_src/pjit.py", line 1736, in _pjit_call_impl
  File "/opt/jax/jax/_src/pjit.py", line 1712, in call_impl_cache_miss
  File "/opt/jax/jax/_src/pjit.py", line 1642, in _pjit_call_impl_python
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2346, in compile
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2855, in from_hlo
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2667, in _cached_compilation
  File "/opt/jax/jax/_src/compiler.py", line 434, in compile_or_get_cached
  File "/opt/jax/jax/_src/compiler.py", line 662, in _compile_and_write_cache
  File "/opt/jax/jax/_src/profiler.py", line 333, in wrapper
  File "/opt/jax/jax/_src/compiler.py", line 267, in backend_compile
  File "/opt/jax/jax/_src/custom_partitioning.py", line 155, in _custom_partitioning_partition
  File "/opt/transformer-engine/transformer_engine/jax/cpp_extensions/attention.py", line 1202, in partition
  File "/opt/transformer-engine/transformer_engine/jax/cpp_extensions/attention.py", line 1046, in check_supported
ValueError: Context parallel fused attention only supports masking types:  NVTE_Mask_Type.NVTE_NO_MASK,NVTE_Mask_Type.NVTE_CAUSAL_MASK got: NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay- I have a proposal although it might not be the cleanest. We use this method in our unit tests to also know if the configuration is supported before running i.e. to skip invalid configs we know to be failing.

We could add the is_context_parallel argument back but as an Optional[bool] and if its passed as None we don't attempt to do the more specific check. That way we can still query the full support at the top level

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we infer is_context_parallel from cp_axis? I noticed there is a similar logic here. If so, we can pass is_context_parallel here in the DotProductAttention module without adding an additional argument.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The challenge @kocchop ran into is that this check only works once the code is being transformed by Jax jit. He found that in MaxText the axis information is not available at the DPA api level and thus the implicit axis check fails.

Copy link
Collaborator

@zlsh80826 zlsh80826 Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Then how about this:
We keep is_fused_attn_kernel_available context_parallism-agnostic, means no is_context_parallel argument. But we call is_fused_attn_kernel_available again in _FusedAttnCPWithAllGatherHelper.check_support if the attn_mask_type == AttnMaskType.CAUSAL_MASK.
Unlike other configs, we can still fall back to the unfused attn. If we don't have the fused attn kernels with CP, then we can just raise a ValueError for that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works and maybe the best compromise here. The one place we need to update then is the unit test which needs to skip configs not supported. We currently rely on is_fused_attn_kernel_available to do this check.

We can call into _FusedAttnCPWithAllGatherHelper.check_support as you suggest here from the unit test to also properly skip as needed.

if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK:
if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available():
return False

return True


Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""

@nn.compact
def __call__(
Expand Down Expand Up @@ -308,6 +310,8 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
"""kvpacked format, treat
Expand All @@ -331,6 +335,8 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
if self.transpose_batch_sequence:
Expand All @@ -349,6 +355,8 @@ def __call__(
dropout_probability=self.attention_dropout,
is_training=not deterministic,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)
else:
raise ValueError(f"Unsupported {self.qkv_layout=}.")
Expand Down Expand Up @@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.

Optimization parameters
-----------------------
Expand All @@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
zlsh80826 marked this conversation as resolved.
Show resolved Hide resolved

@nn.compact
def __call__(
Expand Down Expand Up @@ -614,6 +627,8 @@ def __call__(
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
window_size=self.window_size,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
)(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)

return x
Expand Down
Loading