-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support CUDA Graph for MoE models #1233
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this seems mostly reasonable, although I have questions and stylistic suggestions. Have you tested that it works with Mcore?
@ptrendx @ksivaman @sbhavani What is our priority for this feature? The custom Mcore logic in make_graphed_callables
is already messy and fragile, and this PR does exacerbate those problems.
for m_chunk in range(num_model_chunks): | ||
for _ in range(num_microbatches): | ||
for l_no in range(num_layers): | ||
per_callable_module_params.append( | ||
tuple(callables[m_chunk * num_layers + l_no].parameters()) | ||
if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) | ||
else () | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change seems correct to me, but it's odd if the Mcore integration was working before. @ksivaman Have we run this with Mcore, or did we run with num_microbatches=1
?
This changes the interpretation of per_callable_module_params
from (num_chunks, layers_per_chunk, num_microbatches)
to (num_chunks, num_microbatches, layers_per_chunk)
. This matches the interpretation of per_callable_*
lists when capturing graphs:
TransformerEngine/transformer_engine/pytorch/graph.py
Lines 237 to 239 in 3b89c36
per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + ( | |
fwd_idx[m_chunk] * num_layers + l_no | |
) |
/te-ci pytorch |
1 similar comment
/te-ci pytorch |
Yes, we also made some changes in Mcore, together with TE changes in this PR, to enable MoE cudagraph. You can refer to issue 193 in our Megatron-LM repo. |
/te-ci pytorch |
bb1c160
to
66748b9
Compare
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Yifei Song <[email protected]>
66748b9
to
41d6100
Compare
This reverts commit 73a22e2. Signed-off-by: Yifei Song <[email protected]>
/te-ci pytorch |
Signed-off-by: Robin Zhang <[email protected]>
70c514d
to
938f325
Compare
/te-ci pytorch |
Hi @timmoon10 , do you have more suggestions on this PR? |
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Merge this PR since pipeline 20710146 passed and Tim approved. |
Description
Different from non-MoE models like llama2, MoE models have dynamic-shaped activations in FFN layers, so one cudagraph can only capture a part of one transformer layer, instead of covering the whole layer. We call this a "breaking-layer" cudagraph mode. This PR adds breaking-layer cudagraph supports for MoE models on the TE side, and fixes several related bugs in TE.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
per_callable_module_params
order bug in_make_graphed_callables
when_order
is given._make_graphed_callables
when_order
is given.fp8_group
argument tomake_graphed_callables()
and modifingis_first_microbatch
,skip_fp8_weight_update
andfp8_meta
code.Checklist: