Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Expose cp params to jax DPA api #1292

Merged
merged 4 commits into from
Nov 4, 2024

Conversation

kocchop
Copy link
Collaborator

@kocchop kocchop commented Oct 27, 2024

Description

Surface the context parallelism parameters to the JAX DotProductAttention API

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • added the required 2 CP configs to the high level jax DPA api
  • removed the is_context_parallel arg from `is_fused_attn_available().
  • removed the attention mask check for CP from is_fused_attn_available(). The check_supported() inside _FusedAttnCPWithAllGatherHelper essentially performs the check
  • fixed the test_distributed_fused_attn accordingly by removing the last arg related to is_context_parallel flag from is _fused_attn_available() api calls

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@@ -215,11 +214,6 @@ def make_helper(attn_mask_type):
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False

# For context parallel need to check additional masking types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there is no way to pass None from maxtext when CP is not used for the axis? I think we should keep this check here so that we can avoid exceptions on cuDNN when we don't have appropriate support.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mgoldfarb-nvidia , we can definitely pass a boolean from jax side. However, I decided to remove this because:

  1. Wanted to make as minimal changes as possible to the high level DotProductAttention api
  2. The same check is performed here. It essentially provides meaningful error msg and its higher than cudnn level:
...
...
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 757, in <module>
    app.run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 753, in main
    train_loop(config)
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 649, in train_loop
    state, metrics = p_train_step(state, example_batch, nextrng)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: custom_partitioner: Traceback (most recent call last):
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 757, in <module>
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 753, in main
  File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 649, in train_loop
  File "/opt/jax/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/opt/jax/jax/_src/pjit.py", line 337, in cache_miss
  File "/opt/jax/jax/_src/pjit.py", line 187, in _python_pjit_helper
  File "/opt/jax/jax/_src/core.py", line 2820, in bind
  File "/opt/jax/jax/_src/core.py", line 442, in bind_with_trace
  File "/opt/jax/jax/_src/core.py", line 955, in process_primitive
  File "/opt/jax/jax/_src/pjit.py", line 1736, in _pjit_call_impl
  File "/opt/jax/jax/_src/pjit.py", line 1712, in call_impl_cache_miss
  File "/opt/jax/jax/_src/pjit.py", line 1642, in _pjit_call_impl_python
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2346, in compile
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2855, in from_hlo
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2667, in _cached_compilation
  File "/opt/jax/jax/_src/compiler.py", line 434, in compile_or_get_cached
  File "/opt/jax/jax/_src/compiler.py", line 662, in _compile_and_write_cache
  File "/opt/jax/jax/_src/profiler.py", line 333, in wrapper
  File "/opt/jax/jax/_src/compiler.py", line 267, in backend_compile
  File "/opt/jax/jax/_src/custom_partitioning.py", line 155, in _custom_partitioning_partition
  File "/opt/transformer-engine/transformer_engine/jax/cpp_extensions/attention.py", line 1202, in partition
  File "/opt/transformer-engine/transformer_engine/jax/cpp_extensions/attention.py", line 1046, in check_supported
ValueError: Context parallel fused attention only supports masking types:  NVTE_Mask_Type.NVTE_NO_MASK,NVTE_Mask_Type.NVTE_CAUSAL_MASK got: NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay- I have a proposal although it might not be the cleanest. We use this method in our unit tests to also know if the configuration is supported before running i.e. to skip invalid configs we know to be failing.

We could add the is_context_parallel argument back but as an Optional[bool] and if its passed as None we don't attempt to do the more specific check. That way we can still query the full support at the top level

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we infer is_context_parallel from cp_axis? I noticed there is a similar logic here. If so, we can pass is_context_parallel here in the DotProductAttention module without adding an additional argument.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The challenge @kocchop ran into is that this check only works once the code is being transformed by Jax jit. He found that in MaxText the axis information is not available at the DPA api level and thus the implicit axis check fails.

Copy link
Collaborator

@zlsh80826 zlsh80826 Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Then how about this:
We keep is_fused_attn_kernel_available context_parallism-agnostic, means no is_context_parallel argument. But we call is_fused_attn_kernel_available again in _FusedAttnCPWithAllGatherHelper.check_support if the attn_mask_type == AttnMaskType.CAUSAL_MASK.
Unlike other configs, we can still fall back to the unfused attn. If we don't have the fused attn kernels with CP, then we can just raise a ValueError for that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works and maybe the best compromise here. The one place we need to update then is the unit test which needs to skip configs not supported. We currently rely on is_fused_attn_kernel_available to do this check.

We can call into _FusedAttnCPWithAllGatherHelper.check_support as you suggest here from the unit test to also properly skip as needed.

@phu0ngng phu0ngng requested a review from zlsh80826 October 29, 2024 18:37
@denera denera requested review from denera and phu0ngng October 29, 2024 18:39
@@ -215,11 +214,6 @@ def make_helper(attn_mask_type):
if not make_helper(attn_mask_type).is_fused_attn_kernel_available():
return False

# For context parallel need to check additional masking types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we infer is_context_parallel from cp_axis? I noticed there is a similar logic here. If so, we can pass is_context_parallel here in the DotProductAttention module without adding an additional argument.

transformer_engine/jax/flax/transformer.py Show resolved Hide resolved
kocchop and others added 2 commits October 30, 2024 20:43
Signed-off-by: Md Fahim Faysal Khan <[email protected]>
Signed-off-by: Michael Goldfarb <[email protected]>
@mgoldfarb-nvidia mgoldfarb-nvidia force-pushed the faysal/expose-cp-to-jax-dpa branch from b0c5c06 to 0b302d4 Compare October 30, 2024 20:43
@phu0ngng phu0ngng removed their request for review October 31, 2024 15:20
@kocchop kocchop requested a review from zlsh80826 October 31, 2024 22:17
@mgoldfarb-nvidia
Copy link
Collaborator

/te-ci jax

Copy link
Collaborator

@mgoldfarb-nvidia mgoldfarb-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@zlsh80826 zlsh80826 changed the title expose cp params to jax DPA api [JAX] Expose cp params to jax DPA api Nov 1, 2024
@zlsh80826
Copy link
Collaborator

LGTM, I will approve it once all CI passed.

@phu0ngng
Copy link
Collaborator

phu0ngng commented Nov 1, 2024

Please rebase with the main branch to include the fix introduced in #1304 to avoid unresolved failures caused by FFI and rerun the L1 CI @kocchop. Thanks

@mgoldfarb-nvidia
Copy link
Collaborator

/te-ci jax L1

@mgoldfarb-nvidia mgoldfarb-nvidia merged commit d725686 into NVIDIA:main Nov 4, 2024
21 checks passed
huanghua1994 pushed a commit to huanghua1994/TransformerEngine that referenced this pull request Nov 4, 2024
Exposed context parallel params to DPA api

Signed-off-by: Md Fahim Faysal Khan <[email protected]>
Signed-off-by: Michael Goldfarb <[email protected]>

---------

Signed-off-by: Md Fahim Faysal Khan <[email protected]>
Signed-off-by: Michael Goldfarb <[email protected]>
Co-authored-by: Michael Goldfarb <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants