diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 5a343b89..5ed4b250 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -1095,22 +1095,23 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in if (is_cached_dispatch) return; - EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(num_rdma_ranks <= 32); // Iterate in reverse order - if (lane_id < num_rdma_ranks and warp_id < num_channels) { - int token_start_idx, token_end_idx; - get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx); - - // NOTES: `1 << 25` is a heuristic large number - int last_head = 1 << 25; - for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) { - auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); - if (current_head < 0) { - combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; - } else { - last_head = current_head; + if (lane_id < num_rdma_ranks) { + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) { + auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); + if (current_head < 0) { + combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } } } } @@ -1118,19 +1119,18 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in if (is_cached_dispatch) return; - EP_DEVICE_ASSERT(num_warps >= num_channels); EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); - if (warp_id < num_channels) { + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { constexpr int tma_batch_size = kNumTMABytesPerWarp - sizeof(uint64_t); constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS; constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token; EP_STATIC_ASSERT(num_bytes_per_token % 16 == 0, "num_bytes_per_token should be divisible by 16"); - + // TMA stuffs extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; - auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp; + auto tma_buffer = smem_tma_buffer + channel_id * kNumTMABytesPerWarp; auto tma_mbarrier = reinterpret_cast(tma_buffer + tma_batch_size); uint32_t tma_phase = 0; if (lane_id == 0) { @@ -1142,8 +1142,8 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 2) { // Iterate in reverse order - int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; - int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; + int token_start_idx = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; + int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; token_start_idx += shift, token_end_idx += shift;