Skip to content

Commit 4906662

Browse files
author
Yen-Chen Lin
committed
Improve CP P2P efficiency
Signed-off-by: Yen-Chen Lin <[email protected]>
1 parent 209b8e5 commit 4906662

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,7 +1561,7 @@ def forward(
15611561
fused_attn_qkv_dtype = TE_DType[q.dtype]
15621562
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
15631563

1564-
p2p_comm_buffers = [None for _ in range(cp_size)]
1564+
p2p_comm_buffers = [None, None]
15651565
if use_fused_attention and qkv_format in ["bshd", "sbhd"]:
15661566
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
15671567
else:
@@ -1576,12 +1576,12 @@ def forward(
15761576
req.wait()
15771577

15781578
if i < (cp_size - 1):
1579-
p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i])
1579+
p2p_comm_buffers[(i + 1) % 2] = torch.empty_like(p2p_comm_buffers[i % 2])
15801580
send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
15811581
rank,
1582-
p2p_comm_buffers[i],
1582+
p2p_comm_buffers[i % 2],
15831583
send_dst,
1584-
p2p_comm_buffers[i + 1],
1584+
p2p_comm_buffers[(i + 1) % 2],
15851585
recv_src,
15861586
cp_group,
15871587
batch_p2p_comm,
@@ -1592,11 +1592,11 @@ def forward(
15921592
or fp8_meta["recipe"].fp8_mha
15931593
or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
15941594
):
1595-
kv_inputs[i % 2] = p2p_comm_buffers[i]
1595+
kv_inputs[i % 2] = p2p_comm_buffers[i % 2]
15961596
else:
15971597
# KV exchange is in BF16/FP16, cast received KV in each step
15981598
kv_inputs[i % 2] = cast_to_fp8(
1599-
p2p_comm_buffers[i],
1599+
p2p_comm_buffers[i % 2],
16001600
fp8_meta["scaling_fwd"],
16011601
META_QKV,
16021602
fp8_dtype_forward,

0 commit comments

Comments
 (0)