diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index d53ede87bf..6c33cc72b9 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -240,20 +240,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument if isinstance(module, TransformerEngineBaseModule): visited_te_modules.add(module) - # Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs. - no_grad_weights = None - - def get_no_grads(static_input_surface, grad_inputs, func_idx): - grad_index = 0 - none_grads = [] - for i in range(len(static_input_surface)): - if static_input_surface[i].requires_grad: - if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]): - none_grads.append(i) - grad_index += 1 - return set(none_grads) - - # Run warmup and do the above two filtering. + # Run warmup and do the above filtering. with torch.cuda.stream(torch.cuda.Stream()): for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] @@ -274,26 +261,12 @@ def get_no_grads(static_input_surface, grad_inputs, func_idx): only_inputs=True, allow_unused=allow_unused_input, ) - if no_grad_weights is None: - no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx) del outputs, grad_inputs # The following code is added specifically for MCore's special requirements, # aimed at preventing warmup from altering the control flow. for module in func.modules(): if hasattr(module, "is_first_microbatch"): module.is_first_microbatch = True - if len(no_grad_weights) > 0: - per_callable_static_input_surfaces[func_idx] = tuple( - inp - for i, inp in enumerate(per_callable_static_input_surfaces[func_idx]) - if i not in no_grad_weights - ) - per_callable_module_params[func_idx] = tuple( - param - for i, param in enumerate(per_callable_module_params[func_idx]) - if i + len(flatten_sample_args[func_idx]) not in no_grad_weights - ) - no_grad_weights = None torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory,