File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
transformer_engine/pytorch Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -1576,12 +1576,12 @@ def forward(
1576
1576
req .wait ()
1577
1577
1578
1578
if i < (cp_size - 1 ):
1579
- p2p_comm_buffers [(i + 1 ) % 2 ] = torch .empty_like (p2p_comm_buffers [i % 2 ])
1579
+ p2p_comm_buffers [(i + 1 ) % 2 ] = torch .empty_like (p2p_comm_buffers [i % 2 ])
1580
1580
send_recv_reqs [i % 2 ] = flash_attn_p2p_communicate (
1581
1581
rank ,
1582
- p2p_comm_buffers [i % 2 ],
1582
+ p2p_comm_buffers [i % 2 ],
1583
1583
send_dst ,
1584
- p2p_comm_buffers [(i + 1 ) % 2 ],
1584
+ p2p_comm_buffers [(i + 1 ) % 2 ],
1585
1585
recv_src ,
1586
1586
cp_group ,
1587
1587
batch_p2p_comm ,
@@ -1592,11 +1592,11 @@ def forward(
1592
1592
or fp8_meta ["recipe" ].fp8_mha
1593
1593
or int (os .getenv ("NVTE_FP8_DPA_BWD" , "1" ))
1594
1594
):
1595
- kv_inputs [i % 2 ] = p2p_comm_buffers [i % 2 ]
1595
+ kv_inputs [i % 2 ] = p2p_comm_buffers [i % 2 ]
1596
1596
else :
1597
1597
# KV exchange is in BF16/FP16, cast received KV in each step
1598
1598
kv_inputs [i % 2 ] = cast_to_fp8 (
1599
- p2p_comm_buffers [i % 2 ],
1599
+ p2p_comm_buffers [i % 2 ],
1600
1600
fp8_meta ["scaling_fwd" ],
1601
1601
META_QKV ,
1602
1602
fp8_dtype_forward ,
You can’t perform that action at this time.
0 commit comments