@@ -1561,7 +1561,7 @@ def forward(
1561
1561
fused_attn_qkv_dtype = TE_DType [q .dtype ]
1562
1562
fused_attn_backend = FusedAttnBackend ["F16_arbitrary_seqlen" ]
1563
1563
1564
- p2p_comm_buffers = [None for _ in range ( cp_size ) ]
1564
+ p2p_comm_buffers = [None , None ]
1565
1565
if use_fused_attention and qkv_format in ["bshd" , "sbhd" ]:
1566
1566
p2p_comm_buffers [0 ] = torch .cat ((k .unsqueeze (- 3 ), v .unsqueeze (- 3 )), dim = - 3 )
1567
1567
else :
@@ -1576,12 +1576,12 @@ def forward(
1576
1576
req .wait ()
1577
1577
1578
1578
if i < (cp_size - 1 ):
1579
- p2p_comm_buffers [i + 1 ] = torch .empty_like (p2p_comm_buffers [i ])
1579
+ p2p_comm_buffers [( i + 1 ) % 2 ] = torch .empty_like (p2p_comm_buffers [i % 2 ])
1580
1580
send_recv_reqs [i % 2 ] = flash_attn_p2p_communicate (
1581
1581
rank ,
1582
- p2p_comm_buffers [i ],
1582
+ p2p_comm_buffers [i % 2 ],
1583
1583
send_dst ,
1584
- p2p_comm_buffers [i + 1 ],
1584
+ p2p_comm_buffers [( i + 1 ) % 2 ],
1585
1585
recv_src ,
1586
1586
cp_group ,
1587
1587
batch_p2p_comm ,
@@ -1592,11 +1592,11 @@ def forward(
1592
1592
or fp8_meta ["recipe" ].fp8_mha
1593
1593
or int (os .getenv ("NVTE_FP8_DPA_BWD" , "1" ))
1594
1594
):
1595
- kv_inputs [i % 2 ] = p2p_comm_buffers [i ]
1595
+ kv_inputs [i % 2 ] = p2p_comm_buffers [i % 2 ]
1596
1596
else :
1597
1597
# KV exchange is in BF16/FP16, cast received KV in each step
1598
1598
kv_inputs [i % 2 ] = cast_to_fp8 (
1599
- p2p_comm_buffers [i ],
1599
+ p2p_comm_buffers [i % 2 ],
1600
1600
fp8_meta ["scaling_fwd" ],
1601
1601
META_QKV ,
1602
1602
fp8_dtype_forward ,
0 commit comments