Skip to content

Commit 5505867

Browse files
pre-commit-ci[bot]denera
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0973a05 commit 5505867

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

tests/pytorch/distributed/run_layer_with_overlap.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _get_ub_cfg(config):
9292
"qkv_fprop": dict(),
9393
"qkv_dgrad": {
9494
"method": "pipeline" if config.rs_dgrad_overlap else "bulk",
95-
"fp8_buf": True if config.fp8_buf and config.rs_dgrad_overlap else False
95+
"fp8_buf": True if config.fp8_buf and config.rs_dgrad_overlap else False,
9696
},
9797
}
9898
)
@@ -101,9 +101,7 @@ def _get_ub_cfg(config):
101101
if config.layer_type in [te.Linear, te.MultiheadAttention, te.TransformerLayer]:
102102
ub_cfg.update(
103103
{
104-
"proj_fprop": {
105-
"fp8_buf": True if config.fp8_buf else False
106-
},
104+
"proj_fprop": {"fp8_buf": True if config.fp8_buf else False},
107105
"proj_dgrad": dict(),
108106
}
109107
)
@@ -113,7 +111,7 @@ def _get_ub_cfg(config):
113111
"fc1_fprop": dict(),
114112
"fc1_dgrad": {
115113
"method": "pipeline" if config.rs_dgrad_overlap else "bulk",
116-
"fp8_buf": True if config.fp8_buf and config.rs_dgrad_overlap else False
114+
"fp8_buf": True if config.fp8_buf and config.rs_dgrad_overlap else False,
117115
},
118116
"fc2_fprop": {
119117
"fp8_buf": True if config.fp8_buf else False,
@@ -150,7 +148,7 @@ def _parse_args(argv=None, namespace=None):
150148
"--fp8-buf",
151149
action="store_true",
152150
default=False,
153-
help="Allocate FP8 communication buffers for layers that support it."
151+
help="Allocate FP8 communication buffers for layers that support it.",
154152
)
155153
parser.add_argument(
156154
"--rs-dgrad-overlap",

tests/pytorch/distributed/test_comm_gemm_overlap.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,7 @@ def test_bulk_overlaps(comm_type, fp8, connections):
274274
)
275275
@pytest.mark.parametrize(
276276
"fp8,fp8_init",
277-
[
278-
(False, False),
279-
(True, False),
280-
(True, True)
281-
],
277+
[(False, False), (True, False), (True, True)],
282278
ids=[
283279
" BF16 GEMM - BF16 PARAMS ",
284280
" FP8 GEMM - BF16 PARAMS ",

0 commit comments

Comments
 (0)