Skip to content

Commit aad8294

Browse files
committed
fixed UB config reference before assignment and corrected FP8 UB buffer logic
Signed-off-by: Alp Dener <[email protected]>
1 parent 2ca29de commit aad8294

File tree

1 file changed

+4
-4
lines changed
  • transformer_engine/pytorch/module

1 file changed

+4
-4
lines changed

transformer_engine/pytorch/module/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def add_ub(
391391

392392
# Loop over user configs and disable dgrad and wgrad bulk overlaps for every layer that has a
393393
# reduce-scatter dgrad overlap.
394-
ub_cfg = {} if ub_cfg is None else ub_cfg
394+
ub_cfgs = {} if ub_cfgs is None else ub_cfgs
395395
for name in dgrad_reduce_scatter_overlap:
396396
if name in ub_cfgs:
397397
final_cfg = get_default_config(name)
@@ -410,15 +410,15 @@ def add_ub(
410410
new_method = final_cfg["method"]
411411
methods[new_method].append(name)
412412

413-
ub_cfg[name] = final_cfg
413+
ub_cfgs[name] = final_cfg
414414

415415
# Now initialize the UB objects for each layer
416416
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
417-
if ub_cfgs is not None and name in ub_cfgs:
417+
if name in ub_cfgs:
418418
final_cfg = get_default_config(name)
419419
final_cfg.update(ub_cfgs[name])
420420
final_cfg["fp8_buf"] = (name in layers_all_gather_overlap) or (
421-
ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
421+
ub_cfgs[name].get("fp8_buf", False) and name not in methods["pipeline"]
422422
)
423423
add_ub(name, **final_cfg)
424424

0 commit comments

Comments
 (0)