Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Fix failure on pattern matching of FP8 GEMM when enabling FSDP. #547

Merged
merged 8 commits into from
Jan 12, 2024

Conversation

mingxu1067
Copy link
Collaborator

  • Add a custom call, cast.
  • Replace cast_and_transpose with cast to the kernel of layernorm_fp8_dot and the kernel_1 of layernrom_geglu_fp8_mlp to allow XLA handle transpose for avoiding unnecessary copy to break FP8 GEMM pattern matching.
  • Replace cast_and_transpose with native XLA cast to the x and kernel of fp8_dot and the kernel_2 of layernrom_geglu_fp8_mlp to allow XLA handle transpose for avoiding unnecessary copy to break FP8 GEMM pattern matching.
  • Fix a bug of enabling layernrom_geglu_fp8_mlp in flax.LayernormMLP.

@mingxu1067 mingxu1067 self-assigned this Dec 1, 2023
@mingxu1067 mingxu1067 force-pushed the mingh/fix_failure_of_xla_fp8_with_fsdp branch from a63ec56 to 591e21c Compare December 1, 2023 06:03
Copy link
Contributor

@kaixih kaixih left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, it seems the cast_transpose is replaced with either the custom cast or the native quantize. I feel the cast and quantize are functionally same. Do we have a rule or guide on when to use which?

transformer_engine/jax/layernorm.py Outdated Show resolved Hide resolved
@nouiz
Copy link
Collaborator

nouiz commented Dec 1, 2023

@denera to review.
@mingxu1067 Can you extent the all current fp8 sharding tests to verify that the gemm is in fp8?
This will test the failure you fix and make sure it doesn't regress and make sure no equivalent issue happen elsewhere.

@mingxu1067
Copy link
Collaborator Author

Basically, it seems the cast_transpose is replaced with either the custom cast or the native quantize. I feel the cast and quantize are functionally same. Do we have a rule or guide on when to use which?

We apply the native quantize and transpose when the tensors is split along its column. For example, a tensor in shape (M, N), and it is sharded along N. In this case, native quantize and transpose offer better flexibility for XLA to schedule all-gather and transpose, then avoid unnecessary copy.

However, there are few cases in backward, like x of fp8_dot, excluded from the above rule, but we found it also introduces unexpected copy. For this case, we also apply the native quantize and transpose to solve.

We also target on replacing all custom cast_fp8 and cast_transpose with the native implementation in the feature. Will schedule a time to evaluate the performance gap and work effort needed.

@mingxu1067
Copy link
Collaborator Author

@denera to review. @mingxu1067 Can you extent the all current fp8 sharding tests to verify that the gemm is in fp8? This will test the failure you fix and make sure it doesn't regress and make sure no equivalent issue happen elsewhere.

Currently, UTs does not include FSDP related tests, therefore the tests cannot capture this kind of failures. Extending the UTs to conver wide range of cases require some time. Had an internal discussion, will add this into TODO.

Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mingxu1067
Copy link
Collaborator Author

/te-ci jax

@nouiz
Copy link
Collaborator

nouiz commented Dec 6, 2023

This PR description tell that this fix a bug.
We need a test that make sure we don't regress.

@mingxu1067
Copy link
Collaborator Author

/te-ci jax

@mingxu1067 mingxu1067 force-pushed the mingh/fix_failure_of_xla_fp8_with_fsdp branch from 39ca75d to 60a0eab Compare December 14, 2023 07:28
@mingxu1067 mingxu1067 force-pushed the mingh/fix_failure_of_xla_fp8_with_fsdp branch from 60a0eab to 2ce5724 Compare December 15, 2023 01:39
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

@denera denera merged commit 2ae121d into NVIDIA:main Jan 12, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants