diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8edb4768d2..626edd76a2 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -391,7 +391,7 @@ def add_ub( # Loop over user configs and disable dgrad and wgrad bulk overlaps for every layer that has a # reduce-scatter dgrad overlap. - ub_cfg = {} if ub_cfg is None else ub_cfg + ub_cfgs = {} if ub_cfgs is None else ub_cfgs for name in dgrad_reduce_scatter_overlap: if name in ub_cfgs: final_cfg = get_default_config(name) @@ -410,15 +410,16 @@ def add_ub( new_method = final_cfg["method"] methods[new_method].append(name) - ub_cfg[name] = final_cfg + ub_cfgs[name] = final_cfg # Now initialize the UB objects for each layer for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: - if ub_cfgs is not None and name in ub_cfgs: + if name in ub_cfgs: final_cfg = get_default_config(name) final_cfg.update(ub_cfgs[name]) - final_cfg["fp8_buf"] = (name in layers_all_gather_overlap) or ( - ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"] + final_cfg["fp8_buf"] = ( + (name in layers_all_gather_overlap) + or ub_cfgs[name].get("fp8_buf", False) ) add_ub(name, **final_cfg)