Skip to content

Commit

Permalink
Disable UB bulk wgrad when weights are frozen (#702)
Browse files Browse the repository at this point in the history
Signed-off-by: Jaemin Choi <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
  • Loading branch information
minitu and Jaemin Choi authored Mar 5, 2024
1 parent 3f8baf9 commit b0f6535
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def backward(

if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
if tp_world_size == 1 or not weight.requires_grad:
ctx.ub_bulk_wgrad = False

# Column Parallel Linear
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def backward(

if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
if tp_world_size == 1 or not fc1_weight.requires_grad:
ctx.ub_bulk_wgrad = False
# Column Parallel Linear
# Overlap input AG with dgrad
Expand Down

0 comments on commit b0f6535

Please sign in to comment.