-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Expose cp params to jax DPA api #1292
Merged
mgoldfarb-nvidia
merged 4 commits into
NVIDIA:main
from
kocchop:faysal/expose-cp-to-jax-dpa
Nov 4, 2024
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
535af66
exposed cp params to DPA api
kocchop 0b302d4
Clean up checks in unit test.
mgoldfarb-nvidia cac031b
Merge branch 'main' into faysal/expose-cp-to-jax-dpa
mgoldfarb-nvidia 6a12787
Merge branch 'main' into faysal/expose-cp-to-jax-dpa
mgoldfarb-nvidia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So there is no way to pass
None
from maxtext when CP is not used for the axis? I think we should keep this check here so that we can avoid exceptions on cuDNN when we don't have appropriate support.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mgoldfarb-nvidia , we can definitely pass a boolean from jax side. However, I decided to remove this because:
DotProductAttention
apiThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay- I have a proposal although it might not be the cleanest. We use this method in our unit tests to also know if the configuration is supported before running i.e. to skip invalid configs we know to be failing.
We could add the
is_context_parallel
argument back but as anOptional[bool]
and if its passed asNone
we don't attempt to do the more specific check. That way we can still query the full support at the top levelThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we infer
is_context_parallel
fromcp_axis
? I noticed there is a similar logic here. If so, we can passis_context_parallel
here in theDotProductAttention
module without adding an additional argument.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The challenge @kocchop ran into is that this check only works once the code is being transformed by Jax jit. He found that in MaxText the axis information is not available at the DPA api level and thus the implicit axis check fails.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Then how about this:
We keep
is_fused_attn_kernel_available
context_parallism-agnostic, means nois_context_parallel
argument. But we callis_fused_attn_kernel_available
again in_FusedAttnCPWithAllGatherHelper.check_support
if theattn_mask_type == AttnMaskType.CAUSAL_MASK
.Unlike other configs, we can still fall back to the unfused attn. If we don't have the fused attn kernels with CP, then we can just raise a ValueError for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works and maybe the best compromise here. The one place we need to update then is the unit test which needs to skip configs not supported. We currently rely on
is_fused_attn_kernel_available
to do this check.We can call into
_FusedAttnCPWithAllGatherHelper.check_support
as you suggest here from the unit test to also properly skip as needed.