Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
}

template <bool kUseFP8, bool kUseUE8M0, int kHidden>
template <bool kUseFP8, bool kUseUE8M0, int kHidden,int kNumTMABytesPerWarp>
__global__ __launch_bounds__(1024, 1) void
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
Expand Down Expand Up @@ -82,6 +82,21 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
constexpr int kNumMaxWarpGroups = 32;
__shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups];

// TMA shared memory and barrier initialization
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto quarter_hidden_int4 = hidden_int4 / 4;
auto quarter_hidden_bytes = quarter_hidden_int4 * static_cast<int>(sizeof(int4));
auto tma_buffer_for_warp = smem_tma_buffer + warp_id * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer_for_warp + quarter_hidden_bytes);
uint32_t tma_phase = 0;
if (lane_id == 0) {
mbarrier_init(tma_mbarrier, 1);
fence_view_async_shared();
fence_barrier_init();
EP_DEVICE_ASSERT(quarter_hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp);
}
__syncwarp();

// Sending phase
if ((phases & LOW_LATENCY_SEND_PHASE) == 0)
goto LOW_LATENCY_DISPATCH_RECV;
Expand Down Expand Up @@ -293,18 +308,34 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
if (lane_id == 0)
recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx);
__syncwarp();

// Copy data
// NOTES: only 2 load iterations for 7K hidden with 7 unrolls
const auto src_data = reinterpret_cast<int4*>(reinterpret_cast<uint8_t*>(src_src_idx) + sizeof(int4));
const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4;
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);
// __syncwarp();
#pragma unroll
for (int j = 0; j < 4; ++j) {
if (lane_id == 0) {
tma_load_1d(tma_buffer_for_warp, src_data + j * quarter_hidden_int4, tma_mbarrier, quarter_hidden_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, quarter_hidden_bytes);
}
__syncwarp();
mbarrier_wait(tma_mbarrier, tma_phase);

// Copy scales
if (lane_id == 0) {
tma_store_1d( tma_buffer_for_warp,dst_data + j * quarter_hidden_int4, quarter_hidden_bytes, false);
tma_store_wait();
}
__syncwarp();

}
if constexpr (kUseFP8) {
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
// const auto smem_scales = reinterpret_cast<const float*>(static_cast<const uint8_t*>(tma_buffer_for_warp) + data_bytes);
const auto num_elems_per_pack = static_cast<int>(sizeof(packed_t) / sizeof(scale_t));
const auto token_idx = recv_token_begin_idx + i;
const auto token_stride = num_elems_per_pack;
Expand All @@ -313,12 +344,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const auto pack_idx = lane_id / num_elems_per_pack;
const auto elem_idx = lane_id % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id));
// auto scale = extract_required_scale_format<kUseUE8M0>(smem_scales[lane_id]);
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
if (lane_id + 32 < num_scales) {
const auto pack_idx = (lane_id + 32) / num_elems_per_pack;
const auto elem_idx = (lane_id + 32) % num_elems_per_pack;
auto scale = extract_required_scale_format<kUseUE8M0>(ld_nc_global(src_scales + lane_id + 32));
// auto scale = extract_required_scale_format<kUseUE8M0>(smem_scales[lane_id + 32]);
recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale;
}
}
Expand Down Expand Up @@ -347,6 +380,8 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
const auto num_warps = num_warp_groups * num_warps_per_group;
const auto num_sms = ceil_div(num_experts, num_warp_groups);
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
constexpr int kNumTMABytesPerWarp = 4096;
const int smem_size = kNumTMABytesPerWarp * num_warps;

// Workspace checks
auto atomic_counter_per_expert = static_cast<int*>(workspace);
Expand All @@ -358,11 +393,12 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");

#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = dispatch<false, false, hidden>; \
auto dispatch_func = dispatch<false, false, hidden,kNumTMABytesPerWarp>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \
dispatch_func = dispatch<true, false, hidden,kNumTMABytesPerWarp>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, hidden>; \
dispatch_func = dispatch<true, true, hidden,kNumTMABytesPerWarp>; \
SET_SHARED_MEMORY_FOR_TMA(dispatch_func); \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
Expand Down Expand Up @@ -658,4 +694,4 @@ LAUNCH_KERNEL(&cfg, combine_func, \

} // namespace internode_ll

} // namespace deep_ep
} // namespace deep_ep