diff --git a/csrc/config.hpp b/csrc/config.hpp index 6bef6255..c674be33 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -133,6 +133,7 @@ struct LowLatencyLayout { LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { const int num_scales = hidden / 128; + const int num_nodes = num_ranks/NUM_MAX_NVL_PEERS; // TODO Automatically calculate the value of NUM_MAX_NVL_PEERS according to the running situation of the process // Dispatch and combine layout: // - 2 symmetric odd/even send buffer @@ -143,7 +144,9 @@ struct LowLatencyLayout { // NOTES: you should add a control `int4` for combine messages if you want to do data transformation // NOTES: `num_scales * sizeof(nv_bfloat162)` means the per-128-channel min/max EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); - size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); + size_t per_meta_data_size = sizeof(int4); + size_t per_token_size = std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); + size_t num_bytes_per_dispatch_msg = per_meta_data_size + per_token_size; size_t num_bytes_per_combine_msg = num_scales * sizeof(nv_bfloat162) + hidden * sizeof(nv_bfloat16); // Send buffer @@ -155,14 +158,15 @@ struct LowLatencyLayout { // Symmetric receive buffers // TODO: optimize memory usages - size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * per_meta_data_size + num_nodes * num_max_dispatch_tokens_per_rank * per_token_size; // means num_experts == local_experts * num_ranks size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); total_bytes += recv_buffer_bytes * 2; // Symmetric signaling buffers - size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); + size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int4) * 2 + // means num_experts == local_experts * num_ranks == local_experts * NUM_MAX_NVL_PEERS * num_nodes, Half is used in dispatch, and the other half is used in combine. + NUM_MAX_NVL_PEERS * num_nodes * num_max_dispatch_tokens_per_rank * sizeof(int) + NUM_MAX_NVL_PEERS * sizeof(int); size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); size_t signaling_buffer_bytes_aligned = align_up(signaling_buffer_bytes, 128); @@ -173,7 +177,7 @@ struct LowLatencyLayout { // so you may see some parameters are duplicated for (int i = 0; i < 2; ++ i) { buffers[i] = { - static_cast(signaling_buffer_bytes / sizeof(int)), + static_cast(signaling_buffer_bytes / sizeof(int4)), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * 2 + send_buffer_bytes * 2 + recv_buffer_bytes * i), advance(rdma_buffer, signaling_buffer_bytes_aligned * i), diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index d18d6fd3..a099ff3f 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -2,11 +2,16 @@ #include "exception.cuh" #include "launch.cuh" #include "ibgda_device.cuh" +#include "utils.cuh" +#include +#include "cooperative_groups.h" namespace deep_ep { namespace internode_ll { +namespace cg = cooperative_groups; + template __launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1) { @@ -43,7 +48,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, - void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, + void* rdma_recv_x, int* new_rdma_recv_count, void* rdma_x, const void* x, const topk_idx_t* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, int* next_clean, int num_next_clean_int, @@ -60,6 +65,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto warp_group_id = warp_id / num_warps_per_group; const auto sub_warp_id = warp_id % num_warps_per_group; const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + const auto num_nvl_ranks = NUM_MAX_NVL_PEERS; + const auto num_nodes = num_ranks/num_nvl_ranks; + int4* rdma_recv_count = reinterpret_cast(new_rdma_recv_count) + num_experts; + int* data_ready_counter = reinterpret_cast(reinterpret_cast(new_rdma_recv_count) + num_experts * 2); + + int4* next_clean_int4 = reinterpret_cast(next_clean); + int* next_clean_data_ready_counter = reinterpret_cast(reinterpret_cast(next_clean) + num_experts * 2); + auto* data_ready_send_buffer = reinterpret_cast(data_ready_counter) + + num_nodes * num_max_dispatch_tokens_per_rank * num_nvl_ranks; + if (thread_id < num_nvl_ranks) { + st_na_global(reinterpret_cast(data_ready_send_buffer)+thread_id, 1); // set to 1 + } + __syncthreads(); + EP_DEVICE_ASSERT(num_ranks % num_nvl_ranks == 0); // May extract UE8M0 from the scales using scale_t = std::conditional_t; @@ -75,10 +94,14 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales // NOTES: currently we have 3 reserved int fields for future use using vec_t = std::conditional_t; - const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); - const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); + const size_t num_bytes_per_meta = sizeof(int4); + const size_t num_bytes_per_data = (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + const size_t num_bytes_per_msg = num_bytes_per_meta + num_bytes_per_data; EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + void* rdma_recv_x_meta = rdma_recv_x; + void* rdma_recv_x_data = (void*)(uint64_t(rdma_recv_x) + num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_meta); + // Expert counts constexpr int kNumMaxWarpGroups = 32; __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; @@ -150,25 +173,73 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Issue IBGDA sends if (dst_expert_idx >= 0) { + int send_node_id = dst_expert_idx >= 0 ? dst_expert_idx/num_local_experts/num_nvl_ranks : -1; int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); const auto dst_rank = dst_expert_idx / num_local_experts; const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; + auto real_write_dst_rank = dst_rank / num_nvl_ranks * num_nvl_ranks + rank % num_nvl_ranks; // send data to same gpu_device_id_rank(same-rail rdma traffic) + { // send token + { // avoid sending repeatedly to the same node + EP_DEVICE_ASSERT(num_topk <= 32); + auto tmp_dst_expert_id = lane_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + lane_id)) : -1; + auto tmp_dst_node_id = tmp_dst_expert_id >= 0 ? tmp_dst_expert_id/num_local_experts/num_nvl_ranks : -1; + #pragma unroll + for (int i = 0; i < warp_id; ++ i) { + auto dst_node_id = __shfl_sync(0xffffffff, tmp_dst_node_id, i); // broadcast + if (dst_node_id == send_node_id) { // whether to send repeatedly + send_node_id = -1; + break; + } + } + } + + if (send_node_id != -1) { // send token + const auto src_ptr = reinterpret_cast(rdma_x_src_idx)+num_bytes_per_meta; + const auto dst_ptr = reinterpret_cast(rdma_recv_x_data) + + (rank/num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_data + + token_idx * num_bytes_per_data; + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, real_write_dst_rank); + if (dst_p2p_ptr == 0) { // one token only send once to a node + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_data, real_write_dst_rank, dst_expert_local_idx, lane_id, slot_idx); + } else { + // NOTES: only 2 load iterations for 7K hidden with 8 unrolls + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(7, lane_id, num_bytes_per_data/sizeof(int4), dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } + } + if (send_node_id != -1) { // send data ready flag + const auto src_ptr = reinterpret_cast(data_ready_send_buffer); + const auto data_ready_counter_ptr = reinterpret_cast(data_ready_counter) + + (rank/num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_nvl_ranks * sizeof(int) + + token_idx * num_nvl_ranks * sizeof(int); + const auto data_ready_counter_p2p_ptr = nvshmemi_get_p2p_ptr(data_ready_counter_ptr, rank, real_write_dst_rank); + if (data_ready_counter_p2p_ptr == 0) { // one token only send once to a node + nvshmemi_ibgda_put_nbi_warp(data_ready_counter_ptr, uint64_t(src_ptr), num_nvl_ranks*sizeof(int), real_write_dst_rank, dst_expert_local_idx, lane_id, slot_idx+1); + } else { + const auto* src_int_ptr = reinterpret_cast(src_ptr); + const auto* dst_int_ptr = reinterpret_cast(data_ready_counter_p2p_ptr); + UNROLLED_WARP_COPY(1, lane_id, num_nvl_ranks, dst_int_ptr, src_int_ptr, ld_nc_global, st_na_global); + } + } + } + // send meta const auto src_ptr = reinterpret_cast(rdma_x_src_idx); - const auto dst_ptr = reinterpret_cast(rdma_recv_x) + - dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - slot_idx * num_bytes_per_msg; - const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + const auto dst_ptr = reinterpret_cast(rdma_recv_x_meta) + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta + + (rank/num_nvl_ranks) * num_nvl_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta + + (dst_rank % num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_meta + + slot_idx * num_bytes_per_meta; + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, real_write_dst_rank); if (dst_p2p_ptr == 0) { - nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_meta, real_write_dst_rank, dst_expert_local_idx, lane_id, slot_idx); } else { // NOTES: only 2 load iterations for 7K hidden with 8 unrolls const auto* src_int4_ptr = reinterpret_cast(src_ptr); const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(1, lane_id, num_bytes_per_meta/sizeof(int4), dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); } - // Increase counter after finishing __syncwarp(); lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; @@ -182,9 +253,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // The first SM is also responsible for cleaning the next buffer #pragma unroll - for (int i = lane_id; i < num_next_clean_int; i += 32) + for (int i = lane_id; i < num_experts; i += 32) // clean for combine next_clean[i] = 0; + auto* dispatch_next_clean_int4 = next_clean_int4 + num_experts; + #pragma unroll + for (int i = lane_id; i < num_experts; i += 32) {// clean for dispatch + const auto src_rank = i / num_local_experts; + const auto local_expert_idx = i % num_local_experts; + const auto read_recv_counter_rank = rank / num_nvl_ranks * num_nvl_ranks + src_rank % num_nvl_ranks; // read recv counter from from remote nvl_rank + + const auto counter_ptr = reinterpret_cast(dispatch_next_clean_int4) + + local_expert_idx * num_ranks + + (src_rank/num_nvl_ranks)* num_nvl_ranks+ + (rank%num_nvl_ranks); + + const auto real_counter_p2p_src_ptr = nvshmemi_get_p2p_ptr(uint64_t(counter_ptr), rank, read_recv_counter_rank); + st_release_sys_global(reinterpret_cast(real_counter_p2p_src_ptr), 0); + } // Notify before executing `int_p` __syncwarp(); #pragma unroll @@ -215,6 +301,31 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, } } } + + if (sm_id == num_sms-1) { + // clean data ready flag + #pragma unroll 8 + for (int i = thread_id; i < num_max_dispatch_tokens_per_rank*num_ranks; i += blockDim.x) { + int token_idx = i/num_ranks; + int rank_id = i%num_ranks; + { + auto node_id = rank_id/num_nvl_ranks; + auto nvl_rank_id = rank_id%num_nvl_ranks; + auto* data_ready_flag_ptr = reinterpret_cast(next_clean_data_ready_counter) + + node_id * num_max_dispatch_tokens_per_rank * num_nvl_ranks + + token_idx * num_nvl_ranks + + rank % num_nvl_ranks; + EP_DEVICE_ASSERT(data_ready_flag_ptr-next_clean_data_ready_counter < num_max_dispatch_tokens_per_rank*num_nodes*num_nvl_ranks*sizeof(int)); + const auto data_ready_p2p_src_ptr = nvshmemi_get_p2p_ptr(uint64_t(data_ready_flag_ptr), rank, rank/num_nvl_ranks*num_nvl_ranks + nvl_rank_id); + reinterpret_cast(data_ready_p2p_src_ptr)[0] = 0; + } + } + __syncthreads(); + #pragma unroll + for (int i = thread_id; i < num_experts; i += blockDim.x) + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + } + __syncthreads(); // Issue count sends @@ -222,13 +333,17 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto dst_rank = responsible_expert_idx / num_local_experts; const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; - + auto real_write_dst_rank = dst_rank / num_nvl_ranks * num_nvl_ranks + rank % num_nvl_ranks; + auto start_time = clock64(); // Wait local sends issued and send expert counts - while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); - auto dst_ptr = reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank); - auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 3); + auto dst_ptr = reinterpret_cast(rdma_recv_count) + + dst_expert_local_idx * num_ranks + + (rank/num_nvl_ranks) * num_nvl_ranks + + (dst_rank % num_nvl_ranks); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(uint64_t(dst_ptr), rank, real_write_dst_rank); if (dst_p2p_ptr == 0) { - nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); + nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), -num_tokens_sent - 1, real_write_dst_rank, dst_expert_local_idx); } else { st_release_sys_global(reinterpret_cast(dst_p2p_ptr), -num_tokens_sent - 1); } @@ -256,9 +371,6 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, if (responsible_expert_idx < num_experts) { const auto src_rank = responsible_expert_idx / num_local_experts; const auto local_expert_idx = responsible_expert_idx % num_local_experts; - const auto rdma_recv_x_uint8 = static_cast(rdma_recv_x) + - local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + - src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; const auto recv_x_int4 = static_cast(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; @@ -272,10 +384,27 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Wait tokens to arrive // NOTES: using sub-warp 1 to overlap with sub-warp 0 int num_recv_tokens, recv_token_begin_idx; - EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15); - if (sub_warp_id == 1 and lane_id == 0) { + EP_DEVICE_ASSERT(num_warps_per_group > 2 and num_warp_groups < 15); + if (sub_warp_id == 1 and lane_id == 0) { // wait recv count from same gpu_device_id_rank(same-rail rdma traffic) ready auto start_time = clock64(); - while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); + const auto read_recv_counter_rank = rank / num_nvl_ranks * num_nvl_ranks + src_rank % num_nvl_ranks; + const auto counter_p2p_ptr = reinterpret_cast(rdma_recv_count) + + local_expert_idx * num_ranks + + (src_rank/num_nvl_ranks)* num_nvl_ranks+ + (rank%num_nvl_ranks); + + const auto real_counter_p2p_src_ptr = nvshmemi_get_p2p_ptr(uint64_t(counter_p2p_ptr), rank, read_recv_counter_rank); + int num_recv_tokens = 0; + while (num_recv_tokens == 0) { // read recv counter from from remote nvl_rank + num_recv_tokens = ld_acquire_sys_global(reinterpret_cast(real_counter_p2p_src_ptr)); + // Timeout check + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf("DeepEP ll dispatch recv current counter timeout,src_rank:%d, dst_rank: %d, dst RDMA lane: %d, num_recv_tokens: %d\n", + src_rank, rank, lane_id, num_recv_tokens); + trap(); + } + } + auto wait_recv_cost = clock64() - start_time; num_recv_tokens = -num_recv_tokens - 1; recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); @@ -292,19 +421,39 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); num_recv_tokens = shared_num_recv_tokens[warp_group_id]; recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + const auto real_read_src_rank = src_rank % num_nvl_ranks + rank / num_nvl_ranks * num_nvl_ranks; + const auto src_copy_ptr = reinterpret_cast(rdma_recv_x_meta) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta + + (src_rank/num_nvl_ranks)* num_nvl_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_meta + + (rank%num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_meta; + const auto real_p2p_src_ptr = nvshmemi_get_p2p_ptr(src_copy_ptr, rank, real_read_src_rank); // Copy tokens EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { // Copy source info - const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); - if (lane_id == 0) - recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + const auto src_src_idx = reinterpret_cast(real_p2p_src_ptr + i * num_bytes_per_meta); + int src_token_idx = 0; + if (lane_id == 0) { + src_token_idx = ld_nc_global(src_src_idx); + recv_src_info[recv_token_begin_idx + i] = src_token_idx; + } + src_token_idx = __shfl_sync(0xffffffff, src_token_idx, 0); + const auto data_ready_flag_src_ptr = reinterpret_cast(data_ready_counter) + + (src_rank/num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_nvl_ranks + + src_token_idx * num_nvl_ranks + + rank % num_nvl_ranks; + const auto src_data_ready_flag_p2p_ptr = reinterpret_cast(nvshmemi_get_p2p_ptr(uint64_t(data_ready_flag_src_ptr), rank, real_read_src_rank)); + if (lane_id ==0 ) { + while (ld_acquire_sys_global(src_data_ready_flag_p2p_ptr) == 0); // wait for data to be ready + } __syncwarp(); - + const auto src_ptr = reinterpret_cast(rdma_recv_x_data) + + (src_rank/num_nvl_ranks) * num_max_dispatch_tokens_per_rank * num_bytes_per_data + + src_token_idx * num_bytes_per_data; + const auto src_data = reinterpret_cast(nvshmemi_get_p2p_ptr(src_ptr, rank, real_read_src_rank)); // Copy data // NOTES: only 2 load iterations for 7K hidden with 7 unrolls - const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); @@ -574,7 +723,11 @@ combine(void* combined_x, const auto warp_group_id = warp_id / num_warps_per_group; const auto sub_warp_id = warp_id % num_warps_per_group; const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + int* next_clean_data_ready_counter = reinterpret_cast(reinterpret_cast(next_clean) + num_experts * 2); + const auto num_nvl_ranks = NUM_MAX_NVL_PEERS; + const auto num_nodes = num_ranks / num_nvl_ranks; + int4 *next_clean_int4 = reinterpret_cast(next_clean); extern __shared__ __align__(1024) uint8_t smem_buffer[]; // Data type staffs @@ -602,14 +755,46 @@ combine(void* combined_x, goto LOW_LATENCY_COMBINE_RECV; // Clean up next buffer - if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { + if (sm_id == num_sms-1) { #pragma unroll - for (int i = lane_id; i < num_next_clean_int; i += 32) + for (int i = thread_id; i < num_experts; i += num_threads) next_clean[i] = 0; + next_clean_int4 = next_clean_int4 + num_experts; + #pragma unroll + for (int i = thread_id; i < num_experts; i += num_threads) {// clean for dispatch + const auto src_rank = i / num_local_experts; + const auto local_expert_idx = i % num_local_experts; + const auto read_recv_counter_rank = rank / num_nvl_ranks * num_nvl_ranks + src_rank % num_nvl_ranks; // read recv counter from from remote nvl_rank + + const auto counter_ptr = reinterpret_cast(next_clean_int4) + + local_expert_idx * num_ranks + + (src_rank/num_nvl_ranks)* num_nvl_ranks+ + (rank%num_nvl_ranks); + + const auto real_counter_p2p_src_ptr = nvshmemi_get_p2p_ptr(uint64_t(counter_ptr), rank, read_recv_counter_rank); + st_release_sys_global(reinterpret_cast(real_counter_p2p_src_ptr), {0,0,0,0}); + } + // clean data ready flag + #pragma unroll 8 + for (int i = thread_id; i < num_max_dispatch_tokens_per_rank*num_ranks; i += num_threads) { + int token_idx = i/num_ranks; + int rank_id = i%num_ranks; + { + auto node_id = rank_id/num_nvl_ranks; + auto nvl_rank_id = rank_id%num_nvl_ranks; + auto* data_ready_flag_ptr = reinterpret_cast(next_clean_data_ready_counter) + + node_id * num_max_dispatch_tokens_per_rank * num_nvl_ranks + + token_idx * num_nvl_ranks + + rank % num_nvl_ranks; + EP_DEVICE_ASSERT(data_ready_flag_ptr-next_clean_data_ready_counter < num_max_dispatch_tokens_per_rank*num_nodes*num_nvl_ranks*sizeof(int)); + const auto data_ready_p2p_src_ptr = nvshmemi_get_p2p_ptr(uint64_t(data_ready_flag_ptr), rank, rank/num_nvl_ranks*num_nvl_ranks + nvl_rank_id); + reinterpret_cast(data_ready_p2p_src_ptr)[0] = 0; + } + } // Notify before executing `int_p` - __syncwarp(); - if (lane_id == 0) + __syncthreads(); + if (thread_id == 0) atomic_add_release_global(atomic_clean_flag, num_experts); } diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index da6c34fa..ba642fb5 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -96,6 +96,21 @@ __device__ __forceinline__ uint64_t ld_acquire_sys_global(const uint64_t *ptr) { return ret; } +__device__ __forceinline__ int4 ld_acquire_sys_global(const int4 *ptr) { + int4 ret; + asm volatile("ld.acquire.sys.global.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) + : "l"(ptr)); + return ret; + } + + __device__ __forceinline__ void st_release_sys_global(const int4 *ptr, int4 val) { + asm volatile( + "st.release.sys.global.v4.s32 [%0], {%1, %2, %3, %4};" + : + : "l"(ptr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); + } + __device__ __forceinline__ int ld_acquire_global(const int *ptr) { int ret; asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index b076f341..925f33a5 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -61,6 +61,8 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) hook() if return_recv_hook else event.current_stream_wait() + if not do_check: + continue packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ if dispatch_use_fp8 else packed_recv_x.clone() @@ -131,9 +133,9 @@ def test_func(return_recv_hook: bool): cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - use_logfmt=use_logfmt, return_recv_hook=return_recv_hook) - large_gemm_with_hook(hook) if return_recv_hook else None + # combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, + # use_logfmt=use_logfmt, return_recv_hook=return_recv_hook) + # large_gemm_with_hook(hook) if return_recv_hook else None # Calculate bandwidth num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 @@ -152,15 +154,15 @@ def test_func(return_recv_hook: bool): # Separate profiling for return_recv_hook in (False, True): group.barrier() - dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), - kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, + dispatch_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), + kernel_names=('dispatch'), barrier_comm_profiling=True, suppress_kineto_output=True, num_kernels_per_period=2 if return_recv_hook else 1) if not return_recv_hook: print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' - f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) + ) else: print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' - f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', flush=True) + ) return hash_value