Skip to content

Commit

Permalink
Remove filtering module params
Browse files Browse the repository at this point in the history
Signed-off-by: Robin Zhang <[email protected]>
  • Loading branch information
buptzyb committed Nov 20, 2024
1 parent c82faa2 commit 938f325
Showing 1 changed file with 1 addition and 28 deletions.
29 changes: 1 addition & 28 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,20 +240,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)

# Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs.
no_grad_weights = None

def get_no_grads(static_input_surface, grad_inputs, func_idx):
grad_index = 0
none_grads = []
for i in range(len(static_input_surface)):
if static_input_surface[i].requires_grad:
if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]):
none_grads.append(i)
grad_index += 1
return set(none_grads)

# Run warmup and do the above two filtering.
# Run warmup and do the above filtering.
with torch.cuda.stream(torch.cuda.Stream()):
for func_idx, func in zip(warmup_func_idx, warmup_func):
args = sample_args[func_idx]
Expand All @@ -274,26 +261,12 @@ def get_no_grads(static_input_surface, grad_inputs, func_idx):
only_inputs=True,
allow_unused=allow_unused_input,
)
if no_grad_weights is None:
no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx)
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
if len(no_grad_weights) > 0:
per_callable_static_input_surfaces[func_idx] = tuple(
inp
for i, inp in enumerate(per_callable_static_input_surfaces[func_idx])
if i not in no_grad_weights
)
per_callable_module_params[func_idx] = tuple(
param
for i, param in enumerate(per_callable_module_params[func_idx])
if i + len(flatten_sample_args[func_idx]) not in no_grad_weights
)
no_grad_weights = None
torch.cuda.synchronize()

# All captures here share a mempool. To avoid replays corrupting each other's memory,
Expand Down

0 comments on commit 938f325

Please sign in to comment.