Skip to content

[JAX] Regression tests for custom ops with jax.experimental.custom_partitioning#471

Merged
denera merged 4 commits intoNVIDIA:mainfrom
denera:jax-custom-partitioning-tests
Nov 3, 2023
Merged

[JAX] Regression tests for custom ops with jax.experimental.custom_partitioning#471
denera merged 4 commits intoNVIDIA:mainfrom
denera:jax-custom-partitioning-tests

Conversation

@denera
Copy link
Copy Markdown
Collaborator

@denera denera commented Oct 11, 2023

Included ops:

  • layernorm() inherited from Donglin Yang's fork (cc: @Young768) + rmsnorm type
  • softmax() SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED
  • self_fused_attn(), NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK
  • cross_fused_attn(), NO_BIAS, NO_MASK, PADDING_MASK

Tested shardings:

  • DP
  • TP_COL
  • TP_ROW (softmax only)
  • DP_TP_COL
  • DP_TP_ROW (softmax only)

@denera denera self-assigned this Oct 11, 2023
Comment thread tests/jax/test_custom_partitioning.py Outdated
Comment thread tests/jax/test_custom_partitioning.py Outdated
Comment thread tests/jax/test_custom_partitioning.py Outdated
Comment thread tests/jax/test_custom_partitioning.py Outdated
Comment thread tests/jax/test_custom_partitioning.py Outdated
Comment thread tests/jax/test_custom_partitioning.py Outdated
Comment thread tests/jax/test_custom_partitioning.py Outdated
Comment thread tests/jax/test_custom_partitioning.py Outdated
@mingxu1067
Copy link
Copy Markdown
Collaborator

In summary, could we have following changes?

  • To remove Xmap testing, since this is deprecating and would be removed in the nearly feature.
  • To support different num of collective ops of each custom calls with different sharding patterns.
  • To add correctness checking to guarantee the output values are expected.

@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Oct 12, 2023

I'll expand the tests to check correctness against JAX-native flax.linen counterparts, but I'm not clear on the Xmap deprecation.

Unless it's getting deprecated in JAX itself, it may be useful to keep it around at least internally so we can use it as a reference to check custom partitioning communication patterns. Otherwise I don't know what to use as reference in its place.

@nouiz
Copy link
Copy Markdown
Collaborator

nouiz commented Oct 12, 2023

I'll expand the tests to check correctness against JAX-native flax.linen counterparts, but I'm not clear on the Xmap deprecation.

Unless it's getting deprecated in JAX itself, it may be useful to keep it around at least internally so we can use it as a reference to check custom partitioning communication patterns. Otherwise I don't know what to use as reference in its place.

I think we need to hardcode the expected collective we wants.
We won't keep xmap around, too much code and would be a big maintenance burden. If we hardcode the expectation, we won't need it.

@denera denera force-pushed the jax-custom-partitioning-tests branch from 7b6ea00 to 5d13090 Compare October 13, 2023 19:24
Comment thread tests/jax/custom_ops_helper.py Outdated
Comment thread tests/jax/custom_ops_helper.py Outdated
Comment thread tests/jax/sharding_configs.py Outdated
Comment thread tests/jax/sharding_configs.py Outdated
Comment thread tests/jax/test_custom_ops_cpar.py Outdated
Comment thread tests/jax/test_custom_ops_xmap.py Outdated
@nouiz
Copy link
Copy Markdown
Collaborator

nouiz commented Oct 20, 2023

@mingxu1067 can you also review?

@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Oct 20, 2023

@mingxu1067 Quick note on xmap tests: I needed to run these ops and mesh configs with xmap shardings just to get the collective counts that are now hardcoded for the custom_partioning tests. I kept the code for that in the PR for now as regression tests for xmap in case there's a brief phase where we keep xmap alive with a deprecation warning instead of removing it immediately.

@denera denera force-pushed the jax-custom-partitioning-tests branch 3 times, most recently from 140e2fe to 8678bcf Compare October 23, 2023 15:28
@denera denera marked this pull request as ready for review October 23, 2023 15:30
@denera denera requested review from mingxu1067 and nouiz October 23, 2023 16:57
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Oct 26, 2023

/te-ci jax

2 similar comments
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Oct 26, 2023

/te-ci jax

@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Nov 1, 2023

/te-ci jax

…rtitioning.

Xma-based sharding tests are functional, while custom_partitioning tests are awaiting the custom ops migration to be merged in.

Coverage:
- layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL
- rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL
- softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW
- self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL
- cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the jax-custom-partitioning-tests branch from 1e2d5f8 to c832ea1 Compare November 1, 2023 20:42
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Nov 1, 2023

/te-ci jax

1 similar comment
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Nov 2, 2023

/te-ci jax

…r testing custom_partitioning with 4GPUs.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented Nov 2, 2023

/te-ci jax

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera changed the title Regression tests for custom ops with jax.experimental.custom_partitioning [JAX] Regression tests for custom ops with jax.experimental.custom_partitioning Nov 3, 2023
@denera denera merged commit d20ba9f into NVIDIA:main Nov 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants