diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index dc03c65a..28778724 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -36,7 +36,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, clean_0, num_clean_int_0, clean_1, num_clean_int_1); } -template +template __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -82,6 +82,21 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, constexpr int kNumMaxWarpGroups = 32; __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + // TMA shared memory and barrier initialization + extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; + auto quarter_hidden_int4 = hidden_int4 / 4; + auto quarter_hidden_bytes = quarter_hidden_int4 * static_cast(sizeof(int4)); + auto tma_buffer_for_warp = smem_tma_buffer + warp_id * kNumTMABytesPerWarp; + auto tma_mbarrier = reinterpret_cast(tma_buffer_for_warp + quarter_hidden_bytes); + uint32_t tma_phase = 0; + if (lane_id == 0) { + mbarrier_init(tma_mbarrier, 1); + fence_view_async_shared(); + fence_barrier_init(); + EP_DEVICE_ASSERT(quarter_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp); + } + __syncwarp(); + // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; @@ -293,18 +308,34 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, if (lane_id == 0) recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); __syncwarp(); - + // Copy data // NOTES: only 2 load iterations for 7K hidden with 7 unrolls const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; - UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + // UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + // __syncwarp(); + #pragma unroll + for (int j = 0; j < 4; ++j) { + if (lane_id == 0) { + tma_load_1d(tma_buffer_for_warp, src_data + j * quarter_hidden_int4, tma_mbarrier, quarter_hidden_bytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier, quarter_hidden_bytes); + } + __syncwarp(); + mbarrier_wait(tma_mbarrier, tma_phase); - // Copy scales + if (lane_id == 0) { + tma_store_1d( tma_buffer_for_warp,dst_data + j * quarter_hidden_int4, quarter_hidden_bytes, false); + tma_store_wait(); + } + __syncwarp(); + + } if constexpr (kUseFP8) { // Equivalent CuTe layout: // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + // const auto smem_scales = reinterpret_cast(static_cast(tma_buffer_for_warp) + data_bytes); const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); const auto token_idx = recv_token_begin_idx + i; const auto token_stride = num_elems_per_pack; @@ -313,12 +344,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto pack_idx = lane_id / num_elems_per_pack; const auto elem_idx = lane_id % num_elems_per_pack; auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id)); + // auto scale = extract_required_scale_format(smem_scales[lane_id]); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; } if (lane_id + 32 < num_scales) { const auto pack_idx = (lane_id + 32) / num_elems_per_pack; const auto elem_idx = (lane_id + 32) % num_elems_per_pack; auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); + // auto scale = extract_required_scale_format(smem_scales[lane_id + 32]); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; } } @@ -347,6 +380,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto num_warps = num_warp_groups * num_warps_per_group; const auto num_sms = ceil_div(num_experts, num_warp_groups); EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + constexpr int kNumTMABytesPerWarp = 4096; + const int smem_size = kNumTMABytesPerWarp * num_warps; // Workspace checks auto atomic_counter_per_expert = static_cast(workspace); @@ -358,11 +393,12 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); #define DISPATCH_LAUNCH_CASE(hidden) { \ -auto dispatch_func = dispatch