File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
transformer_engine/pytorch/module Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -391,7 +391,7 @@ def add_ub(
391
391
392
392
# Loop over user configs and disable dgrad and wgrad bulk overlaps for every layer that has a
393
393
# reduce-scatter dgrad overlap.
394
- ub_cfg = {} if ub_cfg is None else ub_cfg
394
+ ub_cfgs = {} if ub_cfgs is None else ub_cfgs
395
395
for name in dgrad_reduce_scatter_overlap :
396
396
if name in ub_cfgs :
397
397
final_cfg = get_default_config (name )
@@ -410,15 +410,15 @@ def add_ub(
410
410
new_method = final_cfg ["method" ]
411
411
methods [new_method ].append (name )
412
412
413
- ub_cfg [name ] = final_cfg
413
+ ub_cfgs [name ] = final_cfg
414
414
415
415
# Now initialize the UB objects for each layer
416
416
for name in methods ["ring_exchange" ] + methods ["pipeline" ] + methods ["bulk" ]:
417
- if ub_cfgs is not None and name in ub_cfgs :
417
+ if name in ub_cfgs :
418
418
final_cfg = get_default_config (name )
419
419
final_cfg .update (ub_cfgs [name ])
420
420
final_cfg ["fp8_buf" ] = (name in layers_all_gather_overlap ) or (
421
- ub_cfgs [name ].get ("fp8_buf" , False ) and name in methods ["pipeline" ]
421
+ ub_cfgs [name ].get ("fp8_buf" , False ) and name not in methods ["pipeline" ]
422
422
)
423
423
add_ub (name , ** final_cfg )
424
424
You can’t perform that action at this time.
0 commit comments