-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Fix failure on pattern matching of FP8 GEMM when enabling FSDP. #547
[JAX] Fix failure on pattern matching of FP8 GEMM when enabling FSDP. #547
Conversation
a63ec56
to
591e21c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically, it seems the cast_transpose is replaced with either the custom cast
or the native quantize
. I feel the cast
and quantize
are functionally same. Do we have a rule or guide on when to use which?
@denera to review. |
We apply the native However, there are few cases in backward, like We also target on replacing all custom |
Currently, UTs does not include FSDP related tests, therefore the tests cannot capture this kind of failures. Extending the UTs to conver wide range of cases require some time. Had an internal discussion, will add this into TODO. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
/te-ci jax |
This PR description tell that this fix a bug. |
/te-ci jax |
39ca75d
to
60a0eab
Compare
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
60a0eab
to
2ce5724
Compare
/te-ci jax |
cast
.cast_and_transpose
withcast
to the kernel oflayernorm_fp8_dot
and the kernel_1 oflayernrom_geglu_fp8_mlp
to allow XLA handletranspose
for avoiding unnecessarycopy
to break FP8 GEMM pattern matching.cast_and_transpose
with native XLA cast to the x and kernel offp8_dot
and the kernel_2 oflayernrom_geglu_fp8_mlp
to allow XLA handletranspose
for avoiding unnecessarycopy
to break FP8 GEMM pattern matching.layernrom_geglu_fp8_mlp
inflax.LayernormMLP
.