Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Migrating from Xmap to Custom Partitioning for All Custom Calls #472

Merged
merged 51 commits into from
Nov 14, 2023

Conversation

mingxu1067
Copy link
Collaborator

@mingxu1067 mingxu1067 commented Oct 12, 2023

  • Adding batching rules to all custom calls for further support pipeline parallelism in Paxml.
  • Replacing the FP8 GEMM custom call with the native XLA implementation.
  • Removing explicitly calling of update_fp8_meta, instead hinding into backward.
  • Migrating Sharding from xmap to custom_partitioning.
  • FP32/BF16 related:
  • layernorm_fwd, layernorm_bwd
  • rmsnorm_fwd, rmsnorm_bwd
  • fused_softmax_fwd, fused_softmax_bwd,
  • fused_masked_softmax_fwd, fused_masked_softmax_bwd
  • fused_upper_triangle_softmax_fwd, fused_upper_triangle_softmax_bwd,
  • geglu_fwd, geglu_bwd
  • self_fused_attention_fwd, self_fused_attention_bwd
  • cross_attention_fwd, cross_attention_bwd
  • FP8 related:
  • cast_and_transpose, transpose
  • layernorm_fp8, rmsnorm_fp8
  • geglu_fp8, dgeglu_cast_and_transpose
  • Removing all setup related to xmap, including ShardingMeta.
  • Setting FP8 meta as WeightHParamsCollection.OVERWRITE_WITH_GRADIENT in Praxis modules.

@mingxu1067
Copy link
Collaborator Author

@zlsh80826 and @jeng1220 for viz.

@mingxu1067
Copy link
Collaborator Author

/te-ci jax

@nouiz
Copy link
Collaborator

nouiz commented Oct 12, 2023

@timmoon10 can you review?

@mingxu1067 mingxu1067 force-pushed the mingh/ln_softmax_custom_partitioning branch from 8e63b5a to 784749a Compare October 13, 2023 03:57
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

transformer_engine/jax/layernorm.py Show resolved Hide resolved
transformer_engine/jax/layernorm.py Outdated Show resolved Hide resolved
transformer_engine/jax/sharding.py Outdated Show resolved Hide resolved
transformer_engine/jax/sharding.py Outdated Show resolved Hide resolved
@ptrendx
Copy link
Member

ptrendx commented Oct 13, 2023

/te-ci jax

transformer_engine/jax/cpp_extensions.py Outdated Show resolved Hide resolved
transformer_engine/jax/cpp_extensions.py Outdated Show resolved Hide resolved
transformer_engine/jax/cpp_extensions.py Show resolved Hide resolved
transformer_engine/jax/cpp_extensions.py Outdated Show resolved Hide resolved
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@nouiz nouiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small optional nitpik.

transformer_engine/jax/cpp_extensions.py Show resolved Hide resolved
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

1 similar comment
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

@mingxu1067
Copy link
Collaborator Author

mingxu1067 commented Oct 19, 2023

Bad News, Xmap and CP cannot work together, since CP does not support manual HloSharding, but xmap would require manual sharding, then raising an assertion.

jax/_src/sharding_impls.py(1387): parse_flatten_op_sharding
jax/experimental/custom_partitioning.py(498): to_mesh_pspec_sharding

What we can do:

  1. Pending thie PR util all custom_calls ready, including FP8 related.
  2. Isolate CP and Xmap as WAR, like for FP8 to use pure xmap and BF16 to pure CP.

@nouiz @jeng1220 for comment.

@mingxu1067 mingxu1067 changed the title [JAX] Migrating from Xmap to Custom Partitioning for LN, RMSN and Softmax [JAX] Migrating from Xmap to Custom Partitioning for All Custom Calls Nov 3, 2023
@mingxu1067
Copy link
Collaborator Author

@nouiz @timmoon10
All custom calls are migrated from xmap to custom_partitioning. Would you kindly help review? Thanks.

CC: @denera @Young768 @zlsh80826 @jeng1220 for viz

@mingxu1067 mingxu1067 force-pushed the mingh/ln_softmax_custom_partitioning branch 3 times, most recently from a74a53d to 2ab34ee Compare November 3, 2023 04:37
@denera
Copy link
Collaborator

denera commented Nov 3, 2023

@mingxu1067 I merged #471 earlier today with the custom partitioning unit tests. You're welcome to rebase your PR and run them. If we see any issues, I'd be happy to commit the fixes in your branch so that the regression tests are all fully functional and correct by the time we merge this.

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming-Xu Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
@mingxu1067 mingxu1067 force-pushed the mingh/ln_softmax_custom_partitioning branch from 794857d to 139a7db Compare November 14, 2023 06:03
@mingxu1067 mingxu1067 force-pushed the mingh/ln_softmax_custom_partitioning branch from 139a7db to 25eb6c0 Compare November 14, 2023 06:13
Copy link
Collaborator

@zlsh80826 zlsh80826 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for this awesome PR. It really simplifies the implementation of custom calls and supports more complex sharding methods.

@mingxu1067 mingxu1067 force-pushed the mingh/ln_softmax_custom_partitioning branch 2 times, most recently from 1a9d9b9 to 9ab60cc Compare November 14, 2023 09:20
@mingxu1067 mingxu1067 force-pushed the mingh/ln_softmax_custom_partitioning branch from 9ab60cc to dda823f Compare November 14, 2023 10:27
@mingxu1067 mingxu1067 force-pushed the mingh/ln_softmax_custom_partitioning branch from dda823f to 1f8d903 Compare November 14, 2023 11:05
@denera denera merged commit 71e51ea into NVIDIA:main Nov 14, 2023
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants