@@ -232,25 +232,68 @@ def _make_graphed_callables(
232
232
set (warmup_func_idx )
233
233
), f"Warmup runs { len (warmup_func )} but only { len (set (warmup_func_idx ))} are unique."
234
234
235
- # Run warmup.
235
+ # Filter the TE modules that cudagraph can access.
236
+ visited_te_modules = set ()
237
+
238
+ def hook_fn (module , input , output ):
239
+ if (
240
+ isinstance (module , TransformerEngineBaseModule )
241
+ and FP8GlobalStateManager .is_fp8_enabled ()
242
+ ):
243
+ visited_te_modules .add (module )
244
+
245
+ # Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs.
246
+ no_grad_weights = None
247
+
248
+ def get_no_grads (static_input_surface , grad_inputs , func_idx ):
249
+ grad_index = 0
250
+ none_grads = []
251
+ for i in range (len (static_input_surface )):
252
+ if static_input_surface [i ].requires_grad :
253
+ if grad_inputs [grad_index ] is None and i >= len (flatten_sample_args [func_idx ]):
254
+ none_grads .append (i )
255
+ grad_index += 1
256
+ return set (none_grads )
257
+
258
+ # Run warmup and do the above two filtering.
236
259
with torch .cuda .stream (torch .cuda .Stream ()):
237
260
for func_idx , func in zip (warmup_func_idx , warmup_func ):
238
261
args = sample_args [func_idx ]
239
262
kwargs = sample_kwargs [func_idx ]
240
263
static_input_surface = per_callable_static_input_surfaces [func_idx ]
241
264
for _ in range (num_warmup_iters ):
265
+ hooks = []
266
+ for module in func .modules ():
267
+ hook = module .register_forward_hook (hook_fn )
268
+ hooks .append (hook )
242
269
outputs , _ = _tree_flatten (func (* args , ** kwargs ))
270
+ for hook in hooks :
271
+ hook .remove ()
243
272
grad_inputs = torch .autograd .grad (
244
273
outputs = tuple (o for o in outputs if o .requires_grad ),
245
274
inputs = tuple (i for i in static_input_surface if i .requires_grad ),
246
275
grad_outputs = tuple (torch .empty_like (o ) for o in outputs if o .requires_grad ),
247
276
only_inputs = True ,
248
277
allow_unused = allow_unused_input ,
249
278
)
279
+ if no_grad_weights is None :
280
+ no_grad_weights = get_no_grads (static_input_surface , grad_inputs , func_idx )
250
281
del outputs , grad_inputs
251
282
for module in func .modules ():
252
283
if hasattr (module , "is_first_microbatch" ):
253
284
module .is_first_microbatch = True
285
+ if len (no_grad_weights ) > 0 :
286
+ per_callable_static_input_surfaces [func_idx ] = tuple (
287
+ inp
288
+ for i , inp in enumerate (per_callable_static_input_surfaces [func_idx ])
289
+ if i not in no_grad_weights
290
+ )
291
+ per_callable_module_params [func_idx ] = tuple (
292
+ param
293
+ for i , param in enumerate (per_callable_module_params [func_idx ])
294
+ if i + len (flatten_sample_args [func_idx ]) not in no_grad_weights
295
+ )
296
+ no_grad_weights = None
254
297
torch .cuda .synchronize ()
255
298
256
299
# All captures here share a mempool. To avoid replays corrupting each other's memory,
@@ -495,6 +538,17 @@ def new_fwd(*user_args, **user_kwargs):
495
538
isinstance (m , TransformerEngineBaseModule )
496
539
and FP8GlobalStateManager .is_fp8_enabled ()
497
540
):
541
+ if m not in visited_te_modules :
542
+ # Only Set the FP8 meta for the modules included by forward
543
+ continue
544
+ fp8_recipe = FP8GlobalStateManager .get_fp8_recipe ()
545
+ if (
546
+ not fp8_recipe .fp8_mha
547
+ and not fp8_recipe .fp8_dpa
548
+ and hasattr (m , "attention_dropout" )
549
+ and m .deterministic
550
+ ):
551
+ continue
498
552
m .fp8_meta ["fp8_group" ] = FP8GlobalStateManager .get_fp8_group ()
499
553
m .fp8_meta ["recipe" ] = FP8GlobalStateManager .get_fp8_recipe ()
500
554
FP8GlobalStateManager .add_fp8_tensors_to_global_buffer (
0 commit comments