diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 4d5ee07e..b99fa84a 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -171,7 +171,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Increase counter after finishing __syncwarp(); - lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; } } } else if (warp_id == num_warps - 1) { @@ -184,47 +183,16 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, #pragma unroll for (int i = lane_id; i < num_next_clean_int; i += 32) next_clean[i] = 0; - - // Notify before executing `int_p` - __syncwarp(); - #pragma unroll - for (int i = lane_id; i < num_experts; i += 32) - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); - } - - // This SM should be responsible for some destination experts, read `topk_idx` for them - int expert_count[kNumMaxWarpGroups] = {0}; - const auto expert_begin_idx = sm_id * num_warp_groups; - const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); - - // Per lane count - #pragma unroll 8 - for (int i = lane_id; i < num_tokens * num_topk; i += 32) { - auto idx = static_cast(__ldg(topk_idx + i)); - if (idx >= expert_begin_idx and idx < expert_end_idx) - expert_count[idx - expert_begin_idx] ++; - } - - // Warp reduce - #pragma unroll - for (int i = expert_begin_idx; i < expert_end_idx; ++ i) { - auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); - if (lane_id == 0) { - shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); - } } } - __syncthreads(); + cg::this_grid().sync(); // Issue count sends if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { const auto dst_rank = responsible_expert_idx / num_local_experts; const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; - const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; + const auto num_tokens_sent = atomic_counter_per_expert[responsible_expert_idx]; - // Wait local sends issued and send expert counts - while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); auto dst_ptr = reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank); auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); if (dst_p2p_ptr == 0) { @@ -235,7 +203,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Clean workspace for next use atomic_counter_per_expert[responsible_expert_idx] = 0; - atomic_finish_counter_per_expert[responsible_expert_idx] = 0; // Clean `packed_recv_count` if (dst_rank == 0)