We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3def47c commit 0973a05Copy full SHA for 0973a05
transformer_engine/pytorch/module/base.py
@@ -418,9 +418,8 @@ def add_ub(
418
if name in ub_cfgs:
419
final_cfg = get_default_config(name)
420
final_cfg.update(ub_cfgs[name])
421
- final_cfg["fp8_buf"] = (
422
- (name in layers_all_gather_overlap)
423
- or ub_cfgs[name].get("fp8_buf", False)
+ final_cfg["fp8_buf"] = (name in layers_all_gather_overlap) or ub_cfgs[name].get(
+ "fp8_buf", False
424
)
425
add_ub(name, **final_cfg)
426
0 commit comments