Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CUDA Graph for MoE models #1233

Merged
merged 16 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,16 +442,16 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non
stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft()

# Replace amaxes and scales with stashed values for phase 2 forward
fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0]
fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1]
fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2]
fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0])
fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1])
fp8_meta["scaling_fwd"].scale_inv.copy_(stashed_fp8_meta[2])

@staticmethod
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""
fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"]
fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"]
fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"]
fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"])
fp8_meta["scaling_fwd"].scale_inv.copy_(fp8_meta["updated_scale_inv_fwd"])


@contextmanager
Expand Down
86 changes: 79 additions & 7 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch._C import _graph_pool_handle

from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.pytorch.constants import dist_group_type
from .fp8 import (
fp8_autocast,
FP8GlobalStateManager,
Expand Down Expand Up @@ -173,11 +174,14 @@ def _make_graphed_callables(
]
else:
per_callable_module_params = []
for c in callables:
for i in range(num_microbatches):
per_callable_module_params.append(
tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
)
for m_chunk in range(num_model_chunks):
for _ in range(num_microbatches):
for l_no in range(num_layers):
per_callable_module_params.append(
tuple(callables[m_chunk * num_layers + l_no].parameters())
if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module)
else ()
)
Comment on lines +177 to +184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems correct to me, but it's odd if the Mcore integration was working before. @ksivaman Have we run this with Mcore, or did we run with num_microbatches=1?

This changes the interpretation of per_callable_module_params from (num_chunks, layers_per_chunk, num_microbatches) to (num_chunks, num_microbatches, layers_per_chunk). This matches the interpretation of per_callable_* lists when capturing graphs:

per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)

assert len(per_callable_module_params) == len(flatten_sample_args)
per_callable_static_input_surfaces = [
flatten_sample_args[i] + per_callable_module_params[i]
Expand All @@ -201,13 +205,55 @@ def _make_graphed_callables(
# Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
# from ending up in any captures.
torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):

# Get warmup func and func_idx.
warmup_func_idx = []
warmup_func = []
if _order is None:
for func_idx, func in enumerate(callables):
warmup_func_idx.append(func_idx)
warmup_func.append(func)
else:
fwd_idx = [0] * num_model_chunks
for c_id in _order:
if c_id > 0:
m_chunk = c_id - 1
for l_no in range(num_layers):
func = callables[m_chunk * num_layers + l_no]
func_idx = (m_chunk * num_microbatches * num_layers) + (
fwd_idx[m_chunk] * num_layers + l_no
)
warmup_func_idx.append(func_idx)
warmup_func.append(func)
fwd_idx[m_chunk] += 1
assert len(warmup_func) == len(
sample_args
), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}."
assert len(warmup_func_idx) == len(
set(warmup_func_idx)
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."

# Filter the TE modules that cudagraph can access.
visited_te_modules = set()

def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)

# Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()):
for func_idx, func in zip(warmup_func_idx, warmup_func):
args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx]
for _ in range(num_warmup_iters):
hooks = []
for module in func.modules():
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
Expand All @@ -216,6 +262,11 @@ def _make_graphed_callables(
allow_unused=allow_unused_input,
)
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved
torch.cuda.synchronize()

# All captures here share a mempool. To avoid replays corrupting each other's memory,
Expand Down Expand Up @@ -462,6 +513,19 @@ def new_fwd(*user_args, **user_kwargs):
isinstance(m, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()
):
if m not in visited_te_modules:
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
from transformer_engine.pytorch.attention import DotProductAttention

if (
isinstance(m, DotProductAttention)
and not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
):
# Don't need to update FP8 meta for non-FP8 DPA
continue
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
Expand Down Expand Up @@ -538,6 +602,7 @@ def make_graphed_callables(
fp8_enabled: bool = False,
fp8_calibrating: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
fp8_weight_caching: bool = False,
_order: Optional[List[int]] = None,
pool: Optional[Tuple[int, ...]] = None,
Expand Down Expand Up @@ -579,6 +644,9 @@ def make_graphed_callables(
using a higher precision.
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
fp8_weight_caching: bool, default = `False`
Whether or not to cache FP8 weights across microbatches. if set to `True`,
the `is_first_microbatch` boolean argument must be passed into the forward
Expand Down Expand Up @@ -607,7 +675,11 @@ def wrap_autocast(block):

def forward_func(*args, **kwargs):
with fp8_autocast(
enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True
enabled=fp8_enabled,
calibrating=fp8_calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=True,
):
outputs = old_forward(*args, **kwargs)
return outputs
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,10 @@ def forward(
produced)
"""

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,10 @@ def forward(
produced)
"""

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,10 @@ def forward(
first microbatch (since it is the first gradient being
produced)
"""

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False

Expand Down
Loading