Skip to content

Commit

Permalink
Add TE modules and weights filters to support MoE models
Browse files Browse the repository at this point in the history
Signed-off-by: Robin Zhang <[email protected]>
Co-authored-by: Yifei Song <[email protected]>
  • Loading branch information
buptzyb and yifeis-nv committed Oct 9, 2024
1 parent cef94e4 commit 34967b6
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,25 +232,68 @@ def _make_graphed_callables(
set(warmup_func_idx)
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."

# Run warmup.
# Filter the TE modules that cudagraph can access.
visited_te_modules = set()

def hook_fn(module, input, output):
if (
isinstance(module, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()
):
visited_te_modules.add(module)

# Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs.
no_grad_weights = None

def get_no_grads(static_input_surface, grad_inputs, func_idx):
grad_index = 0
none_grads = []
for i in range(len(static_input_surface)):
if static_input_surface[i].requires_grad:
if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]):
none_grads.append(i)
grad_index += 1
return set(none_grads)

# Run warmup and do the above two filtering.
with torch.cuda.stream(torch.cuda.Stream()):
for func_idx, func in zip(warmup_func_idx, warmup_func):
args = sample_args[func_idx]
kwargs = sample_kwargs[func_idx]
static_input_surface = per_callable_static_input_surfaces[func_idx]
for _ in range(num_warmup_iters):
hooks = []
for module in func.modules():
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
outputs, _ = _tree_flatten(func(*args, **kwargs))
for hook in hooks:
hook.remove()
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
only_inputs=True,
allow_unused=allow_unused_input,
)
if no_grad_weights is None:
no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx)
del outputs, grad_inputs
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
if len(no_grad_weights) > 0:
per_callable_static_input_surfaces[func_idx] = tuple(
inp
for i, inp in enumerate(per_callable_static_input_surfaces[func_idx])
if i not in no_grad_weights
)
per_callable_module_params[func_idx] = tuple(
param
for i, param in enumerate(per_callable_module_params[func_idx])
if i + len(flatten_sample_args[func_idx]) not in no_grad_weights
)
no_grad_weights = None
torch.cuda.synchronize()

# All captures here share a mempool. To avoid replays corrupting each other's memory,
Expand Down Expand Up @@ -495,6 +538,17 @@ def new_fwd(*user_args, **user_kwargs):
isinstance(m, TransformerEngineBaseModule)
and FP8GlobalStateManager.is_fp8_enabled()
):
if m not in visited_te_modules:
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if (
not fp8_recipe.fp8_mha
and not fp8_recipe.fp8_dpa
and hasattr(m, "attention_dropout")
and m.deterministic
):
continue
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
Expand Down

0 comments on commit 34967b6

Please sign in to comment.