Skip to content

Commit

Permalink
Revert "Use hooks to filter module params"
Browse files Browse the repository at this point in the history
This reverts commit 73a22e2.

Signed-off-by: Yifei Song <[email protected]>
  • Loading branch information
yifeis-nv committed Nov 20, 2024
1 parent 41d6100 commit c82faa2
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,27 @@ def _make_graphed_callables(
set(warmup_func_idx)
), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."

# Filter the TE modules and parameters that cudagraph can access.
# Filter the TE modules that cudagraph can access.
visited_te_modules = set()
visited_params = set()

def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
if isinstance(module, TransformerEngineBaseModule):
visited_te_modules.add(module)
visited_params.update(module.parameters(recurse=False))

# Run warmup and do the above filtering.
# Filter the weights without gradients in backward. These weights are not in the computation graph so can be removed from cudagraph inputs.
no_grad_weights = None

def get_no_grads(static_input_surface, grad_inputs, func_idx):
grad_index = 0
none_grads = []
for i in range(len(static_input_surface)):
if static_input_surface[i].requires_grad:
if grad_inputs[grad_index] is None and i >= len(flatten_sample_args[func_idx]):
none_grads.append(i)
grad_index += 1
return set(none_grads)

# Run warmup and do the above two filtering.
with torch.cuda.stream(torch.cuda.Stream()):
for func_idx, func in zip(warmup_func_idx, warmup_func):
args = sample_args[func_idx]
Expand All @@ -263,22 +274,26 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument
only_inputs=True,
allow_unused=allow_unused_input,
)
if no_grad_weights is None:
no_grad_weights = get_no_grads(static_input_surface, grad_inputs, func_idx)
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True

# Remove cudagraph input params that are not accessed.
for idx, _ in enumerate(flatten_sample_args):
per_callable_module_params[idx] = tuple(
param for param in per_callable_module_params[idx] if param in visited_params
)
per_callable_static_input_surfaces[idx] = (
flatten_sample_args[idx] + per_callable_module_params[idx]
)

if len(no_grad_weights) > 0:
per_callable_static_input_surfaces[func_idx] = tuple(
inp
for i, inp in enumerate(per_callable_static_input_surfaces[func_idx])
if i not in no_grad_weights
)
per_callable_module_params[func_idx] = tuple(
param
for i, param in enumerate(per_callable_module_params[func_idx])
if i + len(flatten_sample_args[func_idx]) not in no_grad_weights
)
no_grad_weights = None
torch.cuda.synchronize()

# All captures here share a mempool. To avoid replays corrupting each other's memory,
Expand Down

0 comments on commit c82faa2

Please sign in to comment.