Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 2 additions & 35 deletions csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,

// Increase counter after finishing
__syncwarp();
lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0;
}
}
} else if (warp_id == num_warps - 1) {
Expand All @@ -184,47 +183,16 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
#pragma unroll
for (int i = lane_id; i < num_next_clean_int; i += 32)
next_clean[i] = 0;

// Notify before executing `int_p`
__syncwarp();
#pragma unroll
for (int i = lane_id; i < num_experts; i += 32)
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG);
}

// This SM should be responsible for some destination experts, read `topk_idx` for them
int expert_count[kNumMaxWarpGroups] = {0};
const auto expert_begin_idx = sm_id * num_warp_groups;
const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts);

// Per lane count
#pragma unroll 8
for (int i = lane_id; i < num_tokens * num_topk; i += 32) {
auto idx = static_cast<int>(__ldg(topk_idx + i));
if (idx >= expert_begin_idx and idx < expert_end_idx)
expert_count[idx - expert_begin_idx] ++;
}

// Warp reduce
#pragma unroll
for (int i = expert_begin_idx; i < expert_end_idx; ++ i) {
auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]);
if (lane_id == 0) {
shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum;
atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum);
}
}
}
__syncthreads();
cg::this_grid().sync();

// Issue count sends
if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) {
const auto dst_rank = responsible_expert_idx / num_local_experts;
const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts;
const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups];
const auto num_tokens_sent = atomic_counter_per_expert[responsible_expert_idx];

// Wait local sends issued and send expert counts
while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2);
auto dst_ptr = reinterpret_cast<uint64_t>(rdma_recv_count + dst_expert_local_idx * num_ranks + rank);
auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank);
if (dst_p2p_ptr == 0) {
Expand All @@ -235,7 +203,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,

// Clean workspace for next use
atomic_counter_per_expert[responsible_expert_idx] = 0;
atomic_finish_counter_per_expert[responsible_expert_idx] = 0;

// Clean `packed_recv_count`
if (dst_rank == 0)
Expand Down