Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Regression tests for custom ops with jax.experimental.custom_partitioning #471

Merged
merged 4 commits into from
Nov 3, 2023

Conversation

denera
Copy link
Collaborator

@denera denera commented Oct 11, 2023

Included ops:

  • layernorm() inherited from Donglin Yang's fork (cc: @Young768) + rmsnorm type
  • softmax() SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED
  • self_fused_attn(), NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK
  • cross_fused_attn(), NO_BIAS, NO_MASK, PADDING_MASK

Tested shardings:

  • DP
  • TP_COL
  • TP_ROW (softmax only)
  • DP_TP_COL
  • DP_TP_ROW (softmax only)

@denera denera self-assigned this Oct 11, 2023
@mingxu1067
Copy link
Collaborator

In summary, could we have following changes?

  • To remove Xmap testing, since this is deprecating and would be removed in the nearly feature.
  • To support different num of collective ops of each custom calls with different sharding patterns.
  • To add correctness checking to guarantee the output values are expected.

@denera
Copy link
Collaborator Author

denera commented Oct 12, 2023

I'll expand the tests to check correctness against JAX-native flax.linen counterparts, but I'm not clear on the Xmap deprecation.

Unless it's getting deprecated in JAX itself, it may be useful to keep it around at least internally so we can use it as a reference to check custom partitioning communication patterns. Otherwise I don't know what to use as reference in its place.

@nouiz
Copy link
Collaborator

nouiz commented Oct 12, 2023

I'll expand the tests to check correctness against JAX-native flax.linen counterparts, but I'm not clear on the Xmap deprecation.

Unless it's getting deprecated in JAX itself, it may be useful to keep it around at least internally so we can use it as a reference to check custom partitioning communication patterns. Otherwise I don't know what to use as reference in its place.

I think we need to hardcode the expected collective we wants.
We won't keep xmap around, too much code and would be a big maintenance burden. If we hardcode the expectation, we won't need it.

@denera denera force-pushed the jax-custom-partitioning-tests branch from 7b6ea00 to 5d13090 Compare October 13, 2023 19:24
tests/jax/custom_ops_helper.py Outdated Show resolved Hide resolved
tests/jax/custom_ops_helper.py Outdated Show resolved Hide resolved
tests/jax/sharding_configs.py Show resolved Hide resolved
tests/jax/sharding_configs.py Outdated Show resolved Hide resolved
tests/jax/test_custom_ops_cpar.py Outdated Show resolved Hide resolved
tests/jax/test_custom_ops_xmap.py Outdated Show resolved Hide resolved
@nouiz
Copy link
Collaborator

nouiz commented Oct 20, 2023

@mingxu1067 can you also review?

@denera
Copy link
Collaborator Author

denera commented Oct 20, 2023

@mingxu1067 Quick note on xmap tests: I needed to run these ops and mesh configs with xmap shardings just to get the collective counts that are now hardcoded for the custom_partioning tests. I kept the code for that in the PR for now as regression tests for xmap in case there's a brief phase where we keep xmap alive with a deprecation warning instead of removing it immediately.

@denera denera force-pushed the jax-custom-partitioning-tests branch 3 times, most recently from 140e2fe to 8678bcf Compare October 23, 2023 15:28
@denera denera marked this pull request as ready for review October 23, 2023 15:30
@denera denera requested review from nouiz and mingxu1067 October 23, 2023 16:57
@denera
Copy link
Collaborator Author

denera commented Oct 26, 2023

/te-ci jax

2 similar comments
@denera
Copy link
Collaborator Author

denera commented Oct 26, 2023

/te-ci jax

@denera
Copy link
Collaborator Author

denera commented Nov 1, 2023

/te-ci jax

…rtitioning.

Xma-based sharding tests are functional, while custom_partitioning tests are awaiting the custom ops migration to be merged in.

Coverage:
- layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL
- rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL
- softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW
- self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL
- cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL

Signed-off-by: Alp Dener <[email protected]>
@denera denera force-pushed the jax-custom-partitioning-tests branch from 1e2d5f8 to c832ea1 Compare November 1, 2023 20:42
@denera
Copy link
Collaborator Author

denera commented Nov 1, 2023

/te-ci jax

1 similar comment
@denera
Copy link
Collaborator Author

denera commented Nov 2, 2023

/te-ci jax

…r testing custom_partitioning with 4GPUs.

Signed-off-by: Alp Dener <[email protected]>
@denera
Copy link
Collaborator Author

denera commented Nov 2, 2023

/te-ci jax

@denera denera changed the title Regression tests for custom ops with jax.experimental.custom_partitioning [JAX] Regression tests for custom ops with jax.experimental.custom_partitioning Nov 3, 2023
@denera denera merged commit d20ba9f into NVIDIA:main Nov 3, 2023
9 checks passed
cyanguwa pushed a commit to cyanguwa/TransformerEngine that referenced this pull request Nov 13, 2023
…rtitioning (NVIDIA#471)

[JAX] Regression tests for custom ops sharding with both xmap and custom_partitioning.

Coverage:
- layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL
- rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL
- softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW
- self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL
- cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants