Skip to content

Commit 83ea7c8

Browse files
buptzybyifeis-nv
andcommitted
Add TE modules and weights filters to support MoE models
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
1 parent 3485806 commit 83ea7c8

File tree

1 file changed

+55
-1
lines changed

1 file changed

+55
-1
lines changed

transformer_engine/pytorch/graph.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,25 +232,68 @@ def _make_graphed_callables(
232232
set(warmup_func_idx)
233233
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."
234234

235-
# Run warmup.
235+
# Filter the TE modules that cudagraph can access.
236+
visited_te_modules = set()
237+
238+
def hook_fn(module, input, output):
239+
if (
240+
isinstance(module, TransformerEngineBaseModule)
241+
and FP8GlobalStateManager.is_fp8_enabled()
242+
):
243+
visited_te_modules.add(module)
244+
245+
# Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs.
246+
no_grad_weights = None
247+
248+
def get_no_grads(static_input_surface, grad_inputs, func_idx):
249+
grad_index = 0
250+
none_grads = []
251+
for i in range(len(static_input_surface)):
252+
if static_input_surface[i].requires_grad:
253+
if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]):
254+
none_grads.append(i)
255+
grad_index += 1
256+
return set(none_grads)
257+
258+
# Run warmup and do the above two filtering.
236259
with torch.cuda.stream(torch.cuda.Stream()):
237260
for func_idx, func in zip(warmup_func_idx, warmup_func):
238261
args = sample_args[func_idx]
239262
kwargs = sample_kwargs[func_idx]
240263
static_input_surface = per_callable_static_input_surfaces[func_idx]
241264
for _ in range(num_warmup_iters):
265+
hooks = []
266+
for module in func.modules():
267+
hook = module.register_forward_hook(hook_fn)
268+
hooks.append(hook)
242269
outputs, _ = _tree_flatten(func(*args, **kwargs))
270+
for hook in hooks:
271+
hook.remove()
243272
grad_inputs = torch.autograd.grad(
244273
outputs=tuple(o for o in outputs if o.requires_grad),
245274
inputs=tuple(i for i in static_input_surface if i.requires_grad),
246275
grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
247276
only_inputs=True,
248277
allow_unused=allow_unused_input,
249278
)
279+
if no_grad_weights is None:
280+
no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx)
250281
del outputs, grad_inputs
251282
for module in func.modules():
252283
if hasattr(module, "is_first_microbatch"):
253284
module.is_first_microbatch = True
285+
if len(no_grad_weights) > 0:
286+
per_callable_static_input_surfaces[func_idx] = tuple(
287+
inp
288+
for i, inp in enumerate(per_callable_static_input_surfaces[func_idx])
289+
if i not in no_grad_weights
290+
)
291+
per_callable_module_params[func_idx] = tuple(
292+
param
293+
for i, param in enumerate(per_callable_module_params[func_idx])
294+
if i + len(flatten_sample_args[func_idx]) not in no_grad_weights
295+
)
296+
no_grad_weights = None
254297
torch.cuda.synchronize()
255298

256299
# All captures here share a mempool. To avoid replays corrupting each other's memory,
@@ -495,6 +538,17 @@ def new_fwd(*user_args, **user_kwargs):
495538
isinstance(m, TransformerEngineBaseModule)
496539
and FP8GlobalStateManager.is_fp8_enabled()
497540
):
541+
if m not in visited_te_modules:
542+
# Only Set the FP8 meta for the modules included by forward
543+
continue
544+
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
545+
if (
546+
not fp8_recipe.fp8_mha
547+
and not fp8_recipe.fp8_dpa
548+
and hasattr(m, "attention_dropout")
549+
and m.deterministic
550+
):
551+
continue
498552
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
499553
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
500554
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(

0 commit comments

Comments
 (0)