Skip to content

Commit ef2ff72

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a93256f commit ef2ff72

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,12 +1576,12 @@ def forward(
15761576
req.wait()
15771577

15781578
if i < (cp_size - 1):
1579-
p2p_comm_buffers[(i+1)%2] = torch.empty_like(p2p_comm_buffers[i%2])
1579+
p2p_comm_buffers[(i + 1) % 2] = torch.empty_like(p2p_comm_buffers[i % 2])
15801580
send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
15811581
rank,
1582-
p2p_comm_buffers[i%2],
1582+
p2p_comm_buffers[i % 2],
15831583
send_dst,
1584-
p2p_comm_buffers[(i+1)%2],
1584+
p2p_comm_buffers[(i + 1) % 2],
15851585
recv_src,
15861586
cp_group,
15871587
batch_p2p_comm,
@@ -1592,11 +1592,11 @@ def forward(
15921592
or fp8_meta["recipe"].fp8_mha
15931593
or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
15941594
):
1595-
kv_inputs[i % 2] = p2p_comm_buffers[i%2]
1595+
kv_inputs[i % 2] = p2p_comm_buffers[i % 2]
15961596
else:
15971597
# KV exchange is in BF16/FP16, cast received KV in each step
15981598
kv_inputs[i % 2] = cast_to_fp8(
1599-
p2p_comm_buffers[i%2],
1599+
p2p_comm_buffers[i % 2],
16001600
fp8_meta["scaling_fwd"],
16011601
META_QKV,
16021602
fp8_dtype_forward,

0 commit comments

Comments
 (0)