[JAX] Regression tests for custom ops with jax.experimental.custom_partitioning#471
Conversation
|
In summary, could we have following changes?
|
|
I'll expand the tests to check correctness against JAX-native Unless it's getting deprecated in JAX itself, it may be useful to keep it around at least internally so we can use it as a reference to check custom partitioning communication patterns. Otherwise I don't know what to use as reference in its place. |
I think we need to hardcode the expected collective we wants. |
7b6ea00 to
5d13090
Compare
|
@mingxu1067 can you also review? |
|
@mingxu1067 Quick note on xmap tests: I needed to run these ops and mesh configs with xmap shardings just to get the collective counts that are now hardcoded for the custom_partioning tests. I kept the code for that in the PR for now as regression tests for xmap in case there's a brief phase where we keep xmap alive with a deprecation warning instead of removing it immediately. |
140e2fe to
8678bcf
Compare
|
/te-ci jax |
2 similar comments
|
/te-ci jax |
|
/te-ci jax |
…rtitioning. Xma-based sharding tests are functional, while custom_partitioning tests are awaiting the custom ops migration to be merged in. Coverage: - layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL - rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL - softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW - self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL - cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
1e2d5f8 to
c832ea1
Compare
|
/te-ci jax |
1 similar comment
|
/te-ci jax |
…r testing custom_partitioning with 4GPUs. Signed-off-by: Alp Dener <adener@nvidia.com>
|
/te-ci jax |
Signed-off-by: Alp Dener <adener@nvidia.com>
Included ops:
layernorm()inherited from Donglin Yang's fork (cc: @Young768) +rmsnormtypesoftmax()SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKEDself_fused_attn(), NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASKcross_fused_attn(), NO_BIAS, NO_MASK, PADDING_MASKTested shardings: