-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Regression tests for custom ops with jax.experimental.custom_partitioning #471
Conversation
In summary, could we have following changes?
|
I'll expand the tests to check correctness against JAX-native Unless it's getting deprecated in JAX itself, it may be useful to keep it around at least internally so we can use it as a reference to check custom partitioning communication patterns. Otherwise I don't know what to use as reference in its place. |
I think we need to hardcode the expected collective we wants. |
7b6ea00
to
5d13090
Compare
@mingxu1067 can you also review? |
@mingxu1067 Quick note on xmap tests: I needed to run these ops and mesh configs with xmap shardings just to get the collective counts that are now hardcoded for the custom_partioning tests. I kept the code for that in the PR for now as regression tests for xmap in case there's a brief phase where we keep xmap alive with a deprecation warning instead of removing it immediately. |
140e2fe
to
8678bcf
Compare
/te-ci jax |
2 similar comments
/te-ci jax |
/te-ci jax |
…rtitioning. Xma-based sharding tests are functional, while custom_partitioning tests are awaiting the custom ops migration to be merged in. Coverage: - layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL - rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL - softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW - self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL - cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
1e2d5f8
to
c832ea1
Compare
/te-ci jax |
1 similar comment
/te-ci jax |
…r testing custom_partitioning with 4GPUs. Signed-off-by: Alp Dener <[email protected]>
/te-ci jax |
Signed-off-by: Alp Dener <[email protected]>
…rtitioning (NVIDIA#471) [JAX] Regression tests for custom ops sharding with both xmap and custom_partitioning. Coverage: - layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL - rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL - softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW - self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL - cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Charlene Yang <[email protected]>
Included ops:
layernorm()
inherited from Donglin Yang's fork (cc: @Young768) +rmsnorm
typesoftmax()
SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKEDself_fused_attn()
, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASKcross_fused_attn()
, NO_BIAS, NO_MASK, PADDING_MASKTested shardings: