Skip to content

Commit ffc39ba

Browse files
committed
Stronger acquire scope for low-latency kernels
1 parent 7d52ad7 commit ffc39ba

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

csrc/kernels/internode_ll.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales,
260260
int num_recv_tokens, recv_token_begin_idx;
261261
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group");
262262
if (sub_warp_id == 1 and lane_id == 0) {
263-
while ((num_recv_tokens = ld_acquire_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
263+
while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
264264
num_recv_tokens = -num_recv_tokens - 1;
265265
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
266266
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
@@ -450,7 +450,7 @@ combine(void* combined_x,
450450
if (responsible_expert_idx < num_experts) {
451451
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group");
452452
if (sub_warp_id == 0 and lane_id == 0)
453-
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0);
453+
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
454454
}
455455
cg::this_grid().sync();
456456

0 commit comments

Comments
 (0)