-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Expose cp params to jax DPA api #1292
[JAX] Expose cp params to jax DPA api #1292
Conversation
@@ -215,11 +214,6 @@ def make_helper(attn_mask_type): | |||
if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): | |||
return False | |||
|
|||
# For context parallel need to check additional masking types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So there is no way to pass None
from maxtext when CP is not used for the axis? I think we should keep this check here so that we can avoid exceptions on cuDNN when we don't have appropriate support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mgoldfarb-nvidia , we can definitely pass a boolean from jax side. However, I decided to remove this because:
- Wanted to make as minimal changes as possible to the high level
DotProductAttention
api - The same check is performed here. It essentially provides meaningful error msg and its higher than cudnn level:
...
...
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 757, in <module>
app.run(main)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 753, in main
train_loop(config)
File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 649, in train_loop
state, metrics = p_train_step(state, example_batch, nextrng)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: custom_partitioner: Traceback (most recent call last):
File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 757, in <module>
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 753, in main
File "/opt/workspace/maxtext_workspace/context/maxtext/MaxText/train.py", line 649, in train_loop
File "/opt/jax/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "/opt/jax/jax/_src/pjit.py", line 337, in cache_miss
File "/opt/jax/jax/_src/pjit.py", line 187, in _python_pjit_helper
File "/opt/jax/jax/_src/core.py", line 2820, in bind
File "/opt/jax/jax/_src/core.py", line 442, in bind_with_trace
File "/opt/jax/jax/_src/core.py", line 955, in process_primitive
File "/opt/jax/jax/_src/pjit.py", line 1736, in _pjit_call_impl
File "/opt/jax/jax/_src/pjit.py", line 1712, in call_impl_cache_miss
File "/opt/jax/jax/_src/pjit.py", line 1642, in _pjit_call_impl_python
File "/opt/jax/jax/_src/interpreters/pxla.py", line 2346, in compile
File "/opt/jax/jax/_src/interpreters/pxla.py", line 2855, in from_hlo
File "/opt/jax/jax/_src/interpreters/pxla.py", line 2667, in _cached_compilation
File "/opt/jax/jax/_src/compiler.py", line 434, in compile_or_get_cached
File "/opt/jax/jax/_src/compiler.py", line 662, in _compile_and_write_cache
File "/opt/jax/jax/_src/profiler.py", line 333, in wrapper
File "/opt/jax/jax/_src/compiler.py", line 267, in backend_compile
File "/opt/jax/jax/_src/custom_partitioning.py", line 155, in _custom_partitioning_partition
File "/opt/transformer-engine/transformer_engine/jax/cpp_extensions/attention.py", line 1202, in partition
File "/opt/transformer-engine/transformer_engine/jax/cpp_extensions/attention.py", line 1046, in check_supported
ValueError: Context parallel fused attention only supports masking types: NVTE_Mask_Type.NVTE_NO_MASK,NVTE_Mask_Type.NVTE_CAUSAL_MASK got: NVTE_Mask_Type.NVTE_PADDING_CAUSAL_MASK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay- I have a proposal although it might not be the cleanest. We use this method in our unit tests to also know if the configuration is supported before running i.e. to skip invalid configs we know to be failing.
We could add the is_context_parallel
argument back but as an Optional[bool]
and if its passed as None
we don't attempt to do the more specific check. That way we can still query the full support at the top level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The challenge @kocchop ran into is that this check only works once the code is being transformed by Jax jit. He found that in MaxText the axis information is not available at the DPA api level and thus the implicit axis check fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Then how about this:
We keep is_fused_attn_kernel_available
context_parallism-agnostic, means no is_context_parallel
argument. But we call is_fused_attn_kernel_available
again in _FusedAttnCPWithAllGatherHelper.check_support
if the attn_mask_type == AttnMaskType.CAUSAL_MASK
.
Unlike other configs, we can still fall back to the unfused attn. If we don't have the fused attn kernels with CP, then we can just raise a ValueError for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works and maybe the best compromise here. The one place we need to update then is the unit test which needs to skip configs not supported. We currently rely on is_fused_attn_kernel_available
to do this check.
We can call into _FusedAttnCPWithAllGatherHelper.check_support
as you suggest here from the unit test to also properly skip as needed.
@@ -215,11 +214,6 @@ def make_helper(attn_mask_type): | |||
if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): | |||
return False | |||
|
|||
# For context parallel need to check additional masking types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: Md Fahim Faysal Khan <[email protected]>
Signed-off-by: Michael Goldfarb <[email protected]>
b0c5c06
to
0b302d4
Compare
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
LGTM, I will approve it once all CI passed. |
/te-ci jax L1 |
Exposed context parallel params to DPA api Signed-off-by: Md Fahim Faysal Khan <[email protected]> Signed-off-by: Michael Goldfarb <[email protected]> --------- Signed-off-by: Md Fahim Faysal Khan <[email protected]> Signed-off-by: Michael Goldfarb <[email protected]> Co-authored-by: Michael Goldfarb <[email protected]>
Description
Surface the context parallelism parameters to the JAX
DotProductAttention
APIFixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
is_context_parallel
arg from `is_fused_attn_available().is_fused_attn_available()
. Thecheck_supported()
inside_FusedAttnCPWithAllGatherHelper
essentially performs the checktest_distributed_fused_attn
accordingly by removing the last arg related tois_context_parallel
flag fromis _fused_attn_available()
api callsChecklist: