Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CUDA Graph for MoE models #1233

Merged
merged 16 commits into from
Nov 25, 2024
Merged

Support CUDA Graph for MoE models #1233

merged 16 commits into from
Nov 25, 2024

Conversation

buptzyb
Copy link
Contributor

@buptzyb buptzyb commented Oct 9, 2024

Description

Different from non-MoE models like llama2, MoE models have dynamic-shaped activations in FFN layers, so one cudagraph can only capture a part of one transformer layer, instead of covering the whole layer. We call this a "breaking-layer" cudagraph mode. This PR adds breaking-layer cudagraph supports for MoE models on the TE side, and fixes several related bugs in TE.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • Fix wrong per_callable_module_params order bug in _make_graphed_callables when _order is given.
  • Fix warmup argument mismatch bug in _make_graphed_callables when _order is given.
  • Fix fp8 accuracy issue by adding fp8_group argument to make_graphed_callables() and modifing is_first_microbatch, skip_fp8_weight_update and fp8_meta code.
  • Support MoE models cudagraph by filtering graphed TE modules during warmup.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this seems mostly reasonable, although I have questions and stylistic suggestions. Have you tested that it works with Mcore?

@ptrendx @ksivaman @sbhavani What is our priority for this feature? The custom Mcore logic in make_graphed_callables is already messy and fragile, and this PR does exacerbate those problems.

transformer_engine/pytorch/module/layernorm_linear.py Outdated Show resolved Hide resolved
Comment on lines +176 to +184
for m_chunk in range(num_model_chunks):
for _ in range(num_microbatches):
for l_no in range(num_layers):
per_callable_module_params.append(
tuple(callables[m_chunk * num_layers + l_no].parameters())
if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module)
else ()
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems correct to me, but it's odd if the Mcore integration was working before. @ksivaman Have we run this with Mcore, or did we run with num_microbatches=1?

This changes the interpretation of per_callable_module_params from (num_chunks, layers_per_chunk, num_microbatches) to (num_chunks, num_microbatches, layers_per_chunk). This matches the interpretation of per_callable_* lists when capturing graphs:

per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)

transformer_engine/pytorch/graph.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/graph.py Show resolved Hide resolved
transformer_engine/pytorch/graph.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/graph.py Outdated Show resolved Hide resolved
@timmoon10
Copy link
Collaborator

/te-ci pytorch

1 similar comment
@yaox12
Copy link
Collaborator

yaox12 commented Oct 11, 2024

/te-ci pytorch

@buptzyb
Copy link
Contributor Author

buptzyb commented Oct 11, 2024

Have you tested that it works with Mcore?

Yes, we also made some changes in Mcore, together with TE changes in this PR, to enable MoE cudagraph. You can refer to issue 193 in our Megatron-LM repo.

@yaox12
Copy link
Collaborator

yaox12 commented Oct 11, 2024

/te-ci pytorch

@yifeis-nv yifeis-nv force-pushed the cudagraph_moe branch 2 times, most recently from bb1c160 to 66748b9 Compare November 20, 2024 06:50
buptzyb and others added 12 commits November 20, 2024 06:50
Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Yifei Song <[email protected]>
This reverts commit 73a22e2.

Signed-off-by: Yifei Song <[email protected]>
@yaox12
Copy link
Collaborator

yaox12 commented Nov 20, 2024

/te-ci pytorch

@yaox12
Copy link
Collaborator

yaox12 commented Nov 21, 2024

/te-ci pytorch

@buptzyb
Copy link
Contributor Author

buptzyb commented Nov 22, 2024

Hi @timmoon10 , do you have more suggestions on this PR?

@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yaox12
Copy link
Collaborator

yaox12 commented Nov 25, 2024

Merge this PR since pipeline 20710146 passed and Tim approved.

@yaox12 yaox12 merged commit ae393e8 into NVIDIA:main Nov 25, 2024
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants