From 235c78d1d735a9a58c579c42b578b94ae309121d Mon Sep 17 00:00:00 2001 From: Ji Shi Date: Sun, 21 Sep 2025 01:01:19 -0700 Subject: [PATCH] Add HybridEP --- csrc/hybrid_ep.cu | 845 +++++++ csrc/hybrid_ep.cuh | 270 ++ csrc/kernels/hybrid_ep_backend.cuh | 2571 ++++++++++++++++++++ csrc/kernels/hybrid_ep_backend_configs.hpp | 80 + deep_ep/__init__.py | 2 + deep_ep/hybrid_ep_buffer.py | 247 ++ setup.py | 48 +- tests/test_mnnvlink_hybridep.py | 274 +++ tests/utils.py | 124 + 9 files changed, 4460 insertions(+), 1 deletion(-) create mode 100644 csrc/hybrid_ep.cu create mode 100644 csrc/hybrid_ep.cuh create mode 100644 csrc/kernels/hybrid_ep_backend.cuh create mode 100644 csrc/kernels/hybrid_ep_backend_configs.hpp create mode 100644 deep_ep/hybrid_ep_buffer.py create mode 100644 tests/test_mnnvlink_hybridep.py diff --git a/csrc/hybrid_ep.cu b/csrc/hybrid_ep.cu new file mode 100644 index 00000000..911f0027 --- /dev/null +++ b/csrc/hybrid_ep.cu @@ -0,0 +1,845 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved +#include "hybrid_ep.cuh" + +HybridEpBuffer::HybridEpBuffer(HybridEpConfigInstance config, int rank, int group_size, + int num_of_ranks_per_node) + : config(config), rank(rank), group_size(group_size), + num_of_ranks_per_node(num_of_ranks_per_node) { + this->local_rank = rank % num_of_ranks_per_node; + this->node_rank = rank / num_of_ranks_per_node; + + allocate_buffer(); +} + +HybridEpBuffer::~HybridEpBuffer() { + auto free_buffer = [](void *ptr) { + if (ptr != nullptr) { + CUDA_CHECK(cudaFree(ptr)); + } + }; + + // Clean up preprocessing buffer + free_buffer(this->preprocessing_tmp); + + // Clean up dispatch buffers + std::vector dispatch_ptrs = { + dispatch_buffers.rdma_inter_node_group_token, + dispatch_buffers.rdma_inter_node_group_prob, + dispatch_buffers.rdma_inter_node_group_scaling_factor, + dispatch_buffers.rdma_inter_node_group_flags, + dispatch_buffers.expected_rdma_flag_value, + dispatch_buffers.expected_intra_node_flag_value + }; + std::for_each(dispatch_ptrs.begin(), dispatch_ptrs.end(), free_buffer); + + device_mem_free(dispatch_buffers.expert_output_token, USE_MNNVLINK); + device_mem_free(dispatch_buffers.expert_output_prob, USE_MNNVLINK); + device_mem_free(dispatch_buffers.expert_output_scaling_factor, USE_MNNVLINK); + + if (this->local_rank == 0) { + device_mem_free(dispatch_buffers.intra_node_write_completion_flags, USE_MNNVLINK); + } else { + close_device_mem_handle(dispatch_buffers.intra_node_write_completion_flags, USE_MNNVLINK); + } + + for (int i = 0; i < config.num_of_ranks_per_node; i++) { + if (i != this->local_rank) { + close_device_mem_handle(dispatch_buffers.expert_output_token_all_ranks[i], USE_MNNVLINK); + close_device_mem_handle(dispatch_buffers.expert_output_prob_all_ranks[i], USE_MNNVLINK); + close_device_mem_handle(dispatch_buffers.expert_output_scaling_factor_all_ranks[i], USE_MNNVLINK); + } + } + + // Clean up dispatch pointer arrays + delete[] dispatch_buffers.expert_output_token_all_ranks; + delete[] dispatch_buffers.expert_output_prob_all_ranks; + delete[] dispatch_buffers.expert_output_scaling_factor_all_ranks; + + // Clean up combine buffers + std::vector combine_ptrs = { + combine_buffers.rdma_intra_node_red_token, + combine_buffers.rdma_intra_node_red_prob, + combine_buffers.rdma_inter_node_group_token, + combine_buffers.rdma_inter_node_group_prob, + combine_buffers.rdma_inter_node_group_flags, + combine_buffers.expected_rdma_flag_value, + combine_buffers.expected_intra_node_flag_value + }; + std::for_each(combine_ptrs.begin(), combine_ptrs.end(), free_buffer); + + device_mem_free(combine_buffers.expert_input_token, USE_MNNVLINK); + device_mem_free(combine_buffers.expert_input_prob, USE_MNNVLINK); + + if (this->local_rank == 0) { + device_mem_free(combine_buffers.intra_node_write_completion_flags, USE_MNNVLINK); + } else { + close_device_mem_handle(combine_buffers.intra_node_write_completion_flags, USE_MNNVLINK); + } + + for (int i = 0; i < config.num_of_ranks_per_node; i++) { + if (i != this->local_rank) { + close_device_mem_handle(combine_buffers.expert_input_token_all_ranks[i], USE_MNNVLINK); + close_device_mem_handle(combine_buffers.expert_input_prob_all_ranks[i], USE_MNNVLINK); + } + } + // Clean up combine pointer arrays + delete[] combine_buffers.expert_input_token_all_ranks; + delete[] combine_buffers.expert_input_prob_all_ranks; +} + +void HybridEpBuffer::allocate_buffer_for_preprocessing() { + auto preprocessing_tmp_elts = + config.num_of_blocks_preprocessing_api * config.num_of_ranks_per_node; + CUDA_CHECK( + cudaMalloc((void **)&this->preprocessing_tmp, + preprocessing_tmp_elts * sizeof(hybrid_ep::tmp_state_t))); +} + +void HybridEpBuffer::allocate_buffer_for_dispatch() { + dispatch_buffers.data_type = config.token_data_type; + size_t sizeof_token_data_type = get_token_data_type_size(dispatch_buffers.data_type); + + // Calculate buffer sizes + auto expert_output_token_elts = max_num_of_tokens_for_experts * config.hidden_dim; + auto expert_output_prob_elts = max_num_of_tokens_for_experts * + (config.num_of_experts_per_rank * config.num_of_ranks_per_node); + auto expert_output_scaling_factor_elts = max_num_of_tokens_for_experts * (config.hidden_dim / 128); + + auto rdma_inter_node_group_token_elts = config.max_num_of_tokens_per_rank * + (config.num_of_nodes - 1) * config.hidden_dim; + auto rdma_inter_node_group_prob_elts = config.max_num_of_tokens_per_rank * (config.num_of_nodes - 1) * + (config.num_of_experts_per_rank * config.num_of_ranks_per_node); + auto rdma_inter_node_group_scaling_factor_elts = config.max_num_of_tokens_per_rank * + (config.num_of_nodes - 1) * (config.hidden_dim / 128); + auto rdma_inter_node_group_flags_elts = (config.max_num_of_tokens_per_rank / + config.num_of_tokens_per_chunk_dispatch_api) * + (config.num_of_nodes - 1); + + // Allocate main buffers + device_mem_malloc((void**)&dispatch_buffers.expert_output_token, expert_output_token_elts * sizeof_token_data_type, USE_MNNVLINK); + device_mem_malloc((void**)&dispatch_buffers.expert_output_prob, expert_output_prob_elts * sizeof(float), USE_MNNVLINK); + device_mem_malloc((void**)&dispatch_buffers.expert_output_scaling_factor, expert_output_scaling_factor_elts * sizeof(float), USE_MNNVLINK); + + // Allocate RDMA buffers + CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.rdma_inter_node_group_token, + rdma_inter_node_group_token_elts * sizeof_token_data_type)); + CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.rdma_inter_node_group_prob, + rdma_inter_node_group_prob_elts * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.rdma_inter_node_group_scaling_factor, + rdma_inter_node_group_scaling_factor_elts * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.rdma_inter_node_group_flags, + rdma_inter_node_group_flags_elts * sizeof(uint64_t))); + + // Allocate and initialize synchronization buffers + if (this->local_rank == 0) { + device_mem_malloc((void**)&dispatch_buffers.intra_node_write_completion_flags, sizeof(uint32_t), USE_MNNVLINK); + } + + CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.expected_rdma_flag_value, sizeof(uint64_t))); + CUDA_CHECK(cudaMalloc((void**)&dispatch_buffers.expected_intra_node_flag_value, sizeof(uint32_t))); + CUDA_CHECK(cudaMemset(dispatch_buffers.expected_rdma_flag_value, 0, sizeof(uint64_t))); + CUDA_CHECK(cudaMemset(dispatch_buffers.expected_intra_node_flag_value, 0, sizeof(uint32_t))); + + // Create IPC memory handles + MemHandle handles[4]; + get_device_mem_handle(&handles[0], dispatch_buffers.expert_output_token, USE_MNNVLINK); + get_device_mem_handle(&handles[1], dispatch_buffers.expert_output_prob, USE_MNNVLINK); + get_device_mem_handle(&handles[2], dispatch_buffers.expert_output_scaling_factor, USE_MNNVLINK); + if (this->local_rank == 0) { + get_device_mem_handle(&handles[3], dispatch_buffers.intra_node_write_completion_flags, USE_MNNVLINK); + } + + // Pack handles into tensor + dispatch_memory_handles = torch::empty({static_cast(sizeof(handles))}, + torch::dtype(torch::kUInt8).device(torch::kCPU)); + memcpy(dispatch_memory_handles.data_ptr(), handles, sizeof(handles)); +} + +void HybridEpBuffer::allocate_buffer_for_combine() { + // Calculate buffer sizes + auto expert_input_token_elts = max_num_of_tokens_for_experts * config.hidden_dim; + auto expert_input_prob_elts = max_num_of_tokens_for_experts * + (config.num_of_experts_per_rank * config.num_of_ranks_per_node); + auto rdma_intra_node_red_token_elts = config.max_num_of_tokens_per_rank * + (config.num_of_nodes - 1) * config.hidden_dim; + auto rdma_intra_node_red_prob_elts = config.max_num_of_tokens_per_rank * (config.num_of_nodes - 1) * + (config.num_of_experts_per_rank * config.num_of_ranks_per_node); + auto rdma_inter_node_group_token_elts = config.max_num_of_tokens_per_rank * + (config.num_of_nodes - 1) * config.hidden_dim; + auto rdma_inter_node_group_prob_elts = config.max_num_of_tokens_per_rank * (config.num_of_nodes - 1) * + (config.num_of_experts_per_rank * config.num_of_ranks_per_node); + auto rdma_inter_node_group_flags_elts = (config.max_num_of_tokens_per_rank / + config.num_of_tokens_per_chunk_combine_api) * + (config.num_of_nodes - 1); + + // Allocate main buffers + device_mem_malloc((void**)&combine_buffers.expert_input_token, expert_input_token_elts * sizeof(uint16_t), USE_MNNVLINK); + device_mem_malloc((void**)&combine_buffers.expert_input_prob, expert_input_prob_elts * sizeof(float), USE_MNNVLINK); + + // Allocate RDMA buffers + CUDA_CHECK(cudaMalloc((void**)&combine_buffers.rdma_intra_node_red_token, + rdma_intra_node_red_token_elts * sizeof(uint16_t))); + CUDA_CHECK(cudaMalloc((void**)&combine_buffers.rdma_intra_node_red_prob, + rdma_intra_node_red_prob_elts * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)&combine_buffers.rdma_inter_node_group_token, + rdma_inter_node_group_token_elts * sizeof(uint16_t))); + CUDA_CHECK(cudaMalloc((void**)&combine_buffers.rdma_inter_node_group_prob, + rdma_inter_node_group_prob_elts * sizeof(float))); + CUDA_CHECK(cudaMalloc((void**)&combine_buffers.rdma_inter_node_group_flags, + rdma_inter_node_group_flags_elts * sizeof(uint64_t))); + + // Allocate and initialize synchronization buffers + if (this->local_rank == 0) { + device_mem_malloc((void**)&combine_buffers.intra_node_write_completion_flags, sizeof(uint32_t), USE_MNNVLINK); + } + + CUDA_CHECK(cudaMalloc((void**)&combine_buffers.expected_rdma_flag_value, sizeof(uint64_t))); + CUDA_CHECK(cudaMalloc((void**)&combine_buffers.expected_intra_node_flag_value, sizeof(uint32_t))); + CUDA_CHECK(cudaMemset(combine_buffers.expected_rdma_flag_value, 0, sizeof(uint64_t))); + CUDA_CHECK(cudaMemset(combine_buffers.expected_intra_node_flag_value, 0, sizeof(uint32_t))); + + // Create IPC memory handles + MemHandle handles[3]; + get_device_mem_handle(&handles[0], combine_buffers.expert_input_token, USE_MNNVLINK); + get_device_mem_handle(&handles[1], combine_buffers.expert_input_prob, USE_MNNVLINK); + if (this->local_rank == 0) { + get_device_mem_handle(&handles[2], combine_buffers.intra_node_write_completion_flags, USE_MNNVLINK); + } + + // Pack handles into tensor + combine_memory_handles = torch::empty({static_cast(sizeof(handles))}, + torch::dtype(torch::kUInt8).device(torch::kCPU)); + memcpy(combine_memory_handles.data_ptr(), handles, sizeof(handles)); +} + +void HybridEpBuffer::allocate_buffer() { + // Token number at the worst case, all tokens are routed to the same expert. + this->max_num_of_tokens_for_experts = config.max_num_of_tokens_per_rank * + config.num_of_ranks_per_node * + config.num_of_nodes; + assert(this->max_num_of_tokens_for_experts % 4 == + 0); // The number of tokens for experts should be divisible by 4, this + // is required by the permute make_row_id_map kernel + allocate_buffer_for_preprocessing(); + allocate_buffer_for_dispatch(); + allocate_buffer_for_combine(); +} + +void HybridEpBuffer::exchange_ipc_address(py::object process_group) { + try { + // Use Python's torch.distributed APIs through py::object + auto torch_distributed = py::module_::import("torch.distributed"); + + // Move tensors to CUDA for communication + auto dispatch_cuda = dispatch_memory_handles.cuda(); + auto combine_cuda = combine_memory_handles.cuda(); + + // Get world size from process group + int world_size = process_group.attr("size")().cast(); + + // Create empty tensors for allgather output + py::list dispatch_output_list; + py::list combine_output_list; + + for (int i = 0; i < world_size; i++) { + dispatch_output_list.append(torch::empty_like(dispatch_cuda)); + combine_output_list.append(torch::empty_like(combine_cuda)); + } + + // Perform allgather using Python API + torch_distributed.attr("all_gather")(dispatch_output_list, dispatch_cuda, process_group); + torch_distributed.attr("all_gather")(combine_output_list, combine_cuda, process_group); + + // Convert back to C++ vectors and move to CPU + std::vector dispatch_cpu_tensors; + std::vector combine_cpu_tensors; + + for (int i = 0; i < world_size; i++) { + dispatch_cpu_tensors.push_back(dispatch_output_list[i].cast().cpu()); + combine_cpu_tensors.push_back(combine_output_list[i].cast().cpu()); + } + + // Open handles from other ranks + open_handles_from_other_ranks(dispatch_cpu_tensors, combine_cpu_tensors); + + } catch (const std::exception& e) { + throw std::runtime_error( + "C++ distributed communication failed: " + std::string(e.what()) + ); + } +} + +void HybridEpBuffer::update_num_of_tokens_per_rank(int num_of_tokens_per_rank) { + config.num_of_tokens_per_rank = num_of_tokens_per_rank; +} + +void HybridEpBuffer::open_handles_from_other_ranks( + std::vector dispatch_handles, + std::vector combine_handles) { + + // Malloc the pointer arrays used in the dispatch kernel. + dispatch_buffers.expert_output_token_all_ranks = + (void **)malloc(config.num_of_ranks_per_node * sizeof(void *)); + dispatch_buffers.expert_output_prob_all_ranks = + (float **)malloc(config.num_of_ranks_per_node * sizeof(float *)); + dispatch_buffers.expert_output_scaling_factor_all_ranks = + (float **)malloc(config.num_of_ranks_per_node * sizeof(float *)); + + // Global offset means the position in the multi-node case. + auto global_offset = node_rank * num_of_ranks_per_node; + + // Open the dispatch handles for intra_node_write_completion_flags + if (local_rank != 0) { + MemHandle intra_node_write_completion_flags_handle; + // Only rank 0 will allocate memory for this flag + memcpy(&intra_node_write_completion_flags_handle, + dispatch_handles[global_offset].data_ptr() + + sizeof(MemHandle) * 3, + sizeof(MemHandle)); + open_device_mem_handle((void**)(&dispatch_buffers.intra_node_write_completion_flags), + &intra_node_write_completion_flags_handle, USE_MNNVLINK); + } + + // Open the handles for export_output + for (int i = 0; i < num_of_ranks_per_node; i++) { + MemHandle expert_output_token_handle, expert_output_prob_handle, + expert_output_scaling_factor_handle; + + // Extract the handles from the tensor. + auto base_ptr = dispatch_handles[i + global_offset].data_ptr(); + memcpy(&expert_output_token_handle, base_ptr, sizeof(MemHandle)); + memcpy(&expert_output_prob_handle, base_ptr + sizeof(MemHandle), + sizeof(MemHandle)); + memcpy(&expert_output_scaling_factor_handle, + base_ptr + sizeof(MemHandle) * 2, + sizeof(MemHandle)); + + // Open the handles for export_output + if (i != local_rank) { + open_device_mem_handle((void**)(&dispatch_buffers.expert_output_token_all_ranks[i]), + &expert_output_token_handle, USE_MNNVLINK); + open_device_mem_handle((void**)(&dispatch_buffers.expert_output_prob_all_ranks[i]), + &expert_output_prob_handle, USE_MNNVLINK); + open_device_mem_handle((void**)(&dispatch_buffers.expert_output_scaling_factor_all_ranks[i]), + &expert_output_scaling_factor_handle, USE_MNNVLINK); + } else { + // For local rank, use direct pointer assignment (more efficient, no IPC overhead) + dispatch_buffers.expert_output_token_all_ranks[i] = + dispatch_buffers.expert_output_token; + dispatch_buffers.expert_output_prob_all_ranks[i] = + dispatch_buffers.expert_output_prob; + dispatch_buffers.expert_output_scaling_factor_all_ranks[i] = + dispatch_buffers.expert_output_scaling_factor; + } + } + + // Malloc the pointer arrays used in the combine kernel. + combine_buffers.expert_input_token_all_ranks = + (uint16_t **)malloc(config.num_of_ranks_per_node * sizeof(uint16_t *)); + combine_buffers.expert_input_prob_all_ranks = + (float **)malloc(config.num_of_ranks_per_node * sizeof(float *)); + // Open the combine handles for intra_node_write_completion_flags + if (local_rank != 0) { + MemHandle intra_node_write_completion_flags_handle; + // Only rank 0 will allocate memory for this flag + memcpy(&intra_node_write_completion_flags_handle, + combine_handles[global_offset].data_ptr() + + sizeof(MemHandle) * 2, + sizeof(MemHandle)); + open_device_mem_handle((void**)(&combine_buffers.intra_node_write_completion_flags), + &intra_node_write_completion_flags_handle, USE_MNNVLINK); + } + // Open the handles for expert_input + for (int i = 0; i < num_of_ranks_per_node; i++) { + MemHandle expert_input_token_handle, expert_input_prob_handle; + auto base_ptr = combine_handles[i + global_offset].data_ptr(); + // Extract the handles from the tensor. + memcpy(&expert_input_token_handle, base_ptr, sizeof(MemHandle)); + memcpy(&expert_input_prob_handle, base_ptr + sizeof(MemHandle), + sizeof(MemHandle)); + // Open the handles for expert_input + if (i != local_rank) { + open_device_mem_handle((void**)(&combine_buffers.expert_input_token_all_ranks[i]), + &expert_input_token_handle, USE_MNNVLINK); + open_device_mem_handle((void**)(&combine_buffers.expert_input_prob_all_ranks[i]), + &expert_input_prob_handle, USE_MNNVLINK); + } else { + // For local rank, use direct pointer assignment (more efficient, no IPC overhead) + combine_buffers.expert_input_token_all_ranks[i] = + combine_buffers.expert_input_token; + combine_buffers.expert_input_prob_all_ranks[i] = + combine_buffers.expert_input_prob; + } + } +} + +std::tuple +HybridEpBuffer::metadata_preprocessing(torch::Tensor routing_map, int64_t node_rank, + int64_t local_rank) { + assert(routing_map.device().is_cuda()); + assert(routing_map.is_contiguous()); + + // padding for the routing map + const int rdma_to_attn_map_size_per_node = (((config.num_of_tokens_per_rank - 1) / 16) + 1) * 16; + + // Construt the output tensor of the metadata preprocessing kernel. + auto sparse_to_dense_map = + torch::empty({config.num_of_tokens_per_rank * config.num_of_nodes, + config.num_of_ranks_per_node}, + torch::dtype(torch::kInt32).device(torch::kCUDA)); + auto rdma_to_attn_map = + torch::empty({rdma_to_attn_map_size_per_node, config.num_of_nodes}, + torch::dtype(torch::kBool).device(torch::kCUDA)); + auto attn_to_rdma_map = + torch::empty({config.num_of_tokens_per_rank, config.num_of_nodes - 1}, + torch::dtype(torch::kBool).device(torch::kCUDA)); + auto num_of_tokens_for_experts = + torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + auto local_expert_routing_map = torch::empty( + {config.num_of_tokens_per_rank * config.num_of_ranks_per_node * config.num_of_nodes, config.num_of_experts_per_rank}, + torch::dtype(torch::kBool).device(torch::kCUDA)); + + hybrid_ep::hybrid_ep::metadata_preprocessing( + routing_map.data_ptr(), this->preprocessing_tmp, + sparse_to_dense_map.data_ptr(), + rdma_to_attn_map.data_ptr(), attn_to_rdma_map.data_ptr(), + num_of_tokens_for_experts.data_ptr(), + local_expert_routing_map.data_ptr(), static_cast(node_rank), + static_cast(local_rank), static_cast(config.num_of_tokens_per_rank), at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(sparse_to_dense_map, rdma_to_attn_map, + attn_to_rdma_map, num_of_tokens_for_experts, + local_expert_routing_map); +} + +std::tuple +HybridEpBuffer::dispatch(torch::Tensor hidden, c10::optional probs, + c10::optional scaling_factor, + torch::Tensor sparse_to_dense_map, + torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, + int64_t num_of_tokens_for_experts, bool with_probs) { + + // Use exact token count if available, otherwise use maximum bound + auto token_count = (num_of_tokens_for_experts >= 0) ? num_of_tokens_for_experts : max_num_of_tokens_for_experts; + + auto stream = at::cuda::getCurrentCUDAStream(); + + // Create and return output tensors + size_t sizeof_token_data_type = get_token_data_type_size(dispatch_buffers.data_type); + auto create_output_tensors = [&](int64_t token_count) -> std::tuple { + torch::Tensor dispatched_tokens, dispatched_probs, dispatched_scaling_factor; + + // Create dispatched tokens tensor and copy data + dispatched_tokens = torch::empty({token_count, config.hidden_dim}, + torch::dtype(hidden.dtype()).device(torch::kCUDA)); + auto res_sz = token_count * config.hidden_dim * sizeof_token_data_type; + CUDA_CHECK(cudaMemcpyAsync(dispatched_tokens.data_ptr(), + dispatch_buffers.expert_output_token, + res_sz, cudaMemcpyDeviceToDevice, stream)); + + // Create and copy probs if needed + if (with_probs) { + dispatched_probs = torch::empty({token_count, + config.num_of_experts_per_rank * config.num_of_ranks_per_node}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)); + auto probs_sz = token_count * config.num_of_experts_per_rank * + config.num_of_ranks_per_node * sizeof(float); + CUDA_CHECK(cudaMemcpyAsync(dispatched_probs.data_ptr(), + dispatch_buffers.expert_output_prob, + probs_sz, cudaMemcpyDeviceToDevice, stream)); + } + + // Create and copy scaling factor if using UINT8 + if (config.token_data_type == TOKEN_DATA_TYPE::UINT8) { + dispatched_scaling_factor = torch::empty({token_count, config.hidden_dim / 128}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)); + auto scaling_factor_sz = token_count * config.hidden_dim / 128 * sizeof(float); + CUDA_CHECK(cudaMemcpyAsync(dispatched_scaling_factor.data_ptr(), + dispatch_buffers.expert_output_scaling_factor, + scaling_factor_sz, cudaMemcpyDeviceToDevice, stream)); + } + + return std::make_tuple(dispatched_tokens, dispatched_probs, dispatched_scaling_factor); + }; + + // Fast return if there are no tokens to dispatch + if (token_count == 0) { + return create_output_tensors(0); + } + + assert(hidden.device().is_cuda()); + assert(hidden.is_contiguous()); + + float *probs_fp32 = nullptr; + float *scaling_factor_fp32 = nullptr; + if (with_probs) { + assert(probs.has_value()); + assert(probs.value().device().is_cuda()); + assert(probs.value().is_contiguous()); + assert(probs.value().dtype() == torch::kFloat32); + auto probs_tensor = probs.value().view(torch::kFloat32); + probs_fp32 = probs_tensor.data_ptr(); + } + if (config.token_data_type == TOKEN_DATA_TYPE::UINT8) { + assert(scaling_factor.has_value()); + assert(scaling_factor.value().device().is_cuda()); + assert(scaling_factor.value().is_contiguous()); + auto scaling_factor_tensor = scaling_factor.value().view(torch::kFloat32); + scaling_factor_fp32 = scaling_factor_tensor.data_ptr(); + } + + // Template function to setup and launch kernel parameters for uint8 + auto launch_uint8_kernel = [&]() { + auto hidden_uint8 = hidden.view(torch::kUInt8); + assert(NUM_OF_RANKS_PER_NODE == config.num_of_ranks_per_node); + + hybrid_ep::dispatch_kernel_param_t param; + param.attn_input_token = hidden_uint8.data_ptr(); + param.attn_input_prob = probs_fp32; + param.attn_input_token_scaling_factor = scaling_factor_fp32; + + // Setup output pointers + for (int i = 0; i < config.num_of_ranks_per_node; i++) { + param.expert_output_token[i] = reinterpret_cast( + dispatch_buffers.expert_output_token_all_ranks[i]); + param.expert_output_prob[i] = dispatch_buffers.expert_output_prob_all_ranks[i]; + param.expert_output_scaling_factor[i] = + dispatch_buffers.expert_output_scaling_factor_all_ranks[i]; + } + + // Setup RDMA parameters + param.rdma_inter_node_group_token = reinterpret_cast( + dispatch_buffers.rdma_inter_node_group_token); + param.rdma_inter_node_group_prob = dispatch_buffers.rdma_inter_node_group_prob; + param.rdma_inter_node_group_scaling_factor = + dispatch_buffers.rdma_inter_node_group_scaling_factor; + param.rdma_inter_node_group_flags = dispatch_buffers.rdma_inter_node_group_flags; + param.intra_node_write_completion_flags = + dispatch_buffers.intra_node_write_completion_flags; + param.rdma_to_attn_map = rdma_to_attn_map.data_ptr(); + param.attn_to_rdma_map = attn_to_rdma_map.data_ptr(); + param.sparse_to_dense_map = sparse_to_dense_map.data_ptr(); + param.local_rank = local_rank; + param.node_rank = node_rank; + param.num_of_tokens_per_rank = config.num_of_tokens_per_rank; + param.expected_rdma_flag_value = dispatch_buffers.expected_rdma_flag_value; + param.expected_intra_node_flag_value = dispatch_buffers.expected_intra_node_flag_value; + + // Launch kernel + if (with_probs) { + hybrid_ep::hybrid_ep + ::dispatch(param, stream); + } else { + hybrid_ep::hybrid_ep + ::dispatch(param, stream); + } + }; + + // Template function to setup and launch kernel parameters for uint16 + auto launch_uint16_kernel = [&]() { + auto hidden_uint16 = hidden.view(torch::kUInt16); + assert(NUM_OF_RANKS_PER_NODE == config.num_of_ranks_per_node); + + hybrid_ep::dispatch_kernel_param_t param; + param.attn_input_token = hidden_uint16.data_ptr(); + param.attn_input_prob = probs_fp32; + param.attn_input_token_scaling_factor = scaling_factor_fp32; + + // Setup output pointers + for (int i = 0; i < config.num_of_ranks_per_node; i++) { + param.expert_output_token[i] = reinterpret_cast( + dispatch_buffers.expert_output_token_all_ranks[i]); + param.expert_output_prob[i] = dispatch_buffers.expert_output_prob_all_ranks[i]; + param.expert_output_scaling_factor[i] = + dispatch_buffers.expert_output_scaling_factor_all_ranks[i]; + } + + // Setup RDMA parameters + param.rdma_inter_node_group_token = reinterpret_cast( + dispatch_buffers.rdma_inter_node_group_token); + param.rdma_inter_node_group_prob = dispatch_buffers.rdma_inter_node_group_prob; + param.rdma_inter_node_group_scaling_factor = + dispatch_buffers.rdma_inter_node_group_scaling_factor; + param.rdma_inter_node_group_flags = dispatch_buffers.rdma_inter_node_group_flags; + param.intra_node_write_completion_flags = + dispatch_buffers.intra_node_write_completion_flags; + param.rdma_to_attn_map = rdma_to_attn_map.data_ptr(); + param.attn_to_rdma_map = attn_to_rdma_map.data_ptr(); + param.sparse_to_dense_map = sparse_to_dense_map.data_ptr(); + param.local_rank = local_rank; + param.node_rank = node_rank; + param.num_of_tokens_per_rank = config.num_of_tokens_per_rank; + param.expected_rdma_flag_value = dispatch_buffers.expected_rdma_flag_value; + param.expected_intra_node_flag_value = dispatch_buffers.expected_intra_node_flag_value; + + // Launch kernel + if (with_probs) { + hybrid_ep::hybrid_ep + ::dispatch(param, stream); + } else { + hybrid_ep::hybrid_ep + ::dispatch(param, stream); + } + }; + + // Dispatch based on token data type + bool kernel_launched = false; + switch (config.token_data_type) { + case TOKEN_DATA_TYPE::UINT8: + launch_uint8_kernel(); + kernel_launched = true; + break; + case TOKEN_DATA_TYPE::UINT16: + launch_uint16_kernel(); + kernel_launched = true; + break; + default: + throw std::runtime_error("Invalid token data type:" + + std::to_string(static_cast(config.token_data_type))); + } + + if (!kernel_launched) { + throw std::runtime_error("Failed to launch dispatch kernel for num_of_ranks_per_node: " + + std::to_string(config.num_of_ranks_per_node)); + } + + return create_output_tensors(token_count); +} + +std::tuple +HybridEpBuffer::combine(torch::Tensor hidden, c10::optional probs, + torch::Tensor sparse_to_dense_map, + torch::Tensor rdma_to_attn_map, torch::Tensor attn_to_rdma_map, + bool with_probs) { + + // The result tensor of the combine kernel + torch::Tensor combined_tokens, combined_probs; + combined_tokens = + torch::empty({config.num_of_tokens_per_rank, config.hidden_dim}, + torch::dtype(hidden.dtype()).device(torch::kCUDA)); + if (with_probs) { + combined_probs = + torch::empty({config.num_of_tokens_per_rank, + config.num_of_experts_per_rank * + config.num_of_ranks_per_node * config.num_of_nodes}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)); + } + + // Fast return if there are no tokens after combine + if (config.num_of_tokens_per_rank == 0) { + return std::make_tuple(combined_tokens, combined_probs); + } + + assert(hidden.device().is_cuda()); + assert(hidden.dtype() != torch::kUInt8); + assert(hidden.is_contiguous()); + + float *probs_fp32 = nullptr; + auto stream = at::cuda::getCurrentCUDAStream(); + + if (with_probs) { + assert(probs.has_value()); + assert(probs.value().device().is_cuda()); + assert(probs.value().is_contiguous()); + assert(probs.value().dtype() == torch::kFloat32); + assert(probs.value().size(1) == + config.num_of_experts_per_rank * config.num_of_ranks_per_node); + auto probs_tensor = probs.value().view(torch::kFloat32); + probs_fp32 = probs_tensor.data_ptr(); + } + + // Copy the input tensor to the input buffer + auto input_sz = hidden.numel() * sizeof(uint16_t); + CUDA_CHECK( + cudaMemcpyAsync(combine_buffers.expert_input_token, + reinterpret_cast(hidden.data_ptr()), input_sz, + cudaMemcpyDeviceToDevice, stream)); + if (with_probs) { + auto probs_sz = probs.value().numel() * sizeof(float); + CUDA_CHECK(cudaMemcpyAsync(combine_buffers.expert_input_prob, + probs_fp32, probs_sz, + cudaMemcpyDeviceToDevice, stream)); + } + + bool kernel_launched = false; + + assert(NUM_OF_RANKS_PER_NODE == config.num_of_ranks_per_node); + hybrid_ep::combine_kernel_param_t param; + for (int i = 0; i < config.num_of_ranks_per_node; i++) { + param.expert_input_token[i] = + combine_buffers.expert_input_token_all_ranks[i]; + param.expert_input_prob[i] = + combine_buffers.expert_input_prob_all_ranks[i]; + } + param.attn_output_token = + reinterpret_cast(combined_tokens.data_ptr()); + param.attn_output_prob = + with_probs ? combined_probs.data_ptr() : nullptr; + param.rdma_intra_node_red_token = + combine_buffers.rdma_intra_node_red_token; + param.rdma_intra_node_red_prob = combine_buffers.rdma_intra_node_red_prob; + param.rdma_inter_node_group_token = + combine_buffers.rdma_inter_node_group_token; + param.rdma_inter_node_group_prob = + combine_buffers.rdma_inter_node_group_prob; + param.rdma_inter_node_group_flags = + combine_buffers.rdma_inter_node_group_flags; + param.intra_node_write_completion_flags = + combine_buffers.intra_node_write_completion_flags; + param.rdma_to_attn_map = rdma_to_attn_map.data_ptr(); + param.attn_to_rdma_map = attn_to_rdma_map.data_ptr(); + param.sparse_to_dense_map = sparse_to_dense_map.data_ptr(); + param.node_rank = this->node_rank; + param.num_of_tokens_per_rank = config.num_of_tokens_per_rank; + param.expected_rdma_flag_value = combine_buffers.expected_rdma_flag_value; + param.expected_intra_node_flag_value = + combine_buffers.expected_intra_node_flag_value; + + // param.dgqps = combine_buffers.dgqps; + // param.mr_info = combine_buffers.mr_info; + // Call the combine kernel directly using template instantiation + if (with_probs) { + hybrid_ep::hybrid_ep< + HIDDEN_DIM, + MAX_NUM_OF_TOKENS_PER_RANK, + NUM_OF_RANKS_PER_NODE, + NUM_OF_NODES, + NUM_OF_EXPERTS_PER_RANK + >::combine< + NUM_OF_STAGES_G2S_COMBINE_API, + NUM_OF_STAGES_S2G_COMBINE_API, + NUM_OF_TOKENS_PER_CHUNK_COMBINE_API, + NUM_OF_TOKENS_PER_GROUP_COMBINE_API, + NUM_OF_BLOCKS_COMBINE_API, + NUM_OF_ADDITIONAL_IN_FLIGHT_S2G_COMBINE_API, + true, + DEVICE_SIDE_SYNC_COMBINE_API + >(param, stream); + } else { + hybrid_ep::hybrid_ep< + HIDDEN_DIM, + MAX_NUM_OF_TOKENS_PER_RANK, + NUM_OF_RANKS_PER_NODE, + NUM_OF_NODES, + NUM_OF_EXPERTS_PER_RANK + >::combine< + NUM_OF_STAGES_G2S_COMBINE_API, + NUM_OF_STAGES_S2G_COMBINE_API, + NUM_OF_TOKENS_PER_CHUNK_COMBINE_API, + NUM_OF_TOKENS_PER_GROUP_COMBINE_API, + NUM_OF_BLOCKS_COMBINE_API, + NUM_OF_ADDITIONAL_IN_FLIGHT_S2G_COMBINE_API, + false, + DEVICE_SIDE_SYNC_COMBINE_API + >(param, stream); + } + kernel_launched = true; + + if (!kernel_launched) { + throw std::runtime_error( + "fail to launch the combine kernel, corresponding num_of_ranks_per_node:" + + std::to_string(config.num_of_ranks_per_node)); + } + + return std::make_tuple(combined_tokens, combined_probs); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "HybridEP, efficiently enable the expert-parallel communication in " + "the Hopper+ architectures"; + + pybind11::enum_(m, "TokenDataType") + .value("UINT16", TOKEN_DATA_TYPE::UINT16) + .value("UINT8", TOKEN_DATA_TYPE::UINT8) + .export_values() // So we can use hybrid_ep_cpp.TYPE instead of the + // hybrid_ep_cpp.TOKEN_DATA_TYPE.TYPE + .def("__str__", + [](const TOKEN_DATA_TYPE &type) { return type_to_string(type); }); + + pybind11::class_(m, "HybridEpConfigInstance") + .def(py::init<>()) + // Hybrid-ep Config + .def_readwrite("hidden_dim", &HybridEpConfigInstance::hidden_dim) + .def_readwrite("num_of_tokens_per_rank", + &HybridEpConfigInstance::num_of_tokens_per_rank) + .def_readwrite("max_num_of_tokens_per_rank", + &HybridEpConfigInstance::max_num_of_tokens_per_rank) + .def_readwrite("num_of_experts_per_rank", + &HybridEpConfigInstance::num_of_experts_per_rank) + .def_readwrite("num_of_ranks_per_node", + &HybridEpConfigInstance::num_of_ranks_per_node) + .def_readwrite("num_of_nodes", &HybridEpConfigInstance::num_of_nodes) + // Metadata-preprocessing API Config + .def_readwrite( + "num_of_threads_per_block_preprocessing_api", + &HybridEpConfigInstance::num_of_threads_per_block_preprocessing_api) + .def_readwrite("num_of_blocks_preprocessing_api", + &HybridEpConfigInstance::num_of_blocks_preprocessing_api) + // Dispatch API Config + .def_readwrite("token_data_type", &HybridEpConfigInstance::token_data_type) + .def_readwrite("num_of_stages_dispatch_api", + &HybridEpConfigInstance::num_of_stages_dispatch_api) + .def_readwrite("num_of_tokens_per_chunk_dispatch_api", + &HybridEpConfigInstance::num_of_tokens_per_chunk_dispatch_api) + .def_readwrite("num_of_blocks_dispatch_api", + &HybridEpConfigInstance::num_of_blocks_dispatch_api) + .def_readwrite("forward_dispatch_api", + &HybridEpConfigInstance::forward_dispatch_api) + .def_readwrite("device_side_sync_dispatch_api", + &HybridEpConfigInstance::device_side_sync_dispatch_api) + // Combine API Config + .def_readwrite("num_of_stages_g2s_combine_api", + &HybridEpConfigInstance::num_of_stages_g2s_combine_api) + .def_readwrite("num_of_stages_s2g_combine_api", + &HybridEpConfigInstance::num_of_stages_s2g_combine_api) + .def_readwrite("num_of_tokens_per_chunk_combine_api", + &HybridEpConfigInstance::num_of_tokens_per_chunk_combine_api) + .def_readwrite("num_of_tokens_per_group_combine_api", + &HybridEpConfigInstance::num_of_tokens_per_group_combine_api) + .def_readwrite("num_of_blocks_combine_api", + &HybridEpConfigInstance::num_of_blocks_combine_api) + .def_readwrite( + "num_of_additional_in_flight_s2g_combine_api", + &HybridEpConfigInstance::num_of_additional_in_flight_s2g_combine_api) + .def_readwrite("backward_combine_api", + &HybridEpConfigInstance::backward_combine_api) + .def_readwrite("device_side_sync_combine_api", + &HybridEpConfigInstance::device_side_sync_combine_api) + .def("__repr__", [](const HybridEpConfigInstance &config) { + return ""; + }); + + pybind11::class_(m, "HybridEpBuffer") + .def(py::init()) + .def("exchange_ipc_address", &HybridEpBuffer::exchange_ipc_address) + .def("update_num_of_tokens_per_rank", &HybridEpBuffer::update_num_of_tokens_per_rank, + py::arg("num_of_tokens_per_rank")) + .def("metadata_preprocessing", &HybridEpBuffer::metadata_preprocessing, + py::kw_only(), py::arg("routing_map"), py::arg("node_rank"), + py::arg("local_rank")) + .def("dispatch", &HybridEpBuffer::dispatch, py::kw_only(), py::arg("hidden"), + py::arg("probs") = c10::nullopt, + py::arg("scaling_factor") = c10::nullopt, + py::arg("sparse_to_dense_map"), py::arg("rdma_to_attn_map"), + py::arg("attn_to_rdma_map"), + py::arg("num_of_tokens_for_experts") = -1, py::arg("with_probs")) + .def("combine", &HybridEpBuffer::combine, py::kw_only(), py::arg("hidden"), + py::arg("probs") = c10::nullopt, py::arg("sparse_to_dense_map"), + py::arg("rdma_to_attn_map"), py::arg("attn_to_rdma_map"), + py::arg("with_probs")); +} diff --git a/csrc/hybrid_ep.cuh b/csrc/hybrid_ep.cuh new file mode 100644 index 00000000..ac87fdf2 --- /dev/null +++ b/csrc/hybrid_ep.cuh @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved +#pragma once +#include "kernels/hybrid_ep_backend_configs.hpp" +#include "kernels/hybrid_ep_backend.cuh" +#include +#include +#include +#include +#include +#include +#include +#include + +inline std::string type_to_string(TOKEN_DATA_TYPE token_data_type) { + switch (token_data_type) { + case TOKEN_DATA_TYPE::UINT16: + return "uint16_t"; + case TOKEN_DATA_TYPE::UINT8: + return "uint8_t"; + default: + return "unknown"; + } +} + +union MemHandleInner{ + cudaIpcMemHandle_t cuda_ipc_mem_handle; + CUmemFabricHandle cu_mem_fabric_handle; +}; + +struct MemHandle{ + MemHandleInner inner; + size_t size; +}; + +// Utility function to get token data type size +inline size_t get_token_data_type_size(TOKEN_DATA_TYPE data_type) { + switch (data_type) { + case TOKEN_DATA_TYPE::UINT8: + return sizeof(uint8_t); + case TOKEN_DATA_TYPE::UINT16: + return sizeof(uint16_t); + default: + throw std::runtime_error("Invalid token data type:" + std::to_string(static_cast(data_type))); + } +} + +// Round-up allocation size to fabric granularity. +inline size_t get_size_align_to_granularity(size_t size_raw, size_t granularity){ + size_t size = (size_raw + granularity - 1) & ~(granularity - 1); + if(size == 0) size = granularity; + return size; +} + +// Device memory allocator, allocate local device memory. Support both normal cudaMalloc and fabric allocator. +inline void device_mem_malloc(void** ptr, size_t size_raw, bool enable_fabric){ + if(enable_fabric){ + CUdevice device; + CU_CHECK(cuCtxGetDevice(&device)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; + prop.location.id = device; + + size_t granularity = 0; + CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t size = get_size_align_to_granularity(size_raw, granularity); + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemCreate(&handle, size, &prop, 0)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, granularity, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + CUmemAccessDesc access_desc = {}; + access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc.location.id = device; + access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + CU_CHECK(cuMemSetAccess((CUdeviceptr)*ptr, size, &access_desc, 1)); + }else{ + CUDA_CHECK(cudaMalloc(ptr, size_raw)); + } +} + +// Get sharable memory handle of local device memory for remote ranks to access. Support both IPC handle and fabric handle. +inline void get_device_mem_handle(MemHandle* mem_handle, void* ptr, bool enable_fabric){ + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + mem_handle->size = size; + + if(enable_fabric){ + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); + }else{ + CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); + } +} + +// Open sharable memory handle from other remote ranks and map it for local device to access. Support both IPC handle and fabric handle. +inline void open_device_mem_handle(void** ptr, MemHandle* mem_handle, bool enable_fabric){ + if(enable_fabric){ + CUdevice device; + CU_CHECK(cuCtxGetDevice(&device)); + size_t size = mem_handle->size; + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, 0, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + CUmemAccessDesc access_desc = {}; + access_desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc.location.id = device; + access_desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + CU_CHECK(cuMemSetAccess((CUdeviceptr)*ptr, size, &access_desc, 1)); + }else{ + CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess)); + } +} + +// Close and unmap sharable memory handle from other remote ranks. Support both IPC handle and fabric handle. +inline void close_device_mem_handle(void* ptr, bool enable_fabric){ + if(enable_fabric){ + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); + }else{ + CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); + } +} + +// Free local device memory allocated by device_mem_malloc. +inline void device_mem_free(void* ptr, bool enable_fabric){ + if(enable_fabric){ + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemRelease(handle)); + }else{ + CUDA_CHECK(cudaFree(ptr)); + } +} + +class HybridEpBuffer { +public: + HybridEpBuffer(HybridEpConfigInstance config, int local_rank, int node_rank, + int num_of_ranks_per_node); + ~HybridEpBuffer(); + + // Exchange IPC addresses using C++ distributed communication + void exchange_ipc_address(pybind11::object process_group); + + void update_num_of_tokens_per_rank(int num_of_tokens_per_rank); + + std::tuple + metadata_preprocessing(torch::Tensor routing_map, int64_t node_rank, + int64_t local_rank); + + std::tuple + dispatch(torch::Tensor hidden, c10::optional probs, + c10::optional scaling_factor, + torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, + torch::Tensor attn_to_rdma_map, int64_t num_of_tokens_for_experts, + bool with_probs); + + std::tuple + combine(torch::Tensor hidden, c10::optional probs, + torch::Tensor sparse_to_dense_map, torch::Tensor rdma_to_attn_map, + torch::Tensor attn_to_rdma_map, bool with_probs); + +private: + void allocate_buffer(); + void allocate_buffer_for_preprocessing(); + void allocate_buffer_for_dispatch(); + void allocate_buffer_for_combine(); + void open_handles_from_other_ranks(std::vector dispatch_handles, + std::vector combine_handles); + + HybridEpConfigInstance config; + int rank; + int group_size; + int local_rank; + int node_rank; + int num_of_ranks_per_node; + + int64_t max_num_of_tokens_for_experts; + + hybrid_ep::tmp_state_t *preprocessing_tmp; + + struct DispatchBuffers { + TOKEN_DATA_TYPE data_type; + + void *expert_output_token; + + void **expert_output_token_all_ranks; + + float *expert_output_prob; + + float **expert_output_prob_all_ranks; + + float *expert_output_scaling_factor; + + float **expert_output_scaling_factor_all_ranks; + + void *rdma_inter_node_group_token; + + float *rdma_inter_node_group_prob; + + float *rdma_inter_node_group_scaling_factor; + + uint64_t *rdma_inter_node_group_flags; + + uint32_t *intra_node_write_completion_flags; + + uint64_t *expected_rdma_flag_value; + + uint32_t *expected_intra_node_flag_value; + + } dispatch_buffers; + + torch::Tensor + dispatch_memory_handles; + + struct CombineBuffers { + + uint16_t *expert_input_token; + + uint16_t **expert_input_token_all_ranks; + + float *expert_input_prob; + + float **expert_input_prob_all_ranks; + + uint16_t *rdma_intra_node_red_token; + + float *rdma_intra_node_red_prob; + + uint16_t *rdma_inter_node_group_token; + + float + *rdma_inter_node_group_prob; + + uint64_t + *rdma_inter_node_group_flags; + + uint32_t *intra_node_write_completion_flags; + + uint64_t *expected_rdma_flag_value; + + uint32_t *expected_intra_node_flag_value; + + + } combine_buffers; + + torch::Tensor + combine_memory_handles; + +}; \ No newline at end of file diff --git a/csrc/kernels/hybrid_ep_backend.cuh b/csrc/kernels/hybrid_ep_backend.cuh new file mode 100644 index 00000000..fef99f94 --- /dev/null +++ b/csrc/kernels/hybrid_ep_backend.cuh @@ -0,0 +1,2571 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t const status = call; \ + if (status != cudaSuccess) { \ + cudaGetLastError(); \ + fprintf(stderr, \ + "CUDA error encountered at: " \ + "file=%s, line=%d, " \ + "call='%s', Reason=%s:%s", \ + __FILE__, __LINE__, #call, cudaGetErrorName(status), \ + cudaGetErrorString(status)); \ + abort(); \ + } \ + } while (0) + +#define CU_CHECK(call) \ + do { \ + auto result = call; \ + if (result != CUDA_SUCCESS) { \ + const char *p_err_str = nullptr; \ + if (cuGetErrorString(result, &p_err_str) == CUDA_ERROR_INVALID_VALUE) { \ + p_err_str = "Unrecoginzed CU error num"; \ + } \ + fprintf(stderr, "CU error encountered at: " \ + "file=%s line=%d, call='%s' Reason=%s.\n", \ + __FILE__, __LINE__, \ + #call, p_err_str); \ + abort(); \ + } \ + } while (0) + + +namespace hybrid_ep{ + +/*enum DATA_TYPE{ + HYBRID_EP_DATA_TYPE_FP32, + HYBRID_EP_DATA_TYPE_FP16, + HYBRID_EP_DATA_TYPE_BF16, + HYBRID_EP_DATA_TYPE_FP8 +};*/ + +/*template +struct bool_any_reduction_type{}; + +template<> struct bool_any_reduction_type<8> { using Type = uint64_t; }; +template<> struct bool_any_reduction_type<4> { using Type = uint32_t; }; +template<> struct bool_any_reduction_type<2> { using Type = uint16_t; }; +template<> struct bool_any_reduction_type<1> { using Type = uint8_t; };*/ + +template +using Reduce_t = + typename std::conditional::type + >::type + >::type; + +template +using Copy_t = + typename std::conditional::type + >::type + >::type + >::type; + +enum scan_state{ + EMPTY = 0, + PRIV_SUM = 1 +}; + +struct tmp_state_t{ + scan_state state; + int32_t value; +}; + +// Generic warp group for warp-specializaion. +template +struct warp_group{ + __host__ __device__ static constexpr int size(){ return 32 * NUM_WARPS; } + __host__ __device__ static constexpr int warp_size(){ return NUM_WARPS; } + + __host__ __device__ static int thread_rank(){ return threadIdx.x - (32 * STARTING_WARPS); } + __host__ __device__ static int warp_rank(){ return thread_rank() / 32; } +}; + +template +struct dispatch_kernel_dynamic_shared_memory_buffer_t{}; + +template +struct dispatch_kernel_dynamic_shared_memory_buffer_t{ + // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. + alignas(128) uint8_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory Prob buffer. Only used in FW dispatch. Should be 16B alignment so can be used with TMA. 128B is too strict. + alignas(16) float intra_node_prob_buffer[NUM_OF_STAGES][NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; + // Shared memory scaling factor buffer. Only when using FP8 token. Should be 16B alignment so can be used with TMA. 128B is too strict. + alignas(16) float intra_node_scaling_factor_buffer[NUM_OF_STAGES][HIDDEN_DIM / 128]; + // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). + alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; +}; + +template +struct dispatch_kernel_dynamic_shared_memory_buffer_t{ + // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. + alignas(128) uint16_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory Prob buffer. Only used in FW dispatch. Should be 16B alignment so can be used with TMA. 128B is too strict. + alignas(16) float intra_node_prob_buffer[NUM_OF_STAGES][NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; + // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). + alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; +}; + +template +struct dispatch_kernel_dynamic_shared_memory_buffer_t{ + // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. + alignas(128) uint8_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory scaling factor buffer. Only when using FP8 token. Should be 16B alignment so can be used with TMA. 128B is too strict. + alignas(16) float intra_node_scaling_factor_buffer[NUM_OF_STAGES][HIDDEN_DIM / 128]; + // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). + alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; +}; + +template +struct dispatch_kernel_dynamic_shared_memory_buffer_t{ + // Shared memory token buffer. Should be 128B alignment for optimal perf for TMA. + alignas(128) uint16_t intra_node_token_buffer[NUM_OF_STAGES][HIDDEN_DIM]; + // Shared memory mbarrier that protect token entry, 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). + alignas(8) uint64_t intra_node_mbarrier_buffer[NUM_OF_STAGES][2]; +}; + +template +struct combine_kernel_dynamic_shared_memory_buffer_t{}; + +template +struct combine_kernel_dynamic_shared_memory_buffer_t{ + // Shared memory token buffer for inter node red warp group G2S data movement. Should be 128B alignment for optimal perf for TMA. + alignas(128) uint16_t inter_node_token_G2S_buffer[NUM_OF_STAGES_G2S][HIDDEN_DIM]; + // Shared memory token buffer for inter node red warp group S2G data movement. Should be 128B alignment for optimal perf for TMA. + alignas(128) uint16_t inter_node_token_S2G_buffer[NUM_OF_STAGES_S2G][HIDDEN_DIM]; + + // Shared memory prob buffer for inter node red warp group G2S data movement. Should be 16B alignment so can be used with TMA. 128B is too strict. + // Only used in BW combine. + alignas(16) float inter_node_prob_G2S_buffer[NUM_OF_STAGES_G2S][NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; + // Shared memory prob buffer for inter node red warp group S2G data movement. Should be 16B alignment so can be used with TMA. 128B is too strict. + // Only used in BW combine. + alignas(16) float inter_node_prob_S2G_buffer[NUM_OF_STAGES_S2G][NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; + + // Shared memory mbarrier that protect inter node red warp group G2S token entry. 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). + alignas(8) uint64_t inter_node_mbarrier_G2S_buffer[NUM_OF_STAGES_G2S][2]; + + // Endgroup flag for each token entry in G2S buffer. true means that this token is the last token of a intra-node reduction group, otherwise not. + bool inter_node_flag_G2S_buffer[NUM_OF_STAGES_G2S]; +}; + +template +struct combine_kernel_dynamic_shared_memory_buffer_t{ + // Shared memory token buffer for inter node red warp group G2S data movement. Should be 128B alignment for optimal perf for TMA. + alignas(128) uint16_t inter_node_token_G2S_buffer[NUM_OF_STAGES_G2S][HIDDEN_DIM]; + // Shared memory token buffer for inter node red warp group S2G data movement. Should be 128B alignment for optimal perf for TMA. + alignas(128) uint16_t inter_node_token_S2G_buffer[NUM_OF_STAGES_S2G][HIDDEN_DIM]; + + // Shared memory mbarrier that protect inter node red warp group G2S token entry. 1st for producer->consumer, 2nd for consumer->producer. Should be 8B alignment(natural alignment). + alignas(8) uint64_t inter_node_mbarrier_G2S_buffer[NUM_OF_STAGES_G2S][2]; + + // Endgroup flag for each token entry in G2S buffer. true means that this token is the last token of a intra-node reduction group, otherwise not. + bool inter_node_flag_G2S_buffer[NUM_OF_STAGES_G2S]; +}; + + +// Data structure for kernel parameter for dispatch kernel. +template +struct dispatch_kernel_param_t{ + // Input buffers. These buffers are local buffers. + const TOKEN_DATA_TYPE* attn_input_token; + const float* attn_input_prob; // Needed by expert layer, so only valid in forward dispatch. + const float* attn_input_token_scaling_factor; // If input token is FP8 dtype, we need scaling factor for tokens. + // Output buffers. These buffers are both local and remote buffers. + TOKEN_DATA_TYPE* expert_output_token[NUM_OF_RANKS_PER_NODE]; + float* expert_output_prob[NUM_OF_RANKS_PER_NODE]; // Only valid in forward dispatch. + float* expert_output_scaling_factor[NUM_OF_RANKS_PER_NODE]; // Only valid for FP8 token type. + // Internal temp buffers. These buffers are local buffers. + const TOKEN_DATA_TYPE* rdma_inter_node_group_token; + const float* rdma_inter_node_group_prob; // Only valid in forward dispatch. + const float* rdma_inter_node_group_scaling_factor; // Only valid for FP8 token type. + uint64_t* rdma_inter_node_group_flags; // For RDMA Atomic flags. + uint32_t* intra_node_write_completion_flags; // For intra-node S2G write completion notification. + // Metadata buffers. These buffers are local buffers. + const bool* rdma_to_attn_map; + const bool* attn_to_rdma_map; + const int32_t* sparse_to_dense_map; + uint64_t* expected_rdma_flag_value; + uint32_t* expected_intra_node_flag_value; + int local_rank; + int node_rank; + // The number of token output by attn layer on a rank/GPU. + int num_of_tokens_per_rank; +}; + +// Data structure for kernel parameter for combine kernel. +template +struct combine_kernel_param_t{ + // Input buffers. These buffers are both local and remote buffers. + uint16_t* expert_input_token[NUM_OF_RANKS_PER_NODE]; + float* expert_input_prob[NUM_OF_RANKS_PER_NODE]; + // Output buffers. These buffers are local buffers. + uint16_t* attn_output_token; + float* attn_output_prob; + // Internal temp buffers. These buffers are local buffers. + uint16_t* rdma_intra_node_red_token; + float* rdma_intra_node_red_prob; + const uint16_t* rdma_inter_node_group_token; + const float* rdma_inter_node_group_prob; + uint64_t* rdma_inter_node_group_flags; + uint32_t* intra_node_write_completion_flags; // For intra-node src ready notification. + // Metadata buffers. These buffers are local buffers. + const bool* rdma_to_attn_map; + const bool* attn_to_rdma_map; + const int32_t* sparse_to_dense_map; + uint64_t* expected_rdma_flag_value; + uint32_t* expected_intra_node_flag_value; + int node_rank; + // The number of token output by attn layer on a rank/GPU. + int num_of_tokens_per_rank; +}; + +__device__ __forceinline__ bool elect_sync(uint32_t membermask) { + uint32_t is_elected; + asm volatile("{\n\t" + " .reg .pred p;\n\t" + " elect.sync _|p, %1;\n\t" + " selp.u32 %0, 1, 0, p;\n\t" + "}\n\t" + : "=r"(is_elected) + : "r"(membermask)); + return is_elected != 0; +} + +// Each CUDA block has sixteen named barriers numbered 0..15. +// __syncthreads(); will use the 0 named barriers, so we want to avoid that. +// We want to use 1 for intra-node reduction warp group, 2 for inter-node reduction warp group, 3 for RDMA warp group. +inline __device__ void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id = 0) { + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +} + +// Device function for intra-node G2S warp for dispatch kernel. There can be only 1 intra-node G2S warp per CUDA block! +template +inline __device__ void G2S_warp_group_device_function(const int node_rank, + const int num_of_tokens_per_rank, + const uint64_t* expected_flag_value, + const bool* rdma_to_attn_map, + const TOKEN_DATA_TYPE* attn_input_token, + const float* attn_input_prob, + const float* attn_input_token_scaling_factor, + const TOKEN_DATA_TYPE* rdma_inter_node_group_token, + const float* rdma_inter_node_group_prob, + const float* rdma_inter_node_group_scaling_factor, + const uint64_t* rdma_inter_node_group_flags, + SMEM_TYPE* smem_buffer_ptr) +{ + // Load rdma_to_attn_map using LDG.128. Each token will need 1 bool from this map. + using rdma_to_attn_map_load_t = uint4; + static_assert(sizeof(bool) == 1, "Bool is not 1 byte???"); + static_assert(NUM_OF_TOKENS_PER_CHUNK % sizeof(rdma_to_attn_map_load_t) == 0, "NUM_OF_TOKENS_PER_CHUNK must be multiple of rdma_to_attn_map_load_t."); + constexpr int NUM_OF_ROUTING_INFO_LOAD_ITER_PER_CHUNK = NUM_OF_TOKENS_PER_CHUNK / sizeof(rdma_to_attn_map_load_t); + constexpr int NUM_OF_TOKENS_PER_LOAD_ITER = sizeof(rdma_to_attn_map_load_t) / sizeof(bool); + + const int remainder_chunk_size = num_of_tokens_per_rank % NUM_OF_TOKENS_PER_CHUNK; + // How many chunks per rank. Including full chunks and the remainder chunk. + const int num_of_chunks_per_rank = ((num_of_tokens_per_rank - 1) / NUM_OF_TOKENS_PER_CHUNK) + 1; + // The rdma_to_attn_map need to be paded to multiple of rdma_to_attn_map_load_t per node. + // The largest size of rdma_to_attn_map_load_t allowed in all Hybrid-EP kernels are 16B(16 bools), so need to be paded to 16B per node. + // That means the size of rdma_to_attn_map should be rdma_to_attn_map_size_per_node * NUM_OF_NODES. + const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + int stage = 0; + uint32_t consumer_parity = 1; + + // Only 1 thread within the G2S warp will be active, other threads will just exit. + if(elect_sync(~0)){ + // Loop through all data chunk. Data(chunk) parallel between multiple CUDA blocks. + for(int i = blockIdx.x; i < num_of_chunks_per_rank; i += NUM_OF_BLOCKS){ + // How many rdma_to_attn load iter for this chunk. + int num_of_routing_info_load_iter_for_current_chunk; + // How many token for this chunk. + int current_chunk_size; + if(remainder_chunk_size != 0 && i == num_of_chunks_per_rank - 1){ + num_of_routing_info_load_iter_for_current_chunk = ((remainder_chunk_size - 1) / sizeof(rdma_to_attn_map_load_t)) + 1; + current_chunk_size = remainder_chunk_size; + }else{ + num_of_routing_info_load_iter_for_current_chunk = NUM_OF_ROUTING_INFO_LOAD_ITER_PER_CHUNK; + current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; + } + for(int j = 0; j < NUM_OF_NODES; j++){ + // The current node been processed. For each chunk id, node_id order is local_node, local_node - 1, local_node - 2, ......, local_node + 1 and will wrap around. + int node_id = node_rank >= j ? node_rank - j : node_rank + NUM_OF_NODES - j; + // The tile id within the rdma buffers for the current node id. Because rdma buffers only have NUM_OF_NODES - 1 tile. + int rdma_buffer_tile_id = node_id > node_rank ? node_id - 1 : node_id; + // Check if the chunk of this node is ready to be consumed. + // The chunks of local node is the attn input buffers, which are always ready to be consumed. + // The chunks of remote node is the rdma_inter_node_group buffers, which is produced by remote RDMA Write operation. Should poll the flag produced by remote RDMA Atomic FA before consumed. + if(node_id != node_rank){ + const uint64_t* flag_location = rdma_inter_node_group_flags + (rdma_buffer_tile_id * num_of_chunks_per_rank + i); + uint64_t rdma_flag = 0; + do{ + rdma_flag = 0; + // Need a strong system-scope load to observe external RDMA Atomic result. + asm volatile("ld.relaxed.sys.global.b64 %0, [%1];" + : "=l"(rdma_flag) + : "l"(__cvta_generic_to_global(flag_location)) + : "memory"); + }while(rdma_flag != *expected_flag_value); + } + // Load every token and its properties from Global to Shared. Only load tokens that is needed by this node. + const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + + (node_id * rdma_to_attn_map_size_per_node + i * NUM_OF_TOKENS_PER_CHUNK)); + const TOKEN_DATA_TYPE* token_load_base_addr; + const float* prob_load_base_addr; + const float* scaling_factor_load_base_addr; + // For other node's attn token and properties, read from rdma_inter_node_group buffers. + // For this node's attn token and properties, read from attn input buffers. + if(node_id != node_rank){ + int chunk_first_token_id = rdma_buffer_tile_id * num_of_tokens_per_rank + i * NUM_OF_TOKENS_PER_CHUNK; + token_load_base_addr = rdma_inter_node_group_token + chunk_first_token_id * HIDDEN_DIM; + if constexpr(FORWARD_DISPATCH){ + prob_load_base_addr = rdma_inter_node_group_prob + chunk_first_token_id * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE); + } + if constexpr(std::is_same::value){ + scaling_factor_load_base_addr = rdma_inter_node_group_scaling_factor + chunk_first_token_id * (HIDDEN_DIM / 128); + } + }else{ + int chunk_first_token_id = i * NUM_OF_TOKENS_PER_CHUNK; + token_load_base_addr = attn_input_token + chunk_first_token_id * HIDDEN_DIM; + if constexpr(FORWARD_DISPATCH){ + prob_load_base_addr = attn_input_prob + chunk_first_token_id * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES); + } + if constexpr(std::is_same::value){ + scaling_factor_load_base_addr = attn_input_token_scaling_factor + chunk_first_token_id * (HIDDEN_DIM / 128); + } + } + //#pragma unroll + for(int k = 0; k < num_of_routing_info_load_iter_for_current_chunk; k++){ + rdma_to_attn_map_load_t rdma_to_attn_map_data = rdma_to_attn_map_load_base_addr[k]; + #pragma unroll + for(int n = 0; n < NUM_OF_TOKENS_PER_LOAD_ITER; n++){ + int current_token_id = k * NUM_OF_TOKENS_PER_LOAD_ITER + n; + // If the current token is out-of-bound, then just end this load iter. + if(current_token_id >= current_chunk_size){ + break; + } + bool token_needed_by_this_node = *(reinterpret_cast(&rdma_to_attn_map_data) + n); + // If a token is needed by this node(i.e. any expert of this node), load the token and its properties to shared memory entry. + if(token_needed_by_this_node){ + // Wait until shared memory has free entry. + while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->intra_node_mbarrier_buffer[stage][1], consumer_parity)){} + // Issue TMA to load current token and its properties from global to shared memory. + uint32_t total_tx_size = 0; + // Load token. + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->intra_node_token_buffer[stage][0]), + reinterpret_cast(token_load_base_addr + (current_token_id * HIDDEN_DIM)), + (uint32_t)(HIDDEN_DIM * sizeof(TOKEN_DATA_TYPE)), + &smem_buffer_ptr->intra_node_mbarrier_buffer[stage][0]); + + total_tx_size += (uint32_t)(HIDDEN_DIM * sizeof(TOKEN_DATA_TYPE)); + + // Optionally load prob(Only in FW dispatch). + if constexpr(FORWARD_DISPATCH){ + // rdma_inter_node_group prob buffers and attn prob buffers will have different prob vec length. + const float* prob_load_token_addr; + if(node_id != node_rank){ + prob_load_token_addr = prob_load_base_addr + (current_token_id * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE)); + }else{ + prob_load_token_addr = prob_load_base_addr + (current_token_id * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES)) + + (node_rank * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE)); + } + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->intra_node_prob_buffer[stage][0]), + reinterpret_cast(prob_load_token_addr), + (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float)), + &smem_buffer_ptr->intra_node_mbarrier_buffer[stage][0]); + + total_tx_size += (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float)); + } + + // Optionally load scaling factor(Only for FP8 token). + if constexpr(std::is_same::value){ + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->intra_node_scaling_factor_buffer[stage][0]), + reinterpret_cast(scaling_factor_load_base_addr + (current_token_id * (HIDDEN_DIM / 128))), + (uint32_t)((HIDDEN_DIM / 128) * sizeof(float)), + &smem_buffer_ptr->intra_node_mbarrier_buffer[stage][0]); + + total_tx_size += (uint32_t)((HIDDEN_DIM / 128) * sizeof(float)); + } + + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_release, + cuda::ptx::scope_cta, + cuda::ptx::space_shared, + &smem_buffer_ptr->intra_node_mbarrier_buffer[stage][0], + total_tx_size); + + stage += 1; + if(stage == NUM_OF_STAGES){ + stage = 0; + consumer_parity ^= 1; + } + } + } + } + } + } + } +} + +// Device function for intra-node S2G warp for dispatch kernel. There can be only 1 intra-node S2G warp per CUDA block! +template +inline __device__ void S2G_warp_group_device_function(const int local_rank, + const int node_rank, + const int num_of_tokens_per_rank, + const bool* rdma_to_attn_map, + const int32_t* sparse_to_dense_map, + TOKEN_DATA_TYPE* const* remote_expert_output_token, + float* const* remote_expert_output_prob, + float* const* remote_expert_output_scaling_factor, + SMEM_TYPE* smem_buffer_ptr) +{ + // Load rdma_to_attn_map using LDG.128. Each token will need 1 bool from this map. + using rdma_to_attn_map_load_t = uint4; + static_assert(sizeof(bool) == 1, "Bool is not 1 byte???"); + static_assert(NUM_OF_TOKENS_PER_CHUNK % sizeof(rdma_to_attn_map_load_t) == 0, "NUM_OF_TOKENS_PER_CHUNK must be multiple of rdma_to_attn_map_load_t."); + constexpr int NUM_OF_ROUTING_INFO_LOAD_ITER_PER_CHUNK = NUM_OF_TOKENS_PER_CHUNK / sizeof(rdma_to_attn_map_load_t); + constexpr int NUM_OF_TOKENS_PER_LOAD_ITER = sizeof(rdma_to_attn_map_load_t) / sizeof(bool); + + // Load sparse_to_dense_map according to the NUM_OF_RANKS_PER_NODE. + using sparse_to_dense_map_load_t = Copy_t; + constexpr int NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_INPUT_TOKEN = (NUM_OF_RANKS_PER_NODE * sizeof(int32_t)) / sizeof(sparse_to_dense_map_load_t); + constexpr int NUM_OF_OUTPUT_TOKENS_PER_LOAD_ITER = sizeof(sparse_to_dense_map_load_t) / sizeof(int32_t); + + const int remainder_chunk_size = num_of_tokens_per_rank % NUM_OF_TOKENS_PER_CHUNK; + // How many chunks per rank. Including full chunks and the remainder chunk. + const int num_of_chunks_per_rank = ((num_of_tokens_per_rank - 1) / NUM_OF_TOKENS_PER_CHUNK) + 1; + // The rdma_to_attn_map need to be paded to multiple of rdma_to_attn_map_load_t per node. + // The largest size of rdma_to_attn_map_load_t allowed in all Hybrid-EP kernels are 16B(16 bools), so need to be paded to 16B per node. + // That means the size of rdma_to_attn_map should be rdma_to_attn_map_size_per_node * NUM_OF_NODES. + const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + int stage = 0; + uint32_t producer_parity = 0; + + // Only 1 thread within the S2G warp will be active, other threads will just exit. + if(elect_sync(~0)){ + // Loop through all data chunk. Data(chunk) parallel between multiple CUDA blocks. + for(int i = blockIdx.x; i < num_of_chunks_per_rank; i += NUM_OF_BLOCKS){ + // How many rdma_to_attn load iter for this chunk. + int num_of_routing_info_load_iter_for_current_chunk; + // How many token for this chunk. + int current_chunk_size; + if(remainder_chunk_size != 0 && i == num_of_chunks_per_rank - 1){ + num_of_routing_info_load_iter_for_current_chunk = ((remainder_chunk_size - 1) / sizeof(rdma_to_attn_map_load_t)) + 1; + current_chunk_size = remainder_chunk_size; + }else{ + num_of_routing_info_load_iter_for_current_chunk = NUM_OF_ROUTING_INFO_LOAD_ITER_PER_CHUNK; + current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; + } + for(int j = 0; j < NUM_OF_NODES; j++){ + // The current node been processed. For each chunk id, node_id order is local_node, local_node - 1, local_node - 2, ......, local_node + 1 and will wrap around. + int node_id = node_rank >= j ? node_rank - j : node_rank + NUM_OF_NODES - j; + // Store every token and its properties from Shared to Global. Only store tokens that is needed by this node. + const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + + (node_id * rdma_to_attn_map_size_per_node + i * NUM_OF_TOKENS_PER_CHUNK)); + + const int32_t* sparse_to_dense_map_load_base_addr = sparse_to_dense_map + (node_id * num_of_tokens_per_rank + i * NUM_OF_TOKENS_PER_CHUNK) * NUM_OF_RANKS_PER_NODE; + + //#pragma unroll + for(int k = 0; k < num_of_routing_info_load_iter_for_current_chunk; k++){ + rdma_to_attn_map_load_t rdma_to_attn_map_data = rdma_to_attn_map_load_base_addr[k]; + #pragma unroll + for(int n = 0; n < NUM_OF_TOKENS_PER_LOAD_ITER; n++){ + int current_token_id = k * NUM_OF_TOKENS_PER_LOAD_ITER + n; + // If the current token is out-of-bound, then just end this load iter. + if(current_token_id >= current_chunk_size){ + break; + } + bool token_needed_by_this_node = *(reinterpret_cast(&rdma_to_attn_map_data) + n); + if(token_needed_by_this_node){ + const sparse_to_dense_map_load_t* sparse_to_dense_map_load_addr = reinterpret_cast + (sparse_to_dense_map_load_base_addr + (k * NUM_OF_TOKENS_PER_LOAD_ITER + n) * NUM_OF_RANKS_PER_NODE); + // Wait until token entry within the shared memory has been produced. + while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->intra_node_mbarrier_buffer[stage][0], producer_parity)){} + + // This token entry will be multicast to all ranks within this node which need this token and its properties. + // The current implementation do the multicast by issue each unicast separately(we call it a unicast group). If NVLS can be used, we should use it here. + #pragma unroll + for(int m = 0; m < NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_INPUT_TOKEN; m++){ + // Load sparse_to_dense_map. + sparse_to_dense_map_load_t sparse_to_dense_map_data = sparse_to_dense_map_load_addr[m]; + #pragma unroll + for(int t = 0; t < NUM_OF_OUTPUT_TOKENS_PER_LOAD_ITER; t++){ + int32_t output_buffer_index = *(reinterpret_cast(&sparse_to_dense_map_data) + t); + // Only unicast to this rank if it need the current token. + if(output_buffer_index != -1){ + int remote_rank_id = m * NUM_OF_OUTPUT_TOKENS_PER_LOAD_ITER + t; + // Store the token from shared to remote global. + TOKEN_DATA_TYPE* remote_token_addr = remote_expert_output_token[remote_rank_id] + (output_buffer_index * HIDDEN_DIM); + cuda::ptx::cp_async_bulk(cuda::ptx::space_global, + cuda::ptx::space_shared, + reinterpret_cast(remote_token_addr), + reinterpret_cast(&smem_buffer_ptr->intra_node_token_buffer[stage][0]), + (uint32_t)(HIDDEN_DIM * sizeof(TOKEN_DATA_TYPE))); + + // Store the prob from shared to remote global for FW dispatch. + if constexpr(FORWARD_DISPATCH){ + float* remote_prob_addr = remote_expert_output_prob[remote_rank_id] + (output_buffer_index * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE)); + cuda::ptx::cp_async_bulk(cuda::ptx::space_global, + cuda::ptx::space_shared, + reinterpret_cast(remote_prob_addr), + reinterpret_cast(&smem_buffer_ptr->intra_node_prob_buffer[stage][0]), + (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float))); + + } + + // Store the scaling factor from shared to remote global for FP8 tokens. + if constexpr(std::is_same::value){ + float* remote_scaling_factor_addr = remote_expert_output_scaling_factor[remote_rank_id] + (output_buffer_index * (HIDDEN_DIM / 128)); + cuda::ptx::cp_async_bulk(cuda::ptx::space_global, + cuda::ptx::space_shared, + reinterpret_cast(remote_scaling_factor_addr), + reinterpret_cast(&smem_buffer_ptr->intra_node_scaling_factor_buffer[stage][0]), + (uint32_t)((HIDDEN_DIM / 128) * sizeof(float))); + + } + } + } + } + // Commit the previous issued S2G TMA instructions for the same shared memory token entry to a bulk async copy group. + cuda::ptx::cp_async_bulk_commit_group(); + // Wait for previous commited TMA instructions to finish reading the shared memory, so the shared memory can be reused by the producer. + cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t<0>{}); + // Notify the producer warp to load next token entry to the shared memory as the shared memory can be reused. + cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->intra_node_mbarrier_buffer[stage][1]); + // Goto next token entry in shared memory. + stage += 1; + if(stage == NUM_OF_STAGES){ + stage = 0; + producer_parity ^= 1; + } + } + } + } + } + } + // All S2G TMA operations for all tokens assigned to this CUDA block have been issued. + // If the synchronization for output buffer for current rank is on host-side(i.e. cudaStreamSynchronize + MPI_Barrier etc.), then all CUDA block can exit. + // The result of output buffer for current rank is not ready when the dipatch kernel is completed, a Barrier within the node is needed. + // Otherwise, the S2G warp of the first CUDA block must wait for all writes to the local output buffer complete before exit. So kernel completion means the output buffers for current rank is ready. + /*if constexpr(DEVICE_SIDE_SYNC){ + // Wait for all previous issued TMA instructions to complete writing to remote global memory. + cuda::ptx::cp_async_bulk_wait_group(cuda::ptx::n32_t<0>{}); + // Atomically add 1 to the remote flag on remote ranks within the node to notify the remote rank. + for(int i = 0; i < NUM_OF_RANKS_PER_NODE; i++){ + // red.release.sys.global.add.u32 [a], 1; + asm volatile("red.release.sys.global.add.u32 [%0], %1;" + : + : "l"(__cvta_generic_to_global(&remote_write_completion_flags[i][local_rank])) , "n"(1) + : "memory"); + } + if(blockIdx.x == 0){ + // Wait for all flags on local rank to reach the expected value before exit. + for(int i = 0; i < NUM_OF_RANKS_PER_NODE; i++){ + uint32_t flag_data = 0; + do{ + flag_data = 0; + // Need a strong system-scope load to observe peer ranks' Atomic result. + asm volatile("ld.relaxed.sys.global.u32 %0, [%1];" + : "=r"(flag_data) + : "l"(__cvta_generic_to_global(&remote_write_completion_flags[local_rank][i])) + : "memory"); + }while(flag_data != expected_flag_value); + } + } + }*/ + } +} + +// Device function for intra-node G2S warp for combine kernel. There can be only 1 such warp per CUDA block! +template +inline __device__ void intra_node_G2S_warp_group_device_function(const int node_rank, + const int num_of_tokens_per_rank, + const bool* rdma_to_attn_map, + const int32_t* sparse_to_dense_map, + uint16_t* const* remote_expert_input_token, + float* const* remote_expert_input_prob, + SMEM_TYPE* smem_buffer_ptr) +{ + // Load rdma_to_attn_map using LDG.128. Each dst token will need 1 bool from this map. + using rdma_to_attn_map_load_t = uint4; + static_assert(sizeof(bool) == 1, "Bool is not 1 byte???"); + static_assert(NUM_OF_TOKENS_PER_CHUNK % sizeof(rdma_to_attn_map_load_t) == 0, "NUM_OF_TOKENS_PER_CHUNK must be multiple of rdma_to_attn_map_load_t."); + constexpr int NUM_OF_RDMA_TO_ATTN_LOAD_ITER_PER_CHUNK = NUM_OF_TOKENS_PER_CHUNK / sizeof(rdma_to_attn_map_load_t); + constexpr int NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER = sizeof(rdma_to_attn_map_load_t) / sizeof(bool); + + // Load sparse_to_dense_map according to the NUM_OF_RANKS_PER_NODE. + using sparse_to_dense_map_load_t = Copy_t; + constexpr int NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_OUTPUT_TOKEN = (NUM_OF_RANKS_PER_NODE * sizeof(int32_t)) / sizeof(sparse_to_dense_map_load_t); + constexpr int NUM_OF_INPUT_TOKENS_PER_LOAD_ITER = sizeof(sparse_to_dense_map_load_t) / sizeof(int32_t); + + // The intra node reduction warp group of each CUDA block produce a chunk at a time. + // The chunk order is: first produce the same chunk id for all other nodes id, then produce following chunk id. + // (i.e. chunk 0 for node + 1, node + 2, ... node - 1, then chunk 1 for node + 1, node + 2, ... node - 1) + // The RDMA warp group of a CUDA block will consume the chunk by the same order. So each CUDA block will produce and consume the same set of chunks id. + // The reason to distribute chunk in this order is that the inter-node reduction will need the same chunk id from all other nodes, so we need to produce and send chunks in this order. + + const int remainder_chunk_size = num_of_tokens_per_rank % NUM_OF_TOKENS_PER_CHUNK; + // How many chunks per rank. Including full chunks and the remainder chunk. + const int num_of_chunks_per_rank = ((num_of_tokens_per_rank - 1) / NUM_OF_TOKENS_PER_CHUNK) + 1; + // Total number of chunks to produce for RDMA warps to consume. + const int total_num_of_chunks = (NUM_OF_NODES - 1) * num_of_chunks_per_rank; + // The rdma_to_attn_map need to be paded to multiple of rdma_to_attn_map_load_t per node. + // The largest size of rdma_to_attn_map_load_t allowed in all Hybrid-EP kernels are 16B(16 bools), so need to be paded to 16B per node. + // That means the size of rdma_to_attn_map should be rdma_to_attn_map_size_per_node * NUM_OF_NODES. + const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + // Token stage id and phase. + int token_stage = 0; + uint32_t token_consumer_parity = 1; + + // Only 1 thread within the intra-node G2S warp will be active, other threads will just exit. + if(elect_sync(~0)){ + // Iterate through all chunks assigned to this block. + for(int i = blockIdx.x; i < total_num_of_chunks; i += NUM_OF_BLOCKS){ + // Which node this chunk will be sent to. + int node_id = (i % (NUM_OF_NODES - 1) + (node_rank + 1)) % NUM_OF_NODES; + // What is the chunk id of this chunk for the node it will be sent to. + int chunk_id = i / (NUM_OF_NODES - 1); + // How many rdma_to_attn load iter for this chunk. + int num_of_routing_info_load_iter_for_current_chunk; + // How many token for this chunk. + int current_chunk_size; + if(remainder_chunk_size != 0 && chunk_id == num_of_chunks_per_rank - 1){ + num_of_routing_info_load_iter_for_current_chunk = ((remainder_chunk_size - 1) / sizeof(rdma_to_attn_map_load_t)) + 1; + current_chunk_size = remainder_chunk_size; + }else{ + num_of_routing_info_load_iter_for_current_chunk = NUM_OF_RDMA_TO_ATTN_LOAD_ITER_PER_CHUNK; + current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; + } + + const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + + (node_id * rdma_to_attn_map_size_per_node + chunk_id * NUM_OF_TOKENS_PER_CHUNK)); + + const int32_t* sparse_to_dense_map_load_base_addr = sparse_to_dense_map + (node_id * num_of_tokens_per_rank + chunk_id * NUM_OF_TOKENS_PER_CHUNK) * NUM_OF_RANKS_PER_NODE; + + // Iterate through all dst tokens within this chunk. + for(int j = 0; j < num_of_routing_info_load_iter_for_current_chunk; j++){ + rdma_to_attn_map_load_t rdma_to_attn_map_data = rdma_to_attn_map_load_base_addr[j]; + #pragma unroll + for(int k = 0; k < NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER; k++){ + int current_token_id = j * NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER + k; + // If the current token is out-of-bound, then just end this load iter. + if(current_token_id >= current_chunk_size){ + break; + } + // Check whether this dst token is needed by this node. If not needed, just skip. + bool token_needed_by_this_node = *(reinterpret_cast(&rdma_to_attn_map_data) + k); + // If this dst token is needed by this node, load the sparse_to_dense map and load the src token for this dst token. + if(token_needed_by_this_node){ + const sparse_to_dense_map_load_t* sparse_to_dense_map_load_addr = reinterpret_cast + (sparse_to_dense_map_load_base_addr + (j * NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER + k) * NUM_OF_RANKS_PER_NODE); + // Load sparse_to_dense map for this dst token(i.e. a row in sparse_to_dense map). + sparse_to_dense_map_load_t sparse_to_dense_map_data[NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_OUTPUT_TOKEN]; + // First load sparse_to_dense map and decide the last src token within this row. + int last_src_token_id = 0; + #pragma unroll + for(int n = 0; n < NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_OUTPUT_TOKEN; n++){ + sparse_to_dense_map_data[n] = sparse_to_dense_map_load_addr[n]; + #pragma unroll + for(int m = 0; m < NUM_OF_INPUT_TOKENS_PER_LOAD_ITER; m++){ + int32_t sparse_to_dense_map_value = *(reinterpret_cast(&sparse_to_dense_map_data[n]) + m); + if(sparse_to_dense_map_value != -1){ + last_src_token_id = n * NUM_OF_INPUT_TOKENS_PER_LOAD_ITER + m; + } + } + } + + // Then issue all G2S TMA for this row. + #pragma unroll + for(int n = 0; n < NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_OUTPUT_TOKEN; n++){ + #pragma unroll + for(int m = 0; m < NUM_OF_INPUT_TOKENS_PER_LOAD_ITER; m++){ + int32_t sparse_to_dense_map_value = *(reinterpret_cast(&sparse_to_dense_map_data[n]) + m); + if(sparse_to_dense_map_value != -1){ + int current_src_token_id = n * NUM_OF_INPUT_TOKENS_PER_LOAD_ITER + m; + // Wait until current token entry within the shared memory has been consumed. + while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->intra_node_mbarrier_G2S_buffer[token_stage][1], token_consumer_parity)){} + + uint32_t total_tx_size = 0; + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->intra_node_token_G2S_buffer[token_stage][0]), + reinterpret_cast(remote_expert_input_token[current_src_token_id] + (sparse_to_dense_map_value * HIDDEN_DIM)), + (uint32_t)(HIDDEN_DIM * sizeof(uint16_t)), + &smem_buffer_ptr->intra_node_mbarrier_G2S_buffer[token_stage][0]); + + total_tx_size += (uint32_t)(HIDDEN_DIM * sizeof(uint16_t)); + + if constexpr(BACKWARD_COMBINE){ + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->intra_node_prob_G2S_buffer[token_stage][0]), + reinterpret_cast(remote_expert_input_prob[current_src_token_id] + (sparse_to_dense_map_value * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE))), + (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float)), + &smem_buffer_ptr->intra_node_mbarrier_G2S_buffer[token_stage][0]); + + total_tx_size += (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float)); + } + + if(current_src_token_id == last_src_token_id){ + smem_buffer_ptr->intra_node_flag_G2S_buffer[token_stage] = true; + } + else{ + smem_buffer_ptr->intra_node_flag_G2S_buffer[token_stage] = false; + } + + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_release, + cuda::ptx::scope_cta, + cuda::ptx::space_shared, + &smem_buffer_ptr->intra_node_mbarrier_G2S_buffer[token_stage][0], + total_tx_size); + + // Goto next token entry in shared memory. + token_stage += 1; + if(token_stage == NUM_OF_STAGES_G2S){ + token_stage = 0; + token_consumer_parity ^= 1; + } + } + } + } + } + } + } + } + } +} + +// Device function for intra-node reduction warp group for combine kernel. +template +inline __device__ void intra_node_red_warp_group_device_function(const int node_rank, + const int num_of_tokens_per_rank, + const bool* rdma_to_attn_map, + uint16_t* rdma_intra_node_red_token, + float* rdma_intra_node_red_prob, + SMEM_TYPE* smem_buffer_ptr) +{ + // Load rdma_to_attn_map using LDG.128. Each dst token will need 1 bool from this map. + using rdma_to_attn_map_load_t = uint4; + static_assert(sizeof(bool) == 1, "Bool is not 1 byte???"); + static_assert(NUM_OF_TOKENS_PER_CHUNK % sizeof(rdma_to_attn_map_load_t) == 0, "NUM_OF_TOKENS_PER_CHUNK must be multiple of rdma_to_attn_map_load_t."); + constexpr int NUM_OF_RDMA_TO_ATTN_LOAD_ITER_PER_CHUNK = NUM_OF_TOKENS_PER_CHUNK / sizeof(rdma_to_attn_map_load_t); + constexpr int NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER = sizeof(rdma_to_attn_map_load_t) / sizeof(bool); + + // Load sparse_to_dense_map according to the NUM_OF_RANKS_PER_NODE. + /*using sparse_to_dense_map_load_t = Copy_t; + constexpr int NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_OUTPUT_TOKEN = (NUM_OF_RANKS_PER_NODE * sizeof(int32_t)) / sizeof(sparse_to_dense_map_load_t); + constexpr int NUM_OF_INPUT_TOKENS_PER_LOAD_ITER = sizeof(sparse_to_dense_map_load_t) / sizeof(int32_t);*/ + + // Processing token using BF16x2 intruction, HIDDEN_DIM must be multiple of 2. + static_assert(HIDDEN_DIM % 2 == 0, "HIDDEN_DIM must be multiple of 2."); + constexpr int NUM_OF_ELEMENT_PER_THREAD = (HIDDEN_DIM / 2) / INTRA_NODE_RED_GROUP::size(); + // Processing prob using fp32. + constexpr int NUM_OF_PROB_VEC_ELEMENT_PER_THREAD = ((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE - 1) / INTRA_NODE_RED_GROUP::size()) + 1; + //static_assert(INTRA_NODE_RED_GROUP::size() >= NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE, "The size of intra-node reduction warp group must not be smaller than prob size."); + + // The intra node reduction warp group of each CUDA block produce a chunk at a time. + // The chunk order is: first produce the same chunk id for all other nodes id, then produce following chunk id. + // (i.e. chunk 0 for node + 1, node + 2, ... node - 1, then chunk 1 for node + 1, node + 2, ... node - 1) + // The RDMA warp group of a CUDA block will consume the chunk by the same order. So each CUDA block will produce and consume the same set of chunks id. + // The reason to distribute chunk in this order is that the inter-node reduction will need the same chunk id from all other nodes, so we need to produce and send chunks in this order. + + const int remainder_chunk_size = num_of_tokens_per_rank % NUM_OF_TOKENS_PER_CHUNK; + // How many chunks per rank. Including full chunks and the remainder chunk. + const int num_of_chunks_per_rank = ((num_of_tokens_per_rank - 1) / NUM_OF_TOKENS_PER_CHUNK) + 1; + // Total number of chunks to produce for RDMA warps to consume. + const int total_num_of_chunks = (NUM_OF_NODES - 1) * num_of_chunks_per_rank; + // The rdma_to_attn_map need to be paded to multiple of rdma_to_attn_map_load_t per node. + // The largest size of rdma_to_attn_map_load_t allowed in all Hybrid-EP kernels are 16B(16 bools), so need to be paded to 16B per node. + // That means the size of rdma_to_attn_map should be rdma_to_attn_map_size_per_node * NUM_OF_NODES. + const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + // Src token stage id and phase. + int token_stage = 0; + uint32_t token_producer_parity = 0; + + // Dst token stage id. + int dst_token_stage = 0; + + // Whether there are S2G TMA operations of a previous chunk's dst token in-flight(unfinished). + bool outstanding_in_flight_chunk = false; + + // rdma_remote_node_id and chunk_id for previous chunk. + int last_chunk_id; + int last_rdma_remote_node_id; + + // Iterate through all chunks assigned to this block. + for(int i = blockIdx.x; i < total_num_of_chunks; i += NUM_OF_BLOCKS){ + // Which node this chunk will be sent to. + int node_id = (i % (NUM_OF_NODES - 1) + (node_rank + 1)) % NUM_OF_NODES; + // What is the chunk id of this chunk for the node it will be sent to. + int chunk_id = i / (NUM_OF_NODES - 1); + // Which node this chunk belongs to in output rdma reduction buffers. + int rdma_remote_node_id = node_id > node_rank ? node_id - 1 : node_id; + // How many rdma_to_attn load iter for this chunk. + int num_of_routing_info_load_iter_for_current_chunk; + // How many token for this chunk. + int current_chunk_size; + if(remainder_chunk_size != 0 && chunk_id == num_of_chunks_per_rank - 1){ + num_of_routing_info_load_iter_for_current_chunk = ((remainder_chunk_size - 1) / sizeof(rdma_to_attn_map_load_t)) + 1; + current_chunk_size = remainder_chunk_size; + }else{ + num_of_routing_info_load_iter_for_current_chunk = NUM_OF_RDMA_TO_ATTN_LOAD_ITER_PER_CHUNK; + current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; + } + + const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + + (node_id * rdma_to_attn_map_size_per_node + chunk_id * NUM_OF_TOKENS_PER_CHUNK)); + + uint16_t* rdma_intra_node_red_token_base_ptr = rdma_intra_node_red_token + (rdma_remote_node_id * num_of_tokens_per_rank + chunk_id * NUM_OF_TOKENS_PER_CHUNK) * HIDDEN_DIM; + float* rdma_intra_node_red_prob_base_ptr; + if constexpr(BACKWARD_COMBINE){ + rdma_intra_node_red_prob_base_ptr = rdma_intra_node_red_prob + + (rdma_remote_node_id * num_of_tokens_per_rank + chunk_id * NUM_OF_TOKENS_PER_CHUNK) * + (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE); + } + + // How many dst token entry of current chunk have been in-flight. + int additional_in_flight_s2g = 0; + // Iterate through all dst tokens within this chunk. + for(int j = 0; j < num_of_routing_info_load_iter_for_current_chunk; j++){ + rdma_to_attn_map_load_t rdma_to_attn_map_data = rdma_to_attn_map_load_base_addr[j]; + #pragma unroll + for(int k = 0; k < NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER; k++){ + // Check whether there is a previous chunk's dst token S2G in-flight and also current chunk already has NUM_OF_ADDITIONAL_IN_FLIGHT_S2G dst token S2G in-flight. + // If so, wait for previous chunk's S2G finish and notify the RDMA warp groups. + if(outstanding_in_flight_chunk && (additional_in_flight_s2g == NUM_OF_ADDITIONAL_IN_FLIGHT_S2G)){ + if(INTRA_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + // Wait for previous chunk's S2G finish. + cuda::ptx::cp_async_bulk_wait_group(cuda::ptx::n32_t{}); + // Notify the rdma warp group. + if constexpr(NUM_OF_NODES != 1){ + cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->intra_node_to_rdma_mbarrier_buffer[last_rdma_remote_node_id][last_chunk_id]); + } + } + } + outstanding_in_flight_chunk = false; + } + int current_token_id = j * NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER + k; + // If the current token is out-of-bound, then just end this load iter. + if(current_token_id >= current_chunk_size){ + break; + } + // Check whether this dst token is needed by this node. If not needed, just skip. + bool token_needed_by_this_node = *(reinterpret_cast(&rdma_to_attn_map_data) + k); + // If this dst token is needed by this node, which means this dst token will have at least 1 src token within the shread memory. + // Then, load the src token for this dst token from shared memory and accumulate it to the accumulator. + if(token_needed_by_this_node){ + // Accumulator for this dst token. Token must be accumulated in FP32. + float2 acc_token_fp32[NUM_OF_ELEMENT_PER_THREAD]; + // Optional Accumulator for this dst token prob. + float acc_prob[NUM_OF_PROB_VEC_ELEMENT_PER_THREAD]; + // End reduction group flag. + bool last_src_token = false; + // Init accumulator. + #pragma unroll + for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ + acc_token_fp32[n].x = 0.0f; + acc_token_fp32[n].y = 0.0f; + } + #pragma unroll + for(int n = 0; n < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; n++){ + acc_prob[n] = 0.0f; + } + + // Continue loading src token for this dst token and reduce them to accumulator until all src token for this dst token have been accumulated. + do{ + // Base address for current token and prob(optional) in shared memory. + __nv_bfloat162* load_token_base_ptr = reinterpret_cast<__nv_bfloat162*>(&smem_buffer_ptr->intra_node_token_G2S_buffer[token_stage][0]); + float* load_prob_base_ptr; + if constexpr(BACKWARD_COMBINE){ + load_prob_base_ptr = &smem_buffer_ptr->intra_node_prob_G2S_buffer[token_stage][0]; + } + + // Wait until current src token ready in shared memory. + if(INTRA_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->intra_node_mbarrier_G2S_buffer[token_stage][0], token_producer_parity)){} + } + } + arrive_and_wait(INTRA_NODE_RED_GROUP::size(), 1); + + // Accumulate token and prob(optional). + #pragma unroll + for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ + int element_id = (n * INTRA_NODE_RED_GROUP::size()) + INTRA_NODE_RED_GROUP::thread_rank(); + __nv_bfloat162 src_data = load_token_base_ptr[element_id]; + float2 src_data_fp32 = __bfloat1622float2(src_data); + acc_token_fp32[n].x += src_data_fp32.x; + acc_token_fp32[n].y += src_data_fp32.y; + } + + if constexpr(BACKWARD_COMBINE){ + #pragma unroll + for(int n = 0; n < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; n++){ + int element_id = INTRA_NODE_RED_GROUP::thread_rank() + n * INTRA_NODE_RED_GROUP::size(); + if(element_id < NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE){ + float src_data = load_prob_base_ptr[element_id]; + acc_prob[n] += src_data; + } + } + } + + // Check flag for last src token. + last_src_token = smem_buffer_ptr->intra_node_flag_G2S_buffer[token_stage]; + + // Make sure all warp group have finished loading the token entry and accumulate it to the register accumulator. + // Then notify the producer warp to load next token entry to the shared memory as the shared memory can be reused. + arrive_and_wait(INTRA_NODE_RED_GROUP::size(), 1); + if(INTRA_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->intra_node_mbarrier_G2S_buffer[token_stage][1]); + } + } + + // Goto next src token entry. + token_stage += 1; + if(token_stage == NUM_OF_STAGES_G2S){ + token_stage = 0; + token_producer_parity ^= 1; + } + + }while(!last_src_token); + + // Base address for current dst token and prob(optional) in shared memory. + __nv_bfloat162* store_token_base_ptr = reinterpret_cast<__nv_bfloat162*>(&smem_buffer_ptr->intra_node_token_S2G_buffer[dst_token_stage][0]); + float* store_prob_base_ptr; + if constexpr(BACKWARD_COMBINE){ + store_prob_base_ptr = &smem_buffer_ptr->intra_node_prob_S2G_buffer[dst_token_stage][0]; + } + + // Let the TMA thread to wait for previously issued TMA S2G operations finish reading this entry. + if(INTRA_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t{}); + } + } + // Make sure all threads within the red warp group have wait for previously issued TMA S2G operations finish reading this entry before storing new data to this entry. + arrive_and_wait(INTRA_NODE_RED_GROUP::size(), 1); + + // Store the token. + #pragma unroll + for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ + int element_id = (n * INTRA_NODE_RED_GROUP::size()) + INTRA_NODE_RED_GROUP::thread_rank(); + // Convert accumulated token back to BF16 and store the result back to shared memory token entry. + store_token_base_ptr[element_id] = __float22bfloat162_rn(acc_token_fp32[n]); + } + + // Store the prob(optional). + if constexpr(BACKWARD_COMBINE){ + #pragma unroll + for(int n = 0; n < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; n++){ + int element_id = INTRA_NODE_RED_GROUP::thread_rank() + n * INTRA_NODE_RED_GROUP::size(); + if(element_id < NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE){ + store_prob_base_ptr[element_id] = acc_prob[n]; + } + } + } + + // Make sure the shared memory stored by current thread is visible by async proxy. + cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); + + // Make sure all threads within the red warp group have finished storing the current token entry and making it visible to async proxy. + arrive_and_wait(INTRA_NODE_RED_GROUP::size(), 1); + + // Let the TMA thread to issue S2G TMA operations for current token entry. + if(INTRA_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + uint16_t* current_token_addr = rdma_intra_node_red_token_base_ptr + (j * NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER + k) * HIDDEN_DIM; + // Store the token from shared to global. + cuda::ptx::cp_async_bulk(cuda::ptx::space_global, + cuda::ptx::space_shared, + reinterpret_cast(current_token_addr), + reinterpret_cast(&smem_buffer_ptr->intra_node_token_S2G_buffer[dst_token_stage][0]), + (uint32_t)(HIDDEN_DIM * sizeof(uint16_t))); + + // Store the prob from shared to global(Optional). + if constexpr(BACKWARD_COMBINE){ + float* current_prob_addr = rdma_intra_node_red_prob_base_ptr + (j * NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER + k) * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE); + cuda::ptx::cp_async_bulk(cuda::ptx::space_global, + cuda::ptx::space_shared, + reinterpret_cast(current_prob_addr), + reinterpret_cast(&smem_buffer_ptr->intra_node_prob_S2G_buffer[dst_token_stage][0]), + (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float))); + + } + // Commit S2G TMA operations for this dst token into a bulk async copy group. + cuda::ptx::cp_async_bulk_commit_group(); + } + } + + // Goto next dst token entry. + dst_token_stage += 1; + if(dst_token_stage == NUM_OF_STAGES_S2G){ + dst_token_stage = 0; + } + + // Another token entry's S2G in-flight. + additional_in_flight_s2g += 1; + } + } + } + // If the current chunk does not have NUM_OF_ADDITIONAL_IN_FLIGHT_S2G dst token entry in-flight, which is possible of rdma_to_attn map is really sparse. + // We need to wait for both previous and current chunks' dst token entry S2G to finish and notify the RDMA warp group. + if(outstanding_in_flight_chunk){ + if(INTRA_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + // Wait for all previous chunk's(i.e. previous and current chunk) S2G finish. + cuda::ptx::cp_async_bulk_wait_group(cuda::ptx::n32_t<0>{}); + // Notify the rdma warp group. + if constexpr(NUM_OF_NODES != 1){ + cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->intra_node_to_rdma_mbarrier_buffer[last_rdma_remote_node_id][last_chunk_id]); + cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->intra_node_to_rdma_mbarrier_buffer[rdma_remote_node_id][chunk_id]); + } + } + } + outstanding_in_flight_chunk = false; + }else{ // Otherwise, the current chunks is in-flight. + outstanding_in_flight_chunk = true; + } + + // Update last chunk's id. + last_rdma_remote_node_id = rdma_remote_node_id; + last_chunk_id = chunk_id; + } + + // When all chunks have been processed, we need to check whether the last chunk is still in-flight. + // If so, wait for it and notify RDMA warp group. + if(outstanding_in_flight_chunk){ + if(INTRA_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + // Wait for the last chunk's S2G finish. + cuda::ptx::cp_async_bulk_wait_group(cuda::ptx::n32_t<0>{}); + // Notify the rdma warp group. + if constexpr(NUM_OF_NODES != 1){ + cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->intra_node_to_rdma_mbarrier_buffer[last_rdma_remote_node_id][last_chunk_id]); + } + } + } + } +} + +// Device function for inter-node G2S warp for combine kernel. There can be only 1 such warp per CUDA block! +template +inline __device__ void inter_node_G2S_warp_group_device_function(const int node_rank, + const int num_of_tokens_per_rank, + const uint64_t* expected_flag_value, + const bool* rdma_to_attn_map, + const bool* attn_to_rdma_map, + const int32_t* sparse_to_dense_map, + uint16_t* const* remote_expert_input_token, + float* const* remote_expert_input_prob, + const uint16_t* rdma_inter_node_group_token, + const float* rdma_inter_node_group_prob, + const uint64_t* rdma_inter_node_group_flags, + SMEM_TYPE* smem_buffer_ptr) +{ + // All chunks in output buffer(attn buffer) will be divided into token groups and assigned to different CUDA blocks. + // This is different than other functions where chunks are assigned to different CUDA blocks. + static_assert(NUM_OF_TOKENS_PER_CHUNK % NUM_OF_TOKENS_PER_GROUP == 0, "NUM_OF_TOKENS_PER_CHUNK must be multiple of NUM_OF_TOKENS_PER_GROUP."); + constexpr int NUM_OF_TOKEN_GROUPS_PER_CHUNK = NUM_OF_TOKENS_PER_CHUNK / NUM_OF_TOKENS_PER_GROUP; + + static_assert(sizeof(bool) == 1, "Bool is not 1 byte???"); + // Load rdma_to_attn_map for a token group at once. Each dst token will need 1 bool from this map. + using rdma_to_attn_map_load_t = Copy_t; + static_assert(NUM_OF_TOKENS_PER_GROUP == sizeof(rdma_to_attn_map_load_t), "Current implementation requires NUM_OF_TOKENS_PER_GROUP to be 1/2/4/8/16."); + + //constexpr int NUM_OF_RDMA_TO_ATTN_LOAD_ITER_PER_CHUNK = NUM_OF_TOKENS_PER_CHUNK / sizeof(rdma_to_attn_map_load_t); + //constexpr int NUM_OF_TOKENS_PER_RDMA_TO_ATTN_LOAD_ITER = sizeof(rdma_to_attn_map_load_t) / sizeof(bool); + + // Load sparse_to_dense_map according to the NUM_OF_RANKS_PER_NODE. + using sparse_to_dense_map_load_t = Copy_t; + constexpr int NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_OUTPUT_TOKEN = (NUM_OF_RANKS_PER_NODE * sizeof(int32_t)) / sizeof(sparse_to_dense_map_load_t); + constexpr int NUM_OF_INPUT_TOKENS_PER_LOAD_ITER = sizeof(sparse_to_dense_map_load_t) / sizeof(int32_t); + + // The inter node reduction warp group of each CUDA block produce a token group of a chunk at a time. Token groups of each chunk assigned to each CUDA block in interleave pattern. + // The chunk order is: i.e. chunk 0, then chunk 1, ... the last chunk of attn output buffer. + // The RDMA network for current rank will produce the same chunk id from node - 1, node - 2 ... node + 1. + // So inter node reduction warp group will consume the src chunk in the same order. + + const int remainder_chunk_size = num_of_tokens_per_rank % NUM_OF_TOKENS_PER_CHUNK; + // How many chunks per rank. Including full chunks and the remainder chunk. + const int num_of_chunks_per_rank = ((num_of_tokens_per_rank - 1) / NUM_OF_TOKENS_PER_CHUNK) + 1; + // Total number of chunks to process in the output buffer(attn buffer). output buffer(attn buffer) will only have 1 rank's tokens. + const int total_num_of_chunks = num_of_chunks_per_rank; + // The rdma_to_attn_map need to be paded to multiple of rdma_to_attn_map_load_t per node. + // The largest size of rdma_to_attn_map_load_t allowed in all Hybrid-EP kernels are 16B(16 bools), so need to be paded to 16B per node. + // That means the size of rdma_to_attn_map should be rdma_to_attn_map_size_per_node * NUM_OF_NODES. + const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + // Token stage id and phase. + int token_stage = 0; + uint32_t token_consumer_parity = 1; + + // Only 1 thread within the intra-node G2S warp will be active, other threads will just exit. + if(elect_sync(~0)){ + // Iterate through all chunks. All chunks will assign to all CUDA block. + for(int i = 0; i < total_num_of_chunks; i++){ + // How many rdma_to_attn load iter(a.k.a token group) for this chunk. + int num_of_token_groups_for_current_chunk; + // How many token for this chunk. + int current_chunk_size; + if(remainder_chunk_size != 0 && i == num_of_chunks_per_rank - 1){ + num_of_token_groups_for_current_chunk = ((remainder_chunk_size - 1) / NUM_OF_TOKENS_PER_GROUP) + 1; + current_chunk_size = remainder_chunk_size; + }else{ + num_of_token_groups_for_current_chunk = NUM_OF_TOKEN_GROUPS_PER_CHUNK; + current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; + } + const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + + (node_rank * rdma_to_attn_map_size_per_node + i * NUM_OF_TOKENS_PER_CHUNK)); + const int32_t* sparse_to_dense_map_load_base_addr = sparse_to_dense_map + (node_rank * num_of_tokens_per_rank + i * NUM_OF_TOKENS_PER_CHUNK) * NUM_OF_RANKS_PER_NODE; + + const bool* attn_to_rdma_map_load_base_addr = attn_to_rdma_map + (i * NUM_OF_TOKENS_PER_CHUNK) * (NUM_OF_NODES - 1); + + // Padding from NUM_OF_NODES - 1 to NUM_OF_NODES in case NUM_OF_NODES = 1. + // We still only use first NUM_OF_NODES - 1 flags, the last flag is the padding and not been used. + bool rdma_flag_clear[NUM_OF_NODES]; + #pragma unroll + for(int j = 0; j < NUM_OF_NODES; j++){ + rdma_flag_clear[j] = false; + } + + // Iterate through all token groups within this chunk which assign to this CUDA block. + for(int j = blockIdx.x; j < num_of_token_groups_for_current_chunk; j += NUM_OF_BLOCKS){ + rdma_to_attn_map_load_t rdma_to_attn_map_data = rdma_to_attn_map_load_base_addr[j]; + // Iterate through all dst(output) tokens within this token group. + #pragma unroll + for(int k = 0; k < NUM_OF_TOKENS_PER_GROUP; k++){ + int current_token_id = j * NUM_OF_TOKENS_PER_GROUP + k; + // If the current token is out-of-bound, then just end this load iter. + if(current_token_id >= current_chunk_size){ + break; + } + // Each dst token need to accumulate src tokens from local node's ranks(this part is the same as intra-node reduction), and src tokens from rdma inter-node buffers. + // Accumulate local tokens first, then rdma tokens. + + // Check whether this dst token is needed by this(local) node. If not needed, just skip local accumulation. + bool token_needed_by_this_node = *(reinterpret_cast(&rdma_to_attn_map_data) + k); + // If this dst token is needed by this node, load the sparse_to_dense map and load the local src token for this dst token. + if(token_needed_by_this_node){ + const sparse_to_dense_map_load_t* sparse_to_dense_map_load_addr = reinterpret_cast + (sparse_to_dense_map_load_base_addr + (j * NUM_OF_TOKENS_PER_GROUP + k) * NUM_OF_RANKS_PER_NODE); + // Load sparse_to_dense map for this dst token(i.e. a row in sparse_to_dense map). + sparse_to_dense_map_load_t sparse_to_dense_map_data[NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_OUTPUT_TOKEN]; + // First load sparse_to_dense map and decide the last src token within this row. + int last_src_token_id = 0; + #pragma unroll + for(int n = 0; n < NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_OUTPUT_TOKEN; n++){ + sparse_to_dense_map_data[n] = sparse_to_dense_map_load_addr[n]; + #pragma unroll + for(int m = 0; m < NUM_OF_INPUT_TOKENS_PER_LOAD_ITER; m++){ + int32_t sparse_to_dense_map_value = *(reinterpret_cast(&sparse_to_dense_map_data[n]) + m); + if(sparse_to_dense_map_value != -1){ + last_src_token_id = n * NUM_OF_INPUT_TOKENS_PER_LOAD_ITER + m; + } + } + } + // Then issue all G2S TMA for this row. + #pragma unroll + for(int n = 0; n < NUM_OF_SPARSE_TO_DENSE_MAP_LOAD_ITER_PER_OUTPUT_TOKEN; n++){ + #pragma unroll + for(int m = 0; m < NUM_OF_INPUT_TOKENS_PER_LOAD_ITER; m++){ + int32_t sparse_to_dense_map_value = *(reinterpret_cast(&sparse_to_dense_map_data[n]) + m); + if(sparse_to_dense_map_value != -1){ + int current_src_token_id = n * NUM_OF_INPUT_TOKENS_PER_LOAD_ITER + m; + // Wait until current token entry within the shared memory has been consumed. + while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][1], token_consumer_parity)){} + + uint32_t total_tx_size = 0; + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->inter_node_token_G2S_buffer[token_stage][0]), + reinterpret_cast(remote_expert_input_token[current_src_token_id] + (sparse_to_dense_map_value * HIDDEN_DIM)), + (uint32_t)(HIDDEN_DIM * sizeof(uint16_t)), + &smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0]); + + total_tx_size += (uint32_t)(HIDDEN_DIM * sizeof(uint16_t)); + + if constexpr(BACKWARD_COMBINE){ + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->inter_node_prob_G2S_buffer[token_stage][0]), + reinterpret_cast(remote_expert_input_prob[current_src_token_id] + (sparse_to_dense_map_value * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE))), + (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float)), + &smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0]); + + total_tx_size += (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float)); + } + + if(current_src_token_id == last_src_token_id){ + smem_buffer_ptr->inter_node_flag_G2S_buffer[token_stage] = true; + } + else{ + smem_buffer_ptr->inter_node_flag_G2S_buffer[token_stage] = false; + } + + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_release, + cuda::ptx::scope_cta, + cuda::ptx::space_shared, + &smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0], + total_tx_size); + + // Goto next token entry in shared memory. + token_stage += 1; + if(token_stage == NUM_OF_STAGES_G2S){ + token_stage = 0; + token_consumer_parity ^= 1; + } + } + } + } + } + // Then accumulate from rdma inter-node buffers. There are total NUM_OF_NODES - 1 (possible) src tokens from rdma buffer to reduce. + const bool* attn_to_rdma_map_load_addr = attn_to_rdma_map_load_base_addr + (j * NUM_OF_TOKENS_PER_GROUP + k) * (NUM_OF_NODES - 1); + #pragma unroll + for(int n = 1; n < NUM_OF_NODES; n++){ + // The current node been processed. For each chunk id, node_id order is + // (no local_node itself, which is already been accumulated above) local_node - 1, local_node - 2, ......, local_node + 1 and will wrap around. + int node_id = node_rank >= n ? node_rank - n : node_rank + NUM_OF_NODES - n; + // The tile id within the rdma buffers for the current node id. Because rdma buffers only have NUM_OF_NODES - 1 tile. + int rdma_buffer_tile_id = node_id > node_rank ? node_id - 1 : node_id; + // Check wether current dst token need src token from this node. + if(attn_to_rdma_map_load_addr[rdma_buffer_tile_id]){ + // If the current chunk is not ready yet, wait for related rdma inter-node group buffer chunks ready first. + if(rdma_flag_clear[n - 1] == false){ + const uint64_t* flag_location = rdma_inter_node_group_flags + (rdma_buffer_tile_id * num_of_chunks_per_rank + i); + uint64_t rdma_flag = 0; + do{ + rdma_flag = 0; + // Need a strong system-scope load to observe external RDMA Atomic result. + asm volatile("ld.relaxed.sys.global.b64 %0, [%1];" + : "=l"(rdma_flag) + : "l"(__cvta_generic_to_global(flag_location)) + : "memory"); + }while(rdma_flag != *expected_flag_value); + + // Mark the chunk from this node(tile) is already clear. + rdma_flag_clear[n - 1] = true; + } + // Wait until current token entry within the shared memory has been consumed. + while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][1], token_consumer_parity)){} + // Load the src token from this rdma inter-node group buffer chunk to shared memory entry. + uint32_t total_tx_size = 0; + const uint16_t* rdma_inter_node_group_token_load_addr = rdma_inter_node_group_token + + (rdma_buffer_tile_id * num_of_tokens_per_rank + + i * NUM_OF_TOKENS_PER_CHUNK + + j * NUM_OF_TOKENS_PER_GROUP + k) * HIDDEN_DIM; + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->inter_node_token_G2S_buffer[token_stage][0]), + reinterpret_cast(rdma_inter_node_group_token_load_addr), + (uint32_t)(HIDDEN_DIM * sizeof(uint16_t)), + &smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0]); + + total_tx_size += (uint32_t)(HIDDEN_DIM * sizeof(uint16_t)); + + if constexpr(BACKWARD_COMBINE){ + const float* rdma_inter_node_group_prob_load_addr = rdma_inter_node_group_prob + + (rdma_buffer_tile_id * num_of_tokens_per_rank + + i * NUM_OF_TOKENS_PER_CHUNK + + j * NUM_OF_TOKENS_PER_GROUP + k) * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE); + + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, + cuda::ptx::space_global, + reinterpret_cast(&smem_buffer_ptr->inter_node_prob_G2S_buffer[token_stage][0]), + reinterpret_cast(rdma_inter_node_group_prob_load_addr), + (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float)), + &smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0]); + + total_tx_size += (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE) * sizeof(float)); + } + + // Inter-node token does not need flag since the red warp group will also read attn_to_rdma_map. + + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_release, + cuda::ptx::scope_cta, + cuda::ptx::space_shared, + &smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0], + total_tx_size); + + // Goto next token entry in shared memory. + token_stage += 1; + if(token_stage == NUM_OF_STAGES_G2S){ + token_stage = 0; + token_consumer_parity ^= 1; + } + } + } + } + } + } + } +} + +// Device function for inter-node reduction warp group for combine kernel. +template +inline __device__ void inter_node_red_warp_group_device_function(const int node_rank, + const int num_of_tokens_per_rank, + const bool* rdma_to_attn_map, + const bool* attn_to_rdma_map, + uint16_t* attn_output_token, + float* attn_output_prob, + SMEM_TYPE* smem_buffer_ptr) +{ + // All chunks in output buffer(attn buffer) will be divided into token groups and assigned to different CUDA blocks. + // This is different than other functions where chunks are assigned to different CUDA blocks. + static_assert(NUM_OF_TOKENS_PER_CHUNK % NUM_OF_TOKENS_PER_GROUP == 0, "NUM_OF_TOKENS_PER_CHUNK must be multiple of NUM_OF_TOKENS_PER_GROUP."); + constexpr int NUM_OF_TOKEN_GROUPS_PER_CHUNK = NUM_OF_TOKENS_PER_CHUNK / NUM_OF_TOKENS_PER_GROUP; + + static_assert(sizeof(bool) == 1, "Bool is not 1 byte???"); + // Load rdma_to_attn_map for a token group at once. Each dst token will need 1 bool from this map. + using rdma_to_attn_map_load_t = Copy_t; + static_assert(NUM_OF_TOKENS_PER_GROUP == sizeof(rdma_to_attn_map_load_t), "Current implementation requires NUM_OF_TOKENS_PER_GROUP to be 1/2/4/8/16."); + + // Processing token using BF16x2 intruction, HIDDEN_DIM must be multiple of 2. + static_assert(HIDDEN_DIM % 2 == 0, "HIDDEN_DIM must be multiple of 2."); + constexpr int NUM_OF_ELEMENT_PER_THREAD = (HIDDEN_DIM / 2) / INTER_NODE_RED_GROUP::size(); + // Processing prob using fp32. + constexpr int NUM_OF_PROB_VEC_ELEMENT_PER_THREAD = ((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE - 1) / INTER_NODE_RED_GROUP::size()) + 1; + //static_assert(INTER_NODE_RED_GROUP::size() >= NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE, "The size of inter-node reduction warp group must not be smaller than prob size."); + + // The inter node reduction warp group of each CUDA block produce a token group of a chunk at a time. Token groups of each chunk assigned to each CUDA block in interleave pattern. + // The chunk order is: i.e. chunk 0, then chunk 1, ... the last chunk of attn output buffer. + // The RDMA network for current rank will produce the same chunk id from node - 1, node - 2 ... node + 1. + // So inter node reduction warp group will consume the src chunk in the same order. + + const int remainder_chunk_size = num_of_tokens_per_rank % NUM_OF_TOKENS_PER_CHUNK; + // How many chunks per rank. Including full chunks and the remainder chunk. + const int num_of_chunks_per_rank = ((num_of_tokens_per_rank - 1) / NUM_OF_TOKENS_PER_CHUNK) + 1; + // Total number of chunks to process in the output buffer(attn buffer). output buffer(attn buffer) will only have 1 rank's tokens. + const int total_num_of_chunks = num_of_chunks_per_rank; + // The rdma_to_attn_map need to be paded to multiple of rdma_to_attn_map_load_t per node. + // The largest size of rdma_to_attn_map_load_t allowed in all Hybrid-EP kernels are 16B(16 bools), so need to be paded to 16B per node. + // That means the size of rdma_to_attn_map should be rdma_to_attn_map_size_per_node * NUM_OF_NODES. + const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + // Src token stage id and phase. + int token_stage = 0; + uint32_t token_producer_parity = 0; + + // Dst token stage id. + int dst_token_stage = 0; + + // Iterate through all chunks. All chunks will assign to all CUDA block. + for(int i = 0; i < total_num_of_chunks; i++){ + // How many rdma_to_attn load iter(a.k.a token group) for this chunk. + int num_of_token_groups_for_current_chunk; + // How many token for this chunk. + int current_chunk_size; + if(remainder_chunk_size != 0 && i == num_of_chunks_per_rank - 1){ + num_of_token_groups_for_current_chunk = ((remainder_chunk_size - 1) / NUM_OF_TOKENS_PER_GROUP) + 1; + current_chunk_size = remainder_chunk_size; + }else{ + num_of_token_groups_for_current_chunk = NUM_OF_TOKEN_GROUPS_PER_CHUNK; + current_chunk_size = NUM_OF_TOKENS_PER_CHUNK; + } + const rdma_to_attn_map_load_t* rdma_to_attn_map_load_base_addr = reinterpret_cast(rdma_to_attn_map + + (node_rank * rdma_to_attn_map_size_per_node + i * NUM_OF_TOKENS_PER_CHUNK)); + const bool* attn_to_rdma_map_load_base_addr = attn_to_rdma_map + (i * NUM_OF_TOKENS_PER_CHUNK) * (NUM_OF_NODES - 1); + uint16_t* attn_output_token_base_ptr = attn_output_token + (i * NUM_OF_TOKENS_PER_CHUNK) * HIDDEN_DIM; + float* attn_output_prob_base_ptr; + if constexpr(BACKWARD_COMBINE){ + attn_output_prob_base_ptr = attn_output_prob + (i * NUM_OF_TOKENS_PER_CHUNK) * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES); + } + // Iterate through all token groups within this chunk which assign to this CUDA block. + for(int j = blockIdx.x; j < num_of_token_groups_for_current_chunk; j += NUM_OF_BLOCKS){ + rdma_to_attn_map_load_t rdma_to_attn_map_data = rdma_to_attn_map_load_base_addr[j]; + // Iterate through all dst(output) tokens within this token group. + #pragma unroll + for(int k = 0; k < NUM_OF_TOKENS_PER_GROUP; k++){ + int current_token_id = j * NUM_OF_TOKENS_PER_GROUP + k; + // If the current token is out-of-bound, then just end this load iter. + if(current_token_id >= current_chunk_size){ + break; + } + // Each dst token need to accumulate src tokens from local node's ranks(this part is the same as intra-node reduction), and src tokens from rdma inter-node buffers. + // Accumulate local tokens first, then rdma tokens. + // Accumulator for this dst token. Token must be accumulated in FP32. + float2 acc_token_fp32[NUM_OF_ELEMENT_PER_THREAD]; + // Optional Accumulator for this dst token prob. + // Different node's prob need to be gathered together to output. + // 0 used for local node's prob, [1, NUM_OF_NODES - 1] used for remote node's prob. + float acc_prob[NUM_OF_NODES][NUM_OF_PROB_VEC_ELEMENT_PER_THREAD]; + // Init accumulator. + #pragma unroll + for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ + acc_token_fp32[n].x = 0.0f; + acc_token_fp32[n].y = 0.0f; + } + #pragma unroll + for(int n = 0; n < NUM_OF_NODES; n++){ + #pragma unroll + for(int m = 0; m < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; m++){ + acc_prob[n][m] = 0.0f; + } + } + + // Check whether this dst token is needed by this(local) node. If not needed, just skip local accumulation. + bool token_needed_by_this_node = *(reinterpret_cast(&rdma_to_attn_map_data) + k); + // If this dst token is needed by this node, load the local src token from shared memory and accumulate them. + if(token_needed_by_this_node){ + // End reduction group flag. + bool last_local_node_src_token = false; + + // Continue loading local src token for this dst token and reduce them to accumulator until all local src token for this dst token have been accumulated. + do{ + // Base address for current token and prob(optional) in shared memory. + __nv_bfloat162* load_token_base_ptr = reinterpret_cast<__nv_bfloat162*>(&smem_buffer_ptr->inter_node_token_G2S_buffer[token_stage][0]); + float* load_prob_base_ptr; + if constexpr(BACKWARD_COMBINE){ + load_prob_base_ptr = &smem_buffer_ptr->inter_node_prob_G2S_buffer[token_stage][0]; + } + + // Wait until current src token ready in shared memory. + if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0], token_producer_parity)){} + } + } + arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + + // Accumulate token and prob(optional). + #pragma unroll + for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ + int element_id = (n * INTER_NODE_RED_GROUP::size()) + INTER_NODE_RED_GROUP::thread_rank(); + __nv_bfloat162 src_data = load_token_base_ptr[element_id]; + float2 src_data_fp32 = __bfloat1622float2(src_data); + acc_token_fp32[n].x += src_data_fp32.x; + acc_token_fp32[n].y += src_data_fp32.y; + } + + if constexpr(BACKWARD_COMBINE){ + #pragma unroll + for(int n = 0; n < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; n++){ + int element_id = INTER_NODE_RED_GROUP::thread_rank() + n * INTER_NODE_RED_GROUP::size(); + if(element_id < NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE){ + float src_data = load_prob_base_ptr[element_id]; + acc_prob[0][n] += src_data; + } + } + } + + // Check flag for last src token. + last_local_node_src_token = smem_buffer_ptr->inter_node_flag_G2S_buffer[token_stage]; + + // Make sure all warp group have finished loading the token entry and accumulate it to the register accumulator. + // Then notify the producer warp to load next token entry to the shared memory as the shared memory can be reused. + arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][1]); + } + } + + // Goto next src token entry. + token_stage += 1; + if(token_stage == NUM_OF_STAGES_G2S){ + token_stage = 0; + token_producer_parity ^= 1; + } + + }while(!last_local_node_src_token); + } + + // Then accumulate from rdma inter-node buffers. There are total NUM_OF_NODES - 1 (possible) src tokens from rdma buffer to reduce. + const bool* attn_to_rdma_map_load_addr = attn_to_rdma_map_load_base_addr + (j * NUM_OF_TOKENS_PER_GROUP + k) * (NUM_OF_NODES - 1); + #pragma unroll + for(int n = 1; n < NUM_OF_NODES; n++){ + // The current node been processed. For each chunk id, node_id order is + // (no local_node itself, which is already been accumulated above) local_node - 1, local_node - 2, ......, local_node + 1 and will wrap around. + int node_id = node_rank >= n ? node_rank - n : node_rank + NUM_OF_NODES - n; + // The tile id within the rdma buffers(include attn_to_rdma map) for the current node id. Because these rdma buffers only have NUM_OF_NODES - 1 tile or element. + int rdma_buffer_tile_id = node_id > node_rank ? node_id - 1 : node_id; + // Check wether current dst token need src token from this (remote) node. + if(attn_to_rdma_map_load_addr[rdma_buffer_tile_id]){ + // Base address for current token and prob(optional) in shared memory. + __nv_bfloat162* load_token_base_ptr = reinterpret_cast<__nv_bfloat162*>(&smem_buffer_ptr->inter_node_token_G2S_buffer[token_stage][0]); + float* load_prob_base_ptr; + if constexpr(BACKWARD_COMBINE){ + load_prob_base_ptr = &smem_buffer_ptr->inter_node_prob_G2S_buffer[token_stage][0]; + } + // Wait until current src token ready in shared memory. + if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + while(!cuda::ptx::mbarrier_try_wait_parity(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][0], token_producer_parity)){} + } + } + arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + + // Accumulate token and prob(optional). + #pragma unroll + for(int m = 0; m < NUM_OF_ELEMENT_PER_THREAD; m++){ + int element_id = (m * INTER_NODE_RED_GROUP::size()) + INTER_NODE_RED_GROUP::thread_rank(); + __nv_bfloat162 src_data = load_token_base_ptr[element_id]; + float2 src_data_fp32 = __bfloat1622float2(src_data); + acc_token_fp32[m].x += src_data_fp32.x; + acc_token_fp32[m].y += src_data_fp32.y; + } + + if constexpr(BACKWARD_COMBINE){ + #pragma unroll + for(int m = 0; m < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; m++){ + int element_id = INTER_NODE_RED_GROUP::thread_rank() + m * INTER_NODE_RED_GROUP::size(); + if(element_id < NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE){ + acc_prob[n][m] = load_prob_base_ptr[element_id]; + } + } + } + + // Inter-node token does not need flag. + + // Make sure all warp group have finished loading the token entry and accumulate it to the register accumulator. + // Then notify the producer warp to load next token entry to the shared memory as the shared memory can be reused. + arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + cuda::ptx::mbarrier_arrive(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[token_stage][1]); + } + } + + // Goto next src token entry. + token_stage += 1; + if(token_stage == NUM_OF_STAGES_G2S){ + token_stage = 0; + token_producer_parity ^= 1; + } + } + } + + // Store the dst token back to share memory. + // Because each attn token must have go to TOPK rank in dispatch, so it must have been reduced in combine. So each attn dst token must be written back. + // Base address for current dst token and prob(optional) in shared memory. + __nv_bfloat162* store_token_base_ptr = reinterpret_cast<__nv_bfloat162*>(&smem_buffer_ptr->inter_node_token_S2G_buffer[dst_token_stage][0]); + float* store_prob_base_ptr; + if constexpr(BACKWARD_COMBINE){ + store_prob_base_ptr = &smem_buffer_ptr->inter_node_prob_S2G_buffer[dst_token_stage][0]; + } + + // Let the TMA thread to wait for previously issued TMA S2G operations finish reading this entry. + if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t{}); + } + } + // Make sure all threads within the red warp group have wait for previously issued TMA S2G operations finish reading this entry before storing new data to this entry. + arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + + // Store the token. + #pragma unroll + for(int n = 0; n < NUM_OF_ELEMENT_PER_THREAD; n++){ + int element_id = (n * INTER_NODE_RED_GROUP::size()) + INTER_NODE_RED_GROUP::thread_rank(); + // Convert accumulated token back to BF16 and store the result back to shared memory token entry. + store_token_base_ptr[element_id] = __float22bfloat162_rn(acc_token_fp32[n]); + } + + // Store the prob(optional). + if constexpr(BACKWARD_COMBINE){ + #pragma unroll + for(int n = 0; n < NUM_OF_NODES; n++){ + int attn_prob_output_node_id = (node_rank - n) >= 0 ? node_rank - n : node_rank + NUM_OF_NODES - n; + int element_base_id = attn_prob_output_node_id * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE); + #pragma unroll + for(int m = 0; m < NUM_OF_PROB_VEC_ELEMENT_PER_THREAD; m++){ + int element_id = INTER_NODE_RED_GROUP::thread_rank() + m * INTER_NODE_RED_GROUP::size(); + if(element_id < NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE){ + store_prob_base_ptr[element_base_id + element_id] = acc_prob[n][m]; + } + } + } + } + + // Make sure the shared memory stored by current thread is visible by async proxy. + cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); + + // Make sure all threads within the red warp group have finished storing the current token entry and making it visible to async proxy. + arrive_and_wait(INTER_NODE_RED_GROUP::size(), 2); + + // Let the TMA thread to issue S2G TMA operations for current token entry. + if(INTER_NODE_RED_GROUP::warp_rank() == 0){ + if(elect_sync(~0)){ + uint16_t* current_token_addr = attn_output_token_base_ptr + (j * NUM_OF_TOKENS_PER_GROUP + k) * HIDDEN_DIM; + // Store the token from shared to global output. + cuda::ptx::cp_async_bulk(cuda::ptx::space_global, + cuda::ptx::space_shared, + reinterpret_cast(current_token_addr), + reinterpret_cast(&smem_buffer_ptr->inter_node_token_S2G_buffer[dst_token_stage][0]), + (uint32_t)(HIDDEN_DIM * sizeof(uint16_t))); + + // Store the prob from shared to global output. + if constexpr(BACKWARD_COMBINE){ + float* current_prob_addr = attn_output_prob_base_ptr + (j * NUM_OF_TOKENS_PER_GROUP + k) * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES); + cuda::ptx::cp_async_bulk(cuda::ptx::space_global, + cuda::ptx::space_shared, + reinterpret_cast(current_prob_addr), + reinterpret_cast(&smem_buffer_ptr->inter_node_prob_S2G_buffer[dst_token_stage][0]), + (uint32_t)((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES) * sizeof(float))); + + } + // Commit S2G TMA operations for this dst token into a bulk async copy group. + cuda::ptx::cp_async_bulk_commit_group(); + } + } + + // Goto next dst token entry. + dst_token_stage += 1; + if(dst_token_stage == NUM_OF_STAGES_S2G){ + dst_token_stage = 0; + } + } + } + } + // Because the attn output buffers will only be produced by local combine kernel, not by the combine kernels on other ranks, + // so we only need to wait for local combine kernel to finish writing all token data back to output buffer before we can exit. + // Also, a kernel will be considered completed from CUDA stream's perspective if and only if all the threads are exit and all memory operations(including TMA operations) + // issued by all threads have been completed and made visible to sys scope. + // So the CUDA stream's kernel boundary implicit synchronization should be enough to sync with all TMA operations issued in the combine kernel. + // So we can directly exit w/o any explicit synchronization with TMA operations. +} + +__launch_bounds__(1, 1) +__global__ void device_sync_kernel(uint32_t* intra_node_remote_flags, const uint32_t* expected_flag_value) +{ + // Atomically reduce add 1 to the u32 flag on rank #0 in current NVLink domain. + // Need a strong system-scope red to make sure all ranks from current NVLink domain can see the side effect. + // But no memory fence(i.e. .release) needed since CUDA stream already do that for us. + // red.relaxed.sys.global.add.u32 [a], 1; + asm volatile("red.relaxed.sys.global.add.u32 [%0], %1;" + : + : "l"(__cvta_generic_to_global(intra_node_remote_flags)), "n"(1) + : "memory"); + + // Polling flag value from the u32 flag on rank #0 in current NVLink domain. + // Keep polling until reach the expected value. + uint32_t flag_data = 0; + do{ + flag_data = 0; + // Need a strong system-scope load to observe other ranks' Atomic result. + // But no no memory fence(i.e. .aquired) needed since no memory operation behind this. + asm volatile("ld.relaxed.sys.global.u32 %0, [%1];" + : "=r"(flag_data) + : "l"(__cvta_generic_to_global(intra_node_remote_flags)) + : "memory"); + }while(flag_data != *expected_flag_value); +} + +// This kernel will update expected_rdma_flag_value and expected_intra_node_flag_value in local device memory +// by increasing the expected_rdma_flag_value by 1 and expected_intra_node_flag_value by NUM_OF_RANKS_PER_NODE. +template +__launch_bounds__(1, 1) +__global__ void update_expected_value_kernel(uint64_t* expected_rdma_flag_value, uint32_t* expected_intra_node_flag_value) +{ + if constexpr(NUM_OF_NODES != 1){ + (*expected_rdma_flag_value) += 1; + } + if constexpr(DEVICE_SIDE_SYNC){ + (*expected_intra_node_flag_value) += NUM_OF_RANKS_PER_NODE; + } +} + +template +// Each CUDA block of dispatch kernel has 3 warp groups and has the following layout: +// 1. inter-node warp group(i.e. RDMA N2N warp group, 1 warp, only valid for multinode scenario) 2. intra-node G2S warp group(i.e. NVL G2S warp group, 1 warp). +// 3. intra-node S2G warp group(i.e. NVL S2G warp group, 1 warp). Total 2 or 3 warps per CUDA block/SM. +__launch_bounds__(INTER_NODE_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTRA_NODE_S2G_GROUP::size(), 1) +__global__ void dispatch_kernel(const __grid_constant__ dispatch_kernel_param_t param) +{ + // Compile-time check. For now, 1 G2S and 1 S2G warp should be enough. + static_assert(INTRA_NODE_G2S_GROUP::size() == 32, "Dispatch kernel only support 1 G2S warp currently."); + static_assert(INTRA_NODE_S2G_GROUP::size() == 32, "Dispatch kernel only support 1 S2G warp currently."); + // The token and its properties should meet size and alignment requirement. + // Currently, we use TMA to copy prob data, which need at least 16B size and alignment(which requires expert per node to be multiple of 4). + // We need to add padding or not using TMA for prob, if we want to support other scenario. + static_assert((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * sizeof(float)) % 16 == 0, "Currently, expert per node must be multiple of 4(So the prob for each token is multiple of 16B) to make TMA work."); + // If FP8 token is used, HIDDEN_DIM must be multiple of 512 to make scaling factor multiple of 16B to make TMA work. + static_assert(((HIDDEN_DIM / 128) * sizeof(float)) % 16 == 0, "Currently, scaling factor per token must be multiple of 16B."); + + + // Shared memory used over 48KB, should use dynamic shared memory. + extern __shared__ uint8_t smem_bytes[]; + using cur_smem_t = dispatch_kernel_dynamic_shared_memory_buffer_t; + cur_smem_t* smem_buffer_ptr = reinterpret_cast(smem_bytes); + + // Let first thread of each CUDA block initialize the mbarrier. + if(threadIdx.x == 0){ + for(int i = 0; i < NUM_OF_STAGES; i++){ + // Initialize mbarrier + cuda::ptx::mbarrier_init(&smem_buffer_ptr->intra_node_mbarrier_buffer[i][0], 1); + cuda::ptx::mbarrier_init(&smem_buffer_ptr->intra_node_mbarrier_buffer[i][1], 1); + } + // Make mbarriers initialization visible to async proxy(TMA). + cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); + } + + // Make sure all the warps wait for mbarriers to be initialized before producing/consuming data. + __syncthreads(); + + // Now warps can become specialized. + // The input warp group data type must match the warp groups layout. + // To prevent compiler generate pointless comparison warning. + int threadIdx_x_int = (int)threadIdx.x; + if(threadIdx_x_int < INTER_NODE_GROUP::size()){ + }else if(threadIdx_x_int < INTER_NODE_GROUP::size() + INTRA_NODE_G2S_GROUP::size()){ + // Intra-node G2S warp groups. + G2S_warp_group_device_function + + (param.node_rank, param.num_of_tokens_per_rank, param.expected_rdma_flag_value, param.rdma_to_attn_map, param.attn_input_token, param.attn_input_prob, param.attn_input_token_scaling_factor, param.rdma_inter_node_group_token, + param.rdma_inter_node_group_prob, param.rdma_inter_node_group_scaling_factor, param.rdma_inter_node_group_flags, smem_buffer_ptr); + }else if(threadIdx_x_int < INTER_NODE_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTRA_NODE_S2G_GROUP::size()){ + // Intra-node S2G warp groups. + S2G_warp_group_device_function + + (param.local_rank, param.node_rank, param.num_of_tokens_per_rank, param.rdma_to_attn_map, param.sparse_to_dense_map, param.expert_output_token, param.expert_output_prob, + param.expert_output_scaling_factor, smem_buffer_ptr); + }else{ + // Too many threads, should not goes here. + } +} + +template +// Each CUDA block of combine kernel has 5 warp groups and has the following layout: +// 1. intra-node reduction warp group(4 warps, only valid for multinode scenario). 2. inter-node reduction warp group(4 warps). +// 3. intra-node G2S warp group(1 warp, only valid for multinode scenario). 4. inter-node G2S warp group(1 warp). 5. inter-node N2N rdma warp group(1 warp, only valid for multinode scenario). +// Total 5 or 11 warps per CUDA block/SM. +__launch_bounds__(INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTER_NODE_G2S_GROUP::size() + INTER_NODE_RDMA_GROUP::size(), 1) +__global__ void combine_kernel(const __grid_constant__ combine_kernel_param_t param) +{ + // Compile-time check. For now, 1 G2S and 1 S2G warp should be enough. + static_assert(INTER_NODE_G2S_GROUP::size() == 32, "Combine kernel only support 1 INTER_NODE_G2S warp currently."); + // The token and its properties should meet size and alignment requirement. + // Currently, we use TMA to copy prob data, which need at least 16B size and alignment(which requires expert per node to be multiple of 4). + // We need to add padding or not using TMA for prob, if we want to support other scenario. + static_assert((NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * sizeof(float)) % 16 == 0, "Currently, expert per node must be multiple of 4(So the prob for each token is multiple of 16B) to make TMA work."); + static_assert(MAX_NUM_OF_TOKENS_PER_RANK % NUM_OF_TOKENS_PER_CHUNK == 0, "MAX_NUM_OF_TOKENS_PER_RANK must be multiple of NUM_OF_TOKENS_PER_CHUNK."); + constexpr int MAX_NUM_OF_CHUNKS_PER_RANK = MAX_NUM_OF_TOKENS_PER_RANK / NUM_OF_TOKENS_PER_CHUNK; + + // Shared memory used over 48KB, should use dynamic shared memory. + extern __shared__ uint8_t smem_bytes[]; + using cur_smem_t = combine_kernel_dynamic_shared_memory_buffer_t + ; + cur_smem_t* smem_buffer_ptr = reinterpret_cast(smem_bytes); + + // Let first thread of each CUDA block initialize the mbarrier. + if(threadIdx.x == 0){ + for(int i = 0; i < NUM_OF_STAGES_G2S; i++){ + // Initialize mbarrier + if constexpr(NUM_OF_NODES != 1){ + cuda::ptx::mbarrier_init(&smem_buffer_ptr->intra_node_mbarrier_G2S_buffer[i][0], 1); + cuda::ptx::mbarrier_init(&smem_buffer_ptr->intra_node_mbarrier_G2S_buffer[i][1], 1); + } + cuda::ptx::mbarrier_init(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[i][0], 1); + cuda::ptx::mbarrier_init(&smem_buffer_ptr->inter_node_mbarrier_G2S_buffer[i][1], 1); + } + if constexpr(NUM_OF_NODES != 1){ + // Initialize mbarrier + for(int i = 0; i < NUM_OF_NODES - 1; i++){ + for(int j = 0; j < MAX_NUM_OF_CHUNKS_PER_RANK; j++){ + cuda::ptx::mbarrier_init(&smem_buffer_ptr->intra_node_to_rdma_mbarrier_buffer[i][j], 1); + } + } + } + // Make mbarriers initialization visible to async proxy(TMA). + cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); + } + + // Make sure all the warps wait for mbarriers to be initialized before producing/consuming data. + __syncthreads(); + + // Now warps can become specialized. + // The input warp group data type must match the warp groups layout. + // To prevent compiler generate pointless comparison warning. + int threadIdx_x_int = (int)threadIdx.x; + if(threadIdx_x_int < INTRA_NODE_RED_GROUP::size()){ + // Intra-node reduction warp group. + if constexpr(NUM_OF_NODES != 1){ + intra_node_red_warp_group_device_function + + (param.node_rank, param.num_of_tokens_per_rank, param.rdma_to_attn_map, param.rdma_intra_node_red_token, param.rdma_intra_node_red_prob, smem_buffer_ptr); + } + }else if(threadIdx_x_int < INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size()){ + // Inter-node reduction warp group. + inter_node_red_warp_group_device_function + + (param.node_rank, param.num_of_tokens_per_rank, param.rdma_to_attn_map, param.attn_to_rdma_map, param.attn_output_token, param.attn_output_prob, smem_buffer_ptr); + }else if(threadIdx_x_int < INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size() + INTRA_NODE_G2S_GROUP::size()){ + // Intra-node G2S warp group. + if constexpr(NUM_OF_NODES != 1){ + intra_node_G2S_warp_group_device_function + + (param.node_rank, param.num_of_tokens_per_rank, param.rdma_to_attn_map, param.sparse_to_dense_map, param.expert_input_token, param.expert_input_prob, smem_buffer_ptr); + } + }else if(threadIdx_x_int < INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTER_NODE_G2S_GROUP::size()){ + // Inter-node G2S warp group. + inter_node_G2S_warp_group_device_function + + (param.node_rank, param.num_of_tokens_per_rank, param.expected_rdma_flag_value, param.rdma_to_attn_map, param.attn_to_rdma_map, param.sparse_to_dense_map, param.expert_input_token, param.expert_input_prob, + param.rdma_inter_node_group_token, param.rdma_inter_node_group_prob, param.rdma_inter_node_group_flags, smem_buffer_ptr); + }else if(threadIdx_x_int < INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTER_NODE_G2S_GROUP::size() + INTER_NODE_RDMA_GROUP::size()){ + // Inter-node rdma warp group. + }else{ + // Too many threads, should not goes here. + } +} + +template +__launch_bounds__(NUM_THREADS_PER_BLOCK, 1) +__global__ void scan(const bool* input_routing_map, + tmp_state_t* tmp, + int32_t* sparse_to_dense_map, + bool* rdma_to_attn_map, + bool* attn_to_rdma_map, + int32_t* num_of_tokens_for_experts, + bool* local_expert_routing_map, + const int node_rank, + const int local_rank, + const int num_of_tokens_per_rank) +{ + // Calculate the warps per block. + constexpr int WARP_SIZE = 32; + constexpr int NUM_OF_WARPS_PER_BLOCK = NUM_THREADS_PER_BLOCK / WARP_SIZE; + + // Calculate total threads count. + constexpr int NUM_OF_TOTAL_THREADS = NUM_THREADS_PER_BLOCK * NUM_OF_BLOCKS; + + // Calculate the number of tokens belong to each CUDA block, warp and thread. + // We assign 1 token(row in routing map) to 1 thread. + const int num_of_total_attn_tokens = num_of_tokens_per_rank * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES; + //static_assert(NUM_OF_TOTAL_ATTN_TOKENS % NUM_OF_TOTAL_THREADS == 0, "NUM_OF_TOTAL_ATTN_TOKENS must be multiple of NUM_OF_TOTAL_THREADS"); + const int num_of_tokens_per_thread = ((num_of_total_attn_tokens - 1) / NUM_OF_TOTAL_THREADS) + 1; + const int num_of_tokens_per_warp = num_of_tokens_per_thread * WARP_SIZE; + const int num_of_tokens_per_block = num_of_tokens_per_warp * NUM_OF_WARPS_PER_BLOCK; + // The rdma_to_attn_map need to be paded to multiple of rdma_to_attn_map_load_t per node. + // The largest size of rdma_to_attn_map_load_t allowed in all Hybrid-EP kernels are 16B(16 bools), so need to be paded to 16B per node. + // That means the size of rdma_to_attn_map should be rdma_to_attn_map_size_per_node * NUM_OF_NODES. + const int rdma_to_attn_map_size_per_node = (((num_of_tokens_per_rank - 1) / 16) + 1) * 16; + + // For each token(row in routing map), calculate how many bytes need to be loaded from the routing map and how to load them. + static_assert(sizeof(bool) == 1, "Bool is not 1 byte???"); + constexpr int NUM_OF_BYTES_TO_LOAD_FOR_EACH_TOKEN = NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE; + using copy_t = Copy_t; + static_assert(NUM_OF_BYTES_TO_LOAD_FOR_EACH_TOKEN % sizeof(copy_t) == 0, "NUM_OF_BYTES_TO_LOAD_FOR_EACH_TOKEN and copy_t mismatch"); + constexpr int ROUTING_MAP_LOAD_ITER = NUM_OF_BYTES_TO_LOAD_FOR_EACH_TOKEN / sizeof(copy_t); + + // For each token, calculate how many bytes need to be store to sparse_to_dense_map. + constexpr int NUM_OF_BYTES_TO_STORE_FOR_EACH_TOKEN = sizeof(int32_t) * NUM_OF_RANKS_PER_NODE; + using write_t = Copy_t; + static_assert(NUM_OF_BYTES_TO_STORE_FOR_EACH_TOKEN % sizeof(write_t) == 0, "NUM_OF_BYTES_TO_STORE_FOR_EACH_TOKEN and write_t mismatch"); + constexpr int S2D_MAP_STORE_ITER = NUM_OF_BYTES_TO_STORE_FOR_EACH_TOKEN / sizeof(write_t); + + // How to convert per-expert routing info to per-rank routing info. We support any number of expert per rank. + using expert_to_rank_t = Reduce_t; + static_assert(NUM_OF_EXPERTS_PER_RANK % sizeof(expert_to_rank_t) == 0, "NUM_OF_EXPERTS_PER_RANK and expert_to_rank_t mismatch"); + constexpr int EXPERTS_TO_RANK_REDUCE_ITER = NUM_OF_EXPERTS_PER_RANK / sizeof(expert_to_rank_t); + + // How to convert per-rank routing info to per-node routing info. We support any number of ranks per node(nvl domain). + //using rank_to_node_t = Reduce_t; + //static_assert(NUM_OF_RANKS_PER_NODE % sizeof(rank_to_node_t) == 0, "NUM_OF_RANKS_PER_NODE and rank_to_node_t mismatch"); + //constexpr int RANKS_TO_NODE_REDUCE_ITER = NUM_OF_RANKS_PER_NODE / sizeof(rank_to_node_t); + + // How do a warp save per-rank routing info back to shared memory. What's the max number of elements does each thread save back. + constexpr int NUM_OF_RANKS_PER_THREAD = ((NUM_OF_RANKS_PER_NODE - 1) / WARP_SIZE) + 1; + + // Sum of per-rank routing info of all warps within the block. + __shared__ int32_t warp_token_routing_map_sum[NUM_OF_WARPS_PER_BLOCK][NUM_OF_RANKS_PER_NODE]; + // Sum of previous blocks' per-rank routing info. + __shared__ int32_t previous_block_sum[NUM_OF_RANKS_PER_NODE]; + + // We assign contiguous tokens called chunk to each CUDA block, each CUDA block get the same size of chunk. + int block_starting_token = blockIdx.x * num_of_tokens_per_block; + // warp id and lane id. + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + // We assign contiguous tokens called sub-chunk to each warp within a CUDA block, each warp within a CUDA block get the same size of sub-chunk. + int warp_starting_token = block_starting_token + warp_id * num_of_tokens_per_warp; + // Within a sub-chunk, we assign tokens to thread in a interleave pattern. So each thread process a token each time and each warp sum a tile of 32 tokens each time. + int thread_starting_token = warp_starting_token + lane_id; + + // Step 0: Each warp sum the sub-chunk assigned to them and store the sum back to shared memory. + // All warps within all CTA attend this step. + // Also, some tokens need per-node info which store to rdma_to_attn_map, also processed here. + + // Sum of per-rank token routing map within a thread. + int32_t token_routing_map_sum[NUM_OF_RANKS_PER_NODE]; + #pragma unroll + for(int i = 0; i < NUM_OF_RANKS_PER_NODE; i++){ + token_routing_map_sum[i] = 0; + } + + //#pragma unroll + for(int i = 0; i < num_of_tokens_per_thread; i++){ + // The global token id conditions for current token. + int current_token_id = thread_starting_token + i * WARP_SIZE; + // If the current token is out-of-bound, then just end summing tokens assigned to this thread. + if(current_token_id >= num_of_total_attn_tokens){ + break; + } + int current_token_node_rank = current_token_id / (num_of_tokens_per_rank * NUM_OF_RANKS_PER_NODE); + int current_token_local_rank = (current_token_id % (num_of_tokens_per_rank * NUM_OF_RANKS_PER_NODE)) / num_of_tokens_per_rank; + int current_token_local_id = current_token_id % num_of_tokens_per_rank; + // If the token belongs to the inter-node group. + // We need to calculate the per-node routing info and save back to rdma_to_attn_map. + bool per_node_routing_info = (current_token_local_rank == local_rank); + int current_token_rdma_to_attn_map_id = current_token_node_rank * rdma_to_attn_map_size_per_node + current_token_local_id; + // Global routing map load base addr for current token. + const copy_t* routing_map_load_base_addr = reinterpret_cast(input_routing_map + + current_token_id * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES) + + node_rank * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE)); + + // Load the routing map for current token. + bool token_routing_map[NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; + #pragma unroll + for(int j = 0; j < ROUTING_MAP_LOAD_ITER; j++){ + *(reinterpret_cast(token_routing_map) + j) = routing_map_load_base_addr[j]; + } + + // Convert the routing map to per rank routing info and accumulate to accumulator. + // Also convert the per rank routing info to per node routing info. + bool token_needed_by_this_node = false; + #pragma unroll + for(int j = 0; j < NUM_OF_RANKS_PER_NODE; j++){ + bool token_needed_by_this_rank = false; + #pragma unroll + for(int k = 0; k < EXPERTS_TO_RANK_REDUCE_ITER; k++){ + int current_expert_to_rank_t_id = j * EXPERTS_TO_RANK_REDUCE_ITER + k; + expert_to_rank_t reduction_data = *(reinterpret_cast(token_routing_map) + current_expert_to_rank_t_id); + if(reduction_data != (expert_to_rank_t)0){ + token_needed_by_this_rank = true; + break; + } + } + if(token_needed_by_this_rank){ + token_routing_map_sum[j] += 1; + token_needed_by_this_node = true; + } + } + + // Save the per node routing info back to rdma_to_attn_map if needed. + if(per_node_routing_info){ + rdma_to_attn_map[current_token_rdma_to_attn_map_id] = token_needed_by_this_node; + } + } + + // Each warp sum the per-rank routing info from all its threads. + #pragma unroll + for(int i = 0; i < NUM_OF_RANKS_PER_NODE; i++){ + int dst_tid = i % WARP_SIZE; + int dst_id = i / WARP_SIZE; + int32_t temp_sum = __reduce_add_sync(~0, token_routing_map_sum[i]); + if(lane_id == dst_tid){ + token_routing_map_sum[dst_id] = temp_sum; + } + } + + // Each warp store the sum of per-rank routing info back to shared memory. + #pragma unroll + for(int i = 0; i < NUM_OF_RANKS_PER_THREAD; i++){ + int element_id = i * WARP_SIZE + lane_id; + if(element_id < NUM_OF_RANKS_PER_NODE){ + warp_token_routing_map_sum[warp_id][element_id] = token_routing_map_sum[i]; + } + } + + // Sync within a CUDA block to make sure all warps have produced the per-rank sum data to the shared memory before any thread can consume them to produce CUDA block level's sum data. + __syncthreads(); + + // Step 1: Communication between CUDA blocks. Each CUDA block's threads need to produce and store the current block's per-rank sum data to global memory, + // and load and accumulate previous blocks' per-rank sum data and save the result to shared memory. + + // Each thread within a CUDA block calculate the CUDA block level sum for a single rank at a time. + for(int i = threadIdx.x; i < NUM_OF_RANKS_PER_NODE; i += NUM_THREADS_PER_BLOCK){ + int32_t rank_acc = 0; + // Calculate the sum of current rank within this CUDA block. + #pragma unroll + for(int j = 0; j < NUM_OF_WARPS_PER_BLOCK; j++){ + rank_acc += warp_token_routing_map_sum[j][i]; + } + + // Store the sum of current rank within this CUDA block to global memory for later scan opeartions. + // Strong(atomic) store is needed to be visible to strong(atomic) load from other blocks. + tmp_state_t* tmp_dst = &tmp[blockIdx.x * NUM_OF_RANKS_PER_NODE + i]; + tmp_state_t tmp_data{PRIV_SUM, rank_acc}; + uint64_t data = *reinterpret_cast(&tmp_data); + asm volatile("st.relaxed.gpu.global.b64 [%0], %1;" + : + : "l"(__cvta_generic_to_global(tmp_dst)), "l"(data) + : "memory"); + } + + // Each thread within a CUDA block load previous blocks' block level sum for a single rank at a time. + for(int i = threadIdx.x; i < NUM_OF_RANKS_PER_NODE; i += NUM_THREADS_PER_BLOCK){ + int32_t previous_block_sum_for_current_rank = 0; + for(int j = 0; j < blockIdx.x; j++){ + tmp_state_t tmp_data{EMPTY, 0}; + tmp_state_t* tmp_src = &tmp[j * NUM_OF_RANKS_PER_NODE + i]; + do{ + // Load previous blocks' per-rank sum from global memory. + // Strong(atomic) load is needed to view strong(atomic) store from other blocks. + uint64_t data = 0; + asm volatile("ld.relaxed.gpu.global.b64 %0, [%1];" + : "=l"(data) + : "l"(__cvta_generic_to_global(tmp_src)) + : "memory"); + tmp_data = *reinterpret_cast(&data); + }while(tmp_data.state != PRIV_SUM); + previous_block_sum_for_current_rank += tmp_data.value; + } + previous_block_sum[i] = previous_block_sum_for_current_rank; + } + + // Sync within a CUDA block to make sure all previous blocks' per-rank sum have been produced to the shared memory before any thread can consume them in scan operation. + __syncthreads(); + + // Step 2: Each warp scan the sub-chunk assigned to them(the same sub-chunk as step 0) and produce sparse_to_dense_map, local_expert_routing_map and num_of_tokens_for_experts. + int32_t previous_token_sum[NUM_OF_RANKS_PER_NODE]; + + // Each warp load the previous blocks' per-rank sum from shared memory. + #pragma unroll + for(int i = 0; i < NUM_OF_RANKS_PER_THREAD; i++){ + int element_id = i * WARP_SIZE + lane_id; + if(element_id < NUM_OF_RANKS_PER_NODE){ + previous_token_sum[i] = previous_block_sum[element_id]; + } + } + + // Each warp accumulate the previous warps' per-rank sum from shared memory. + #pragma unroll + for(int i = 0; i < NUM_OF_RANKS_PER_THREAD; i++){ + int element_id = i * WARP_SIZE + lane_id; + if(element_id < NUM_OF_RANKS_PER_NODE){ + for(int j = 0; j < warp_id; j++){ + previous_token_sum[i] += warp_token_routing_map_sum[j][element_id]; + } + } + } + + // Each warp broadcast the accumulated previous per-rank routing info to all its threads. + // Exact reverse of warp reduce operation. + #pragma unroll + for(int i = NUM_OF_RANKS_PER_NODE - 1; i >= 0 ; i--){ + int src_tid = i % WARP_SIZE; + int src_id = i / WARP_SIZE; + previous_token_sum[i] = __shfl_sync(~0, previous_token_sum[src_id], src_tid); + } + + // Each warp scan all the tiles within its sub-chunk. + //#pragma unroll + for(int i = 0; i < num_of_tokens_per_thread; i++){ + // The global token id conditions for current token. + int current_token_id = thread_starting_token + i * WARP_SIZE; + // If the current token is out-of-bound, then just end scanning tokens assigned to this thread. + if(current_token_id >= num_of_total_attn_tokens){ + break; + } + int current_token_node_rank = current_token_id / (num_of_tokens_per_rank * NUM_OF_RANKS_PER_NODE); + int current_token_local_rank = (current_token_id % (num_of_tokens_per_rank * NUM_OF_RANKS_PER_NODE)) / num_of_tokens_per_rank; + int current_token_local_id = current_token_id % num_of_tokens_per_rank; + + // Since some thread may end scanning earlier, we need to calculate the active mask and number of active thread. + uint32_t active_mask = __activemask(); + int active_thread_count = __popc(active_mask); + + // Global routing map load base addr for current token. + const copy_t* routing_map_load_base_addr = reinterpret_cast(input_routing_map + + current_token_id * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES) + + node_rank * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE)); + + // Load the routing map for current token. + bool token_routing_map[NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE]; + #pragma unroll + for(int j = 0; j < ROUTING_MAP_LOAD_ITER; j++){ + *(reinterpret_cast(token_routing_map) + j) = routing_map_load_base_addr[j]; + } + + // Convert the routing map to per rank routing info for current token, + // then produce the per-rank final exclusive scan within the warp for this tile. + int32_t final_ex_scan[NUM_OF_RANKS_PER_NODE]; + #pragma unroll + for(int j = 0; j < NUM_OF_RANKS_PER_NODE; j++){ + int32_t temp_scan; + bool token_needed_by_this_rank = false; + #pragma unroll + for(int k = 0; k < EXPERTS_TO_RANK_REDUCE_ITER; k++){ + int current_expert_to_rank_t_id = j * EXPERTS_TO_RANK_REDUCE_ITER + k; + expert_to_rank_t reduction_data = *(reinterpret_cast(token_routing_map) + current_expert_to_rank_t_id); + if(reduction_data != (expert_to_rank_t)0){ + token_needed_by_this_rank = true; + break; + } + } + if(token_needed_by_this_rank){ + temp_scan = 1; + }else{ + temp_scan = 0; + } + + // Each warp perform a inclusive scan from all threads(lanes). + for(int k = 1; k < active_thread_count; k *= 2){ + int32_t temp = __shfl_up_sync(active_mask, temp_scan, (unsigned)k); + if(lane_id >= k){ + temp_scan += temp; + } + } + + // The inclusive scan from last lane is the sum of this rank of this tile. Need to accumulate that for later tiles. + int32_t temp_sum = __shfl_sync(active_mask, temp_scan, active_thread_count - 1); + + // Make scan exclusive. + int32_t exclusive_scan = __shfl_up_sync(active_mask, temp_scan, 1); + temp_scan = (lane_id >= 1) ? exclusive_scan : 0; + + // Calculate the final exclusive scan for current token. -1 represent that the current rank does not need the current token. + final_ex_scan[j] = token_needed_by_this_rank ? previous_token_sum[j] + temp_scan : -1; + + // Accumulate the sum to accumulator. + previous_token_sum[j] += temp_sum; + + // Each thread save local routing map for this token of the local rank to local_expert_routing_map if this token is needed by the local rank. + if(j == local_rank && token_needed_by_this_rank){ + expert_to_rank_t* local_expert_routing_map_store_base_addr = reinterpret_cast(local_expert_routing_map + (final_ex_scan[j] * NUM_OF_EXPERTS_PER_RANK)); + #pragma unroll + for(int k = 0; k < EXPERTS_TO_RANK_REDUCE_ITER; k++){ + int current_expert_to_rank_t_id = j * EXPERTS_TO_RANK_REDUCE_ITER + k; + local_expert_routing_map_store_base_addr[k] = *(reinterpret_cast(token_routing_map) + current_expert_to_rank_t_id); + } + } + + // The thread that processing the global last token save the final sum for current rank to num_of_tokens_for_experts. + if(current_token_id == num_of_total_attn_tokens - 1 && j == local_rank){ + *num_of_tokens_for_experts = previous_token_sum[j]; + } + } + + // Save final exclusive scan of this token back to sparse_to_dense_map if current token needed. + if(current_token_local_rank == local_rank){ + // sparse_to_dense_map store base addr for current token. + write_t* sparse_to_dense_map_store_base_addr = reinterpret_cast(sparse_to_dense_map + + (current_token_node_rank * num_of_tokens_per_rank + current_token_local_id) * NUM_OF_RANKS_PER_NODE); + #pragma unroll + for(int j = 0; j < S2D_MAP_STORE_ITER; j++){ + sparse_to_dense_map_store_base_addr[j] = *(reinterpret_cast(final_ex_scan) + j); + } + } + } +} + +template< + // Hidden size of a token. + int HIDDEN_DIM, + // The max num of attn tokens output by a rank/GPU. Used by combine API. + int MAX_NUM_OF_TOKENS_PER_RANK, + // Number of ranks/GPU per NVLink domain. + int NUM_OF_RANKS_PER_NODE, + // Number of total NVLink domain, i.e. the size of RDMA domain. + int NUM_OF_NODES, + // Number of experts running on each rank/GPU. Hybrid-ep support multiple experts running on a single rank/GPU. + int NUM_OF_EXPERTS_PER_RANK> +class hybrid_ep{ +public: + + // Ctor, don't need for now. + /*hybrid_ep(int local_rank, int node_rank, MPI_Comm comm): + local_rank_(local_rank), + node_rank_(node_rank), + comm_(comm) {}*/ + + // Dtor, don't need for now. + //~hybrid_ep() {} + + // Processing metadata. Calculate routing info needed by dispatch and combine operations. + // input_routing_map: IO: input, dtype: bool, shape: [NUM_OF_TOKENS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES, NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES]. + // Routing map which contain global routing info from all tokens to all expert. Allgather is needed before passing the routing map to this API. + // preprocessing_tmp: IO: output/input, dtype: tmp_state_t, shape: [NUM_OF_BLOCKS for preprocessing kernel, NUM_OF_RANKS_PER_NODE]. + // The temp buffer needed by the preprocessing kernel. + // sparse_to_dense_map: IO: output, dtype: int32_t, shape: [NUM_OF_TOKENS_PER_RANK * NUM_OF_NODES, NUM_OF_RANKS_PER_NODE]. + // The routing info needed by NVL warps(i.e. intra-node communication warps) during both dispatch and combine operation. Remains the same in a trainning iteration(FW+BP). + // rdma_to_attn_map: IO: output, dtype: bool, shape: [NUM_OF_TOKENS_PER_RANK padded to 16 * NUM_OF_NODES] + // The routing info mainly needed by RDMA warps during the combine operation. Remains the same in a trainning iteration(FW+BP). + // attn_to_rdma_map: IO: output, dtype: bool, shape: [NUM_OF_TOKENS_PER_RANK, NUM_OF_NODES - 1]. + // The routing info mainly needed by RDMA warps during the dispatch operation. Remains the same in a trainning iteration(FW+BP). + // num_of_tokens_for_experts: IO: output, dtype: int32_t, shape: [1]. + // The total size of expert buffer on this rank(in number of tokens), according to the global routing map. If there are multiple expert on this rank, each token will only appear once. + // Remains the same in a trainning iteration(FW+BP). + // local_expert_routing_map: IO: output, dtype: bool, shape: [at least num_of_tokens_for_experts, NUM_OF_EXPERTS_PER_RANK]. + // The per-expert routing info for all tokens within the expert buffer of this rank. It is used by later layer to routing the tokens to different experts on this rank. + // Remains the same in a trainning iteration(FW+BP). + template + static void metadata_preprocessing(const bool* input_routing_map, + tmp_state_t* preprocessing_tmp, + int32_t* sparse_to_dense_map, + bool* rdma_to_attn_map, + bool* attn_to_rdma_map, + int32_t* num_of_tokens_for_experts, + bool* local_expert_routing_map, + const int node_rank, + const int local_rank, + const int num_of_tokens_per_rank, + cudaStream_t stream) + { + // Gather routing map from all ranks to all ranks. + // All ranks should have the same global routing map after this communication. + // It is a synchronous communication. + /*MPI_CHECK(MPI_Allgather(reinterpret_cast(input_routing_map), + NUM_OF_TOKENS_PER_RANK * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES), + MPI_BYTE, + reinterpret_cast(global_routing_map_), + NUM_OF_TOKENS_PER_RANK * (NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES), + MPI_BYTE, + comm_));*/ + + // Init preprocessing_tmp buffers. + constexpr size_t preprocessing_tmp_sz = NUM_OF_BLOCKS * NUM_OF_RANKS_PER_NODE * sizeof(tmp_state_t); + CUDA_CHECK(cudaMemsetAsync(preprocessing_tmp, 0, preprocessing_tmp_sz, stream)); + + // Launch the preprocessing kernel to process the global routing map. + scan + <<>> + (input_routing_map, preprocessing_tmp, sparse_to_dense_map, rdma_to_attn_map, attn_to_rdma_map, num_of_tokens_for_experts, local_expert_routing_map, node_rank, local_rank, num_of_tokens_per_rank); + + // Check if there is any CUDA error. + CUDA_CHECK(cudaGetLastError()); + } + + // Dispatch tokens or token gradient to expert MLPs. + template + static void dispatch(dispatch_kernel_param_t param, cudaStream_t stream) + { + // The warp groups data type for dispatch kernel, must match the warp groups layout required by the dispatch kernel. + using INTER_NODE_GROUP = warp_group<0, 0>; + using INTRA_NODE_G2S_GROUP = warp_group<1, 0>; + using INTRA_NODE_S2G_GROUP = warp_group<1, 1>; + // The shared memory needed by the dispatch kernel. + using dispatch_kernel_smem_t = dispatch_kernel_dynamic_shared_memory_buffer_t; + // The dispatch kernel to be launched. + const auto dispatch_kernel_ptr = dispatch_kernel; + + // Configure dynamic shared memory for the dispatch kernel. + constexpr int SMEM_SIZE = sizeof(dispatch_kernel_smem_t); + // The dispatch kernel only need to be configured once. + static bool config_completed = false; + if(!config_completed){ + // If the dynamic shared memory requested is too large, we may need to modify the carveout. + //CUDA_CHECK(cudaFuncSetAttribute(dispatch_kernel_ptr, cudaFuncAttributePreferredSharedMemoryCarveout, 100)); + CUDA_CHECK(cudaFuncSetAttribute(dispatch_kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_SIZE)); + config_completed = true; + } + + // Launch update_expected_value_kernel to update expected flag value. + update_expected_value_kernel + <<<1, 1, 0, stream>>>(param.expected_rdma_flag_value, param.expected_intra_node_flag_value); + + // Launch dispatch kernel. + constexpr int BLOCK_DIM = INTER_NODE_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTRA_NODE_S2G_GROUP::size(); + dispatch_kernel_ptr<<>>(param); + + // Launch device sync kernel if needed. + if constexpr(DEVICE_SIDE_SYNC){ + device_sync_kernel<<<1, 1, 0, stream>>>(param.intra_node_write_completion_flags, param.expected_intra_node_flag_value); + } + + // Check if there is any CUDA error. + CUDA_CHECK(cudaGetLastError()); + } + + // Combine tokens or token gradient from expert MLPs. + template + static void combine(combine_kernel_param_t param, cudaStream_t stream) + { + // The warp groups data type for combine kernel, must match the warp groups layout required by the combine kernel. + using INTRA_NODE_RED_GROUP = warp_group<0, 0>; + using INTER_NODE_RED_GROUP = warp_group<4, 0>; + using INTRA_NODE_G2S_GROUP = warp_group<0, 4>; + using INTER_NODE_G2S_GROUP = warp_group<1, 4>; + using INTER_NODE_RDMA_GROUP = warp_group<0, 5>; + + // The shared memory needed by the combine kernel. + using combine_kernel_smem_t = combine_kernel_dynamic_shared_memory_buffer_t; + // The combine kernel to be launched. + const auto combine_kernel_ptr = combine_kernel; + + // Configure dynamic shared memory for the combine kernel. + constexpr int SMEM_SIZE = sizeof(combine_kernel_smem_t); + // The combine kernel only need to be configured once. + static bool config_completed = false; + if(!config_completed){ + // If the dynamic shared memory requested is too large, we may need to modify the carveout. + //CUDA_CHECK(cudaFuncSetAttribute(combine_kernel_ptr, cudaFuncAttributePreferredSharedMemoryCarveout, 100)); + CUDA_CHECK(cudaFuncSetAttribute(combine_kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, SMEM_SIZE)); + config_completed = true; + } + + // Launch update_expected_value_kernel to update expected flag value. + update_expected_value_kernel + <<<1, 1, 0, stream>>>(param.expected_rdma_flag_value, param.expected_intra_node_flag_value); + + // Launch device sync kernel if needed. + if constexpr(DEVICE_SIDE_SYNC){ + device_sync_kernel<<<1, 1, 0, stream>>>(param.intra_node_write_completion_flags, param.expected_intra_node_flag_value); + } + + // Launch combine kernel. + constexpr int BLOCK_DIM = INTRA_NODE_RED_GROUP::size() + INTER_NODE_RED_GROUP::size() + INTRA_NODE_G2S_GROUP::size() + INTER_NODE_G2S_GROUP::size() + INTER_NODE_RDMA_GROUP::size(); + combine_kernel_ptr<<>>(param); + + // Check if there is any CUDA error. + CUDA_CHECK(cudaGetLastError()); + } + + + + /*private: + // Rank within the current node/host. + int local_rank_; + // Rank for the current node/host. + int node_rank_; + + // MPI Communicator for out-of-bond communication. + // This is used to gather routing map from all other ranks, so the communicator should contains all ranks. + MPI_Comm comm_; + + // The global routing map which collected from all other ranks, remains the same in a trainning iteration(FW+BP). + // dtype: bool, shape: [NUM_OF_TOKENS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES, NUM_OF_EXPERTS_PER_RANK * NUM_OF_RANKS_PER_NODE * NUM_OF_NODES]. + bool* global_routing_map_; + // The temp buffer needed by the preprocessing kernel. + // dtype: tmp_state_t, shape: [NUM_OF_BLOCKS for preprocessing kernel, NUM_OF_RANKS_PER_NODE]. + tmp_state_t* preprocessing_tmp_; + // The routing info needed by NVL warps(i.e. intra-node communication warps) during both dispatch and combine operation. + // Remains the same in a trainning iteration(FW+BP). + // dtype: int32_t, shape: [NUM_OF_TOKENS_PER_RANK * NUM_OF_NODES, NUM_OF_RANKS_PER_NODE]. + int32_t* sparse_to_dense_map_; + // The routing info mainly needed by RDMA warps during the combine operation. + // Remains the same in a trainning iteration(FW+BP). + // dtype: bool, shape: [NUM_OF_TOKENS_PER_RANK padded to 16 * NUM_OF_NODES]. + bool* rdma_to_attn_map_; + // The routing info mainly needed by RDMA warps during the dispatch operation. + // Remains the same in a trainning iteration(FW+BP). + // dtype: bool, shape: [NUM_OF_TOKENS_PER_RANK, NUM_OF_NODES - 1]. + bool* attn_to_rdma_map_; + // The total size of expert input/output buffer on this rank(in number of tokens), according to the global routing map. + // If there are multiple expert on this rank, each token will only appear once. + // Remains the same in a trainning iteration(FW+BP). + int32_t* num_of_tokens_for_experts_; + // The per-expert routing info for all tokens within the expert input/output buffer of this rank. + // It is used by later layer to routing the tokens to different experts on this rank. + // Remains the same in a trainning iteration(FW+BP). + // dtype: bool, shape: [at least num_of_tokens_for_experts_, NUM_OF_EXPERTS_PER_RANK]. + bool* local_expert_routing_map_;*/ +}; +} // namespace hybrid_ep + diff --git a/csrc/kernels/hybrid_ep_backend_configs.hpp b/csrc/kernels/hybrid_ep_backend_configs.hpp new file mode 100644 index 00000000..fe19e0be --- /dev/null +++ b/csrc/kernels/hybrid_ep_backend_configs.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved +#pragma once +#include + +enum class TOKEN_DATA_TYPE { UINT16, UINT8 }; + +constexpr int HIDDEN_DIM = 7168; // HIDDEN_DIM = 512xN, N in [0,1,2,....] +constexpr int MAX_NUM_OF_TOKENS_PER_RANK = 4096; // NUM_OF_TOKENS_PER_RANK = NUM_OF_TOKENS_PER_CHUNK_DISPATCH_APIxN, N in [0,1,2,....] +constexpr int NUM_OF_EXPERTS_PER_RANK = 8; // (NUM_OF_EXPERTS_PER_RANKxNUM_OF_RANKS_PER_NODE) = 4xN + +constexpr int NUM_OF_NODES = 1; // Note: this is the number of nvlink domains +constexpr int NUM_OF_RANKS_PER_NODE = 32; // Note: this is the number of ranks in each NVLink domain + +// Multi-node NVLink Staff +constexpr bool USE_MNNVLINK = true; + +// Metadata-preprocessing API Config +constexpr int NUM_THREADS_PER_BLOCK_PREPROCESSING_API = 128; +constexpr int NUM_OF_BLOCKS_PREPROCESSING_API = 32; // how much SM will be used for preprocessing + +// Dispatch API Config +constexpr int NUM_OF_STAGES_DISPATCH_API = 12; // fix to 12 +constexpr int NUM_OF_TOKENS_PER_CHUNK_DISPATCH_API = 128; // fix to 128 +constexpr int NUM_OF_BLOCKS_DISPATCH_API = 32; // how much SM will be used for dispatch +constexpr bool FORWARD_DISPATCH_API = true; +constexpr bool DEVICE_SIDE_SYNC_DISPATCH_API = true; + + + +// Combine API Config +// Combine API specific configuration. +constexpr int NUM_OF_STAGES_G2S_COMBINE_API = 12; +constexpr int NUM_OF_STAGES_S2G_COMBINE_API = 2; +constexpr int NUM_OF_TOKENS_PER_CHUNK_COMBINE_API = 128; +constexpr int NUM_OF_TOKENS_PER_GROUP_COMBINE_API = 4; +constexpr int NUM_OF_BLOCKS_COMBINE_API = 32; // how much SM will be used for combine +constexpr int NUM_OF_ADDITIONAL_IN_FLIGHT_S2G_COMBINE_API = 2; +constexpr bool BACKWARD_COMBINE_API = false; +constexpr bool DEVICE_SIDE_SYNC_COMBINE_API = true; + +struct HybridEpConfigInstance { + /* + * Hybrid-ep Config + */ + int hidden_dim; + int num_of_tokens_per_rank; + int max_num_of_tokens_per_rank; + int num_of_experts_per_rank; + int num_of_ranks_per_node; + int num_of_nodes; + + /* + * Metadata-preprocessing API Config + */ + int num_of_threads_per_block_preprocessing_api; + int num_of_blocks_preprocessing_api; + + /* + * Dispatch API Config + */ + TOKEN_DATA_TYPE token_data_type; + int num_of_stages_dispatch_api; + int num_of_tokens_per_chunk_dispatch_api; + int num_of_blocks_dispatch_api; + bool forward_dispatch_api; + bool device_side_sync_dispatch_api; + + /* + * Combine API Config + */ + int num_of_stages_g2s_combine_api; + int num_of_stages_s2g_combine_api; + int num_of_tokens_per_chunk_combine_api; + int num_of_tokens_per_group_combine_api; + int num_of_blocks_combine_api; + int num_of_additional_in_flight_s2g_combine_api; + bool backward_combine_api; + bool device_side_sync_combine_api; +}; diff --git a/deep_ep/__init__.py b/deep_ep/__init__.py index 7fb801fc..d18e97d6 100644 --- a/deep_ep/__init__.py +++ b/deep_ep/__init__.py @@ -2,6 +2,8 @@ from .utils import EventOverlap from .buffer import Buffer +from .hybrid_ep_buffer import HybridEpBuffer # noinspection PyUnresolvedReferences from deep_ep_cpp import Config +from hybrid_ep_cpp import HybridEpConfigInstance diff --git a/deep_ep/hybrid_ep_buffer.py b/deep_ep/hybrid_ep_buffer.py new file mode 100644 index 00000000..9c1c6160 --- /dev/null +++ b/deep_ep/hybrid_ep_buffer.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: MIT +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved +import torch +import hybrid_ep_cpp + +class HybridEpBuffer: + def __init__( + self, + group: torch.distributed.ProcessGroup, + hidden_dim: int, + max_num_of_tokens_per_rank: int, + num_local_experts: int, + num_of_experts: int, + use_fp8: bool = False, + num_of_ranks_per_node: int = 32, + num_sms_preprocessing_api: int = 32, + num_sms_dispatch_api: int = 32, + num_sms_combine_api: int = 32, + ): + self.group = group + self.rank = self.group.rank() + self.group_size = self.group.size() + + assert ( + self.group_size % num_of_ranks_per_node == 0 + ), f"The number of ranks {self.group_size} should be divisible by the number of ranks per node {num_of_ranks_per_node}." + assert ( + self.group_size > 1 + ), f"The hybrid-ep kernel should be used with at least 2 ranks, but got {self.group_size}." + self.num_of_ranks_per_node = num_of_ranks_per_node + + # Local rank: the active rank in the nvlink domain. + self.local_rank = self.rank % self.num_of_ranks_per_node + # Node rank: the active rank between the nvlink domains. + self.node_rank = self.rank // self.num_of_ranks_per_node + # The number of nodes. + self.num_of_nodes = self.group_size // self.num_of_ranks_per_node + + self.hidden_dim = hidden_dim + self.max_num_of_tokens_per_rank = max_num_of_tokens_per_rank + self.num_local_experts = num_local_experts + self.num_of_experts = num_of_experts + self.use_fp8 = use_fp8 + self.num_sms_preprocessing_api = num_sms_preprocessing_api + self.num_sms_dispatch_api = num_sms_dispatch_api + self.num_sms_combine_api = num_sms_combine_api + self.init_config() + self.init_buffer() + + def init_config( + self, + num_of_threads_per_block_preprocessing_api: int = 512, + num_of_stages_dispatch_api: int = 12, + num_of_tokens_per_chunk_dispatch_api: int = 128, + device_side_sync_dispatch_api: bool = True, + num_of_stages_g2s_combine_api: int = 12, + num_of_stages_s2g_combine_api: int = 2, + num_of_tokens_per_chunk_combine_api: int = 128, + num_of_tokens_per_group_combine_api: int = 4, + num_of_additional_in_flight_s2g_combine_api: int = 2, + device_side_sync_combine_api: bool = True, + ): + """ + Initialize the HybridEpConfigInstance for the hybrid-ep kernel. + We can contoal the detailed setting of the hybrid-ep kernel. + In common case, no need to change the default setting. + """ + config = hybrid_ep_cpp.HybridEpConfigInstance() + + # Initialize the HybridEpConfigInstance + # Hybrid-ep Config + config.hidden_dim = self.hidden_dim + config.max_num_of_tokens_per_rank = self.max_num_of_tokens_per_rank + config.num_of_tokens_per_rank = self.max_num_of_tokens_per_rank # init to max_num_of_tokens_per_rank, will be updated in dispatch + config.num_of_experts_per_rank = self.num_local_experts + config.num_of_ranks_per_node = self.num_of_ranks_per_node + config.num_of_nodes = self.num_of_nodes + # Metadata-preprocessing API Config + config.num_of_threads_per_block_preprocessing_api = ( + num_of_threads_per_block_preprocessing_api + ) + config.num_of_blocks_preprocessing_api = self.num_sms_preprocessing_api + # Dispatch API Config + if self.use_fp8: + # The fp8 data is communicated in the uint8 format. + config.token_data_type = hybrid_ep_cpp.UINT8 + else: + # The bf16 data is communicated in the uint16 format. + config.token_data_type = hybrid_ep_cpp.UINT16 + config.num_of_stages_dispatch_api = num_of_stages_dispatch_api + config.num_of_tokens_per_chunk_dispatch_api = ( + num_of_tokens_per_chunk_dispatch_api + ) + config.num_of_blocks_dispatch_api = self.num_sms_dispatch_api + config.device_side_sync_dispatch_api = device_side_sync_dispatch_api + # Combine API Config + config.num_of_stages_g2s_combine_api = num_of_stages_g2s_combine_api + config.num_of_stages_s2g_combine_api = num_of_stages_s2g_combine_api + config.num_of_tokens_per_chunk_combine_api = num_of_tokens_per_chunk_combine_api + config.num_of_tokens_per_group_combine_api = num_of_tokens_per_group_combine_api + config.num_of_blocks_combine_api = self.num_sms_combine_api + config.num_of_additional_in_flight_s2g_combine_api = ( + num_of_additional_in_flight_s2g_combine_api + ) + config.device_side_sync_combine_api = device_side_sync_combine_api + + self.config = config + + def init_buffer(self): + """ + Initialize the buffer for the hybrid-ep kernel. + Creates the C++ buffer (which allocates buffers) and exchanges IPC addresses. + """ + assert self.config is not None, "Please initialize the config first." + # Create C++ buffer - this will allocate all buffers during construction + self.runtime = hybrid_ep_cpp.HybridEpBuffer( + self.config, self.rank, self.group_size, self.num_of_ranks_per_node + ) + + # Exchange IPC addresses using C++ distributed communication + self.runtime.exchange_ipc_address(self.group) + + def dispatch( + self, + tensor: torch.Tensor, + scaling_factor: torch.Tensor = None, + topk_idx: torch.Tensor = None, + topk_weights: torch.Tensor = None, + routing_map: torch.Tensor = None, + num_of_tokens_for_experts: int = -1, + handle: tuple = None, + async_mode: bool = False, + ): + """ + Dispatch the data to the experts. + + Forward direction: + dispatch_in_forward -> local_permute -> epxert_mlp -> local_unpermute -> combine_in_forward + + Backward direction: + combine_in_backward <- local_unpermute -> expert_mlp -> local_permute -> dispatch_in_backward + """ + num_of_tokens = tensor.shape[0] + # Update the num_of_tokens_per_rank, both dispatch and combine will use this value + self.runtime.update_num_of_tokens_per_rank(num_of_tokens) + routing_map_as_input_and_probs_as_output = routing_map is not None + if routing_map is not None: + assert routing_map.dtype == torch.bool + else: + # Generate the routing map and the probs according to the topk_idx and topk_weights. + assert topk_idx is not None + routing_map = torch.zeros(num_of_tokens, self.num_of_experts, device="cuda", dtype=torch.bool) + routing_map = routing_map.scatter(1, topk_idx.to(torch.int64), 1).bool() + if topk_weights is not None: + probs = torch.zeros(num_of_tokens, self.num_of_experts, device="cuda", dtype=torch.float32) + probs = probs.scatter(1, topk_idx.to(torch.int64), topk_weights) + else: + probs = None + + global_routing_map = torch.empty( + num_of_tokens * self.group_size, + self.num_of_experts, + device="cuda", + dtype=torch.bool, + ) + assert ( + handle is not None or routing_map is not None + ), "The handle and routing_map should be both None" + # If the handle is not provided, we need to generate the handle using the preprocessing kernel. + if handle is None: + torch.distributed.all_gather_into_tensor( + global_routing_map, routing_map, self.group + ) + ( + sparse_to_dense_map, + rdma_to_attn_map, + attn_to_rdma_map, + num_of_tokens_for_experts_tensor, + local_expert_routing_map, + ) = self.runtime.metadata_preprocessing( + routing_map=global_routing_map, + node_rank=self.node_rank, + local_rank=self.local_rank, + ) + # Create the handle using the data generated by the preprocessing kernel. + handle = ( + sparse_to_dense_map, + rdma_to_attn_map, + attn_to_rdma_map, + ) + if not async_mode: + num_of_tokens_for_experts = num_of_tokens_for_experts_tensor.item() + else: + ( + sparse_to_dense_map, + rdma_to_attn_map, + attn_to_rdma_map, + ) = handle + num_of_tokens_for_experts_tensor = None + local_expert_routing_map = None + if not async_mode: + assert ( + num_of_tokens_for_experts >= 0 + ), "The num_of_tokens_for_experts should be provided." + + dispatched_token, dispatched_probs, dispatched_scaling_factor = ( + self.runtime.dispatch( + hidden=tensor, + probs=probs, + scaling_factor=scaling_factor, + sparse_to_dense_map=sparse_to_dense_map, + rdma_to_attn_map=rdma_to_attn_map, + attn_to_rdma_map=attn_to_rdma_map, + num_of_tokens_for_experts=( + num_of_tokens_for_experts if not async_mode else -1 + ), + with_probs=probs is not None, + ) + ) + + return ( + dispatched_token, + dispatched_probs, + dispatched_scaling_factor, + num_of_tokens_for_experts_tensor, + local_expert_routing_map, + handle, + ) + + def combine( + self, tensor: torch.Tensor, probs: torch.Tensor = None, handle: tuple = None + ): + """ + Combine the data from the experts. + Do not require preprocessing, but the handle is necessary. + """ + assert handle is not None, "The handle is necessary for combine." + sparse_to_dense_map, rdma_to_attn_map, attn_to_rdma_map = handle + combined_token, combined_probs = self.runtime.combine( + hidden=tensor, + probs=probs, + sparse_to_dense_map=sparse_to_dense_map, + rdma_to_attn_map=rdma_to_attn_map, + attn_to_rdma_map=attn_to_rdma_map, + with_probs=probs is not None, + ) + return combined_token, combined_probs diff --git a/setup.py b/setup.py index 63ce332b..19724497 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,51 @@ def get_nvshmem_host_lib_name(base_dir): return file.name raise ModuleNotFoundError('libnvshmem_host.so not found') +def get_extension_hybrid_ep_cpp(): + current_dir = os.path.dirname(os.path.abspath(__file__)) + USE_GB200 = os.getenv("USE_GB200", "False") + enable_multinode = os.getenv("HYBRID_EP_MULTINODE", "0") != "0" + assert not enable_multinode, "Multinode is not supported yet" + + # Basic compile arguments + compile_args = { + "cxx": [ + "-std=c++17", + "-O3", + ], + "nvcc": [ + "-std=c++17", + "-Xcompiler", + "-fPIC", + "--expt-relaxed-constexpr", + "-O3", + "--shared", + ], + } + if USE_GB200 == "True": + compile_args["nvcc"].append("-DUSE_GB200") + + sources = [ + os.path.join(current_dir, "csrc/hybrid_ep.cu"), + ] + include_dirs = [ + os.path.join(current_dir, "csrc"), + ] + extra_link_args = [ + "-lnvtx3interop", + ] + libraries = ["cuda"] + + extension_hybrid_ep_cpp = CUDAExtension( + "hybrid_ep_cpp", + sources=sources, + include_dirs=include_dirs, + libraries=libraries, + extra_compile_args=compile_args, + extra_link_args=extra_link_args, + ) + + return extension_hybrid_ep_cpp if __name__ == '__main__': disable_nvshmem = False @@ -120,7 +165,8 @@ def get_nvshmem_host_lib_name(base_dir): sources=sources, extra_compile_args=extra_compile_args, extra_link_args=extra_link_args - ) + ), + get_extension_hybrid_ep_cpp() ], cmdclass={ 'build_ext': BuildExtension diff --git a/tests/test_mnnvlink_hybridep.py b/tests/test_mnnvlink_hybridep.py new file mode 100644 index 00000000..9e786b32 --- /dev/null +++ b/tests/test_mnnvlink_hybridep.py @@ -0,0 +1,274 @@ +import argparse +import time +import torch +import torch.distributed as dist +import os +import deep_ep + +from utils import TorchRef, bench, bench_kineto + +HIDDEN_DIM = 7168 +MAX_NUM_OF_TOKENS_PER_RANK = 4096 +# NUM_TOKENS_PER_RANK should equal or less than MAX_NUM_OF_TOKENS_PER_RANK +NUM_TOKENS_PER_RANK = 4096 +NUM_LOCAL_EXPERTS = 8 +NUM_OF_RANKS_PER_NODE = 32 +TOPK = 8 +NUM_OF_EXPERTS = NUM_LOCAL_EXPERTS * NUM_OF_RANKS_PER_NODE +ITERATIONS = 100 +SEED = 42 +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.cuda.manual_seed_all(SEED) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +def init_dist(local_rank: int, num_local_ranks: int): + # NOTES: you may rewrite this function with your own cluster settings + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + node_rank = int(os.getenv('RANK', 0)) + # local_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + # Call the init process. + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + + dist.init_process_group( + backend="nccl", + init_method=f'tcp://{ip}:{port}', + world_size=num_nodes * num_local_ranks, + rank=node_rank * num_local_ranks + local_rank, + ) + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device("cuda") + torch.cuda.set_device(local_rank) + + return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) + + +def init_tensor( + hidden_dim: int, + seq_len: int, + topk: int, + num_of_experts: int, + use_fp8: bool = False, +): + if use_fp8: + hidden = torch.randint( + low=0, + high=256, + size=(seq_len, hidden_dim), + device="cuda", + dtype=torch.uint8, + ) + else: + hidden = torch.randn(seq_len, hidden_dim, device="cuda", dtype=torch.bfloat16) + probs = torch.zeros(seq_len, num_of_experts, device="cuda", dtype=torch.float32) + topk_idx = torch.zeros(seq_len, topk, device="cuda", dtype=torch.int64) + topk_weights = torch.zeros(seq_len, topk, device="cuda", dtype=torch.float32) + scaling_factor = torch.randn( + seq_len, hidden_dim // 128, device="cuda", dtype=torch.float32 + ) + + routing_map = torch.zeros(seq_len, num_of_experts, device="cuda", dtype=torch.bool) + + for i in range(seq_len): + selected_experts = torch.randperm(num_of_experts, device="cuda")[:topk] + topk_idx[i, :] = selected_experts.to(torch.int64) + topk_weights[i, :] = torch.rand(topk, device="cuda", dtype=torch.float32) + # selected_experts = [0,8,16,24,32,40,48,56] # force balanced routing for testing + routing_map[i, selected_experts] = True + probs[i, selected_experts] = topk_weights[i, :] + + return hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights + + +def test_intra_node_correctness(buffer: deep_ep.HybridEpBuffer, ref: TorchRef, use_fp8: bool): + hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights = init_tensor( + hidden_dim=HIDDEN_DIM, + seq_len=NUM_TOKENS_PER_RANK, + topk=TOPK, + num_of_experts=NUM_OF_EXPERTS, + use_fp8=use_fp8, + ) + + # Dispatch correctness check + for with_probs in [True, False]: + # The check for the dispatch + dispatched_hidden_ref, dispatched_probs_ref, dispatched_scaling_factor_ref = ( + ref.dispatch( + hidden, routing_map, probs if with_probs else None, scaling_factor + ) + ) + ( + dispatched_hidden, + dispatched_probs, + dispatched_scaling_factor, + num_of_tokens_for_experts, + local_expert_routing_map, + handle, + ) = buffer.dispatch( + tensor=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights if with_probs else None + ) + + assert torch.allclose(dispatched_hidden_ref, dispatched_hidden) + if dispatched_probs is not None and dispatched_probs_ref is not None: + start, end = ref._local_expert_range() + masked_probs = torch.zeros_like(dispatched_probs) + masked_probs[:, start:end] = dispatched_probs[:, start:end] + assert torch.allclose(dispatched_probs_ref, dispatched_probs[:, start:end]) + dispatched_probs = masked_probs + if ( + dispatched_scaling_factor is not None + and dispatched_scaling_factor_ref is not None + ): + assert torch.allclose( + dispatched_scaling_factor_ref, dispatched_scaling_factor + ) + + # expert the local routing map from the local routing map + num_of_tokens_for_experts = num_of_tokens_for_experts.cpu() + local_expert_routing_map = local_expert_routing_map[ + : num_of_tokens_for_experts.item() + ] + # Simulate the permute and expert and unpermute. The expert is identity op + copy_times = local_expert_routing_map.sum(dim=1) + dispatched_hidden = dispatched_hidden.to( + torch.bfloat16 + ) # The combine only support bf16 + hidden_to_combine = dispatched_hidden * copy_times.unsqueeze(1) + probs_to_combine = dispatched_probs + + # The check for the combine + combined_hidden, combined_probs = buffer.combine( + hidden_to_combine, probs_to_combine, handle + ) + + # The reconstucted value should be TOPK times larger than the input hidden + combined_hidden = combined_hidden / TOPK + + assert torch.allclose( + combined_hidden, hidden.to(torch.bfloat16), atol=2e-5, rtol=1e-2 + ) + if combined_probs is not None and probs is not None: + assert torch.allclose(combined_probs, probs, atol=2e-5, rtol=1e-2) + + if torch.distributed.get_rank() == 0: + print("Correctness check passed") + + +def test_intra_node_benchmark(buffer: deep_ep.HybridEpBuffer, group: dist.ProcessGroup, use_fp8: bool, nsys_profile: bool): + hidden, probs, scaling_factor, routing_map, topk_idx, topk_weights = init_tensor( + hidden_dim=HIDDEN_DIM, + seq_len=NUM_TOKENS_PER_RANK, + topk=TOPK, + num_of_experts=NUM_OF_EXPERTS, + use_fp8=use_fp8, + ) + + # warmup + for _ in range(10): + dispatched_hidden, dispatched_probs, _, _, _, handle = ( + buffer.dispatch(tensor=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights) + ) + if dispatched_hidden.dtype == torch.uint8: + dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16) + else: + dispatched_hidden_bf16 = dispatched_hidden + dispatched_probs = None + _, _ = buffer.combine(dispatched_hidden_bf16, dispatched_probs, handle) + + rank = torch.distributed.get_rank() + fp8_factor = (1 + 4 / 128) / 2 + dispatch_bf16_nvl_recv_bytes = dispatched_hidden.numel() * 2 + combine_bf16_nvl_send_bytes = dispatch_bf16_nvl_recv_bytes + + dispatch_args = {'tensor': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights} + t = bench(lambda: buffer.dispatch(**dispatch_args))[0] + nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if hidden.dtype == torch.uint8 else dispatch_bf16_nvl_recv_bytes + print(f'[rank {rank}] HybridEP dispatch torch API ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): ' + f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, nvl_recv_bytes: {nvl_recv_bytes / 1e6:.2f} MB', flush=True) + + dispatched_hidden, dispatched_probs, _, _, _, handle= ( + buffer.dispatch(tensor=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights) + ) + combine_args = {'tensor': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle} + t = bench(lambda: buffer.combine(**combine_args))[0] + print(f'[rank {rank}] HybridEP combine torch API: ' + f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), t: {t * 1e6:.2f} us, combine_send_bytes: {combine_bf16_nvl_send_bytes / 1e6:.2f} MB', flush=True) + + + if not nsys_profile: + # noinspection PyShadowingNames + def test_func(): + dispatched_hidden, dispatched_probs, _, _, _, handle = ( + buffer.dispatch(tensor=hidden, scaling_factor=scaling_factor, topk_idx=topk_idx, topk_weights=topk_weights) + ) + if dispatched_hidden.dtype == torch.uint8: + dispatched_hidden_bf16 = dispatched_hidden.to(torch.bfloat16) + else: + dispatched_hidden_bf16 = dispatched_hidden + dispatched_probs = None + _, _ = buffer.combine(dispatched_hidden_bf16, dispatched_probs, handle) + + group.barrier() + dispatch_t, combine_t = bench_kineto(test_func, + kernel_names=('dispatch_kernel', 'combine_kernel'), barrier_comm_profiling=True, + suppress_kineto_output=True) + print(f'[rank {rank}] HybridEP dispatch kernel ({"FP8" if hidden.dtype == torch.uint8 else "BF16"}): {nvl_recv_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' + f'HybridEP combine kernel: {combine_bf16_nvl_send_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', flush=True) + else: + torch.cuda.profiler.start() + with torch.cuda.nvtx.range(f"hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})"): + if rank == 0: + print(f"profile hybrid-ep dispatch ({"FP8" if hidden.dtype == torch.uint8 else "BF16"})", flush=True) + dispatch_args = {'tensor': hidden, 'scaling_factor': scaling_factor, 'topk_idx': topk_idx, 'topk_weights': topk_weights} + bench(lambda: buffer.dispatch(**dispatch_args)) + with torch.cuda.nvtx.range("hybrid-ep combine"): + if rank == 0: + print(f"profile hybrid-ep combine", flush=True) + combine_args = {'tensor': dispatched_hidden_bf16, 'probs': dispatched_probs, 'handle': handle} + bench(lambda: buffer.combine(**combine_args)) + time.sleep(1) + torch.cuda.profiler.stop() + + +def test_main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + buffer = deep_ep.HybridEpBuffer( + group=group, + hidden_dim=HIDDEN_DIM, + max_num_of_tokens_per_rank=MAX_NUM_OF_TOKENS_PER_RANK, + num_local_experts=NUM_LOCAL_EXPERTS, + num_of_experts=NUM_OF_EXPERTS, + use_fp8=args.use_fp8, + num_of_ranks_per_node=NUM_OF_RANKS_PER_NODE, + ) + + # Initialize the torchRef + ref = TorchRef( + ep_group=group, + num_of_experts=NUM_OF_EXPERTS, + num_of_ranks_per_node=NUM_OF_RANKS_PER_NODE, + ) + + # Test body + test_intra_node_correctness(buffer, ref, args.use_fp8) + test_intra_node_benchmark(buffer, group, args.use_fp8, args.nsys_profile) + + # Destroy + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Test intranode EP kernels') + parser.add_argument('--num-processes', type=int, default=4, + help='Number of processes to spawn (default: 4)') + parser.add_argument('--use-fp8', action='store_true', default=False, + help='Use fp8 in dispatch or not (default: False)') + parser.add_argument('--nsys-profile', action='store_true', default=False, + help='benchmark with nsys profile or not (default: False)') + args = parser.parse_args() + torch.multiprocessing.spawn(test_main, args=(args.num_processes, args), nprocs=args.num_processes) diff --git a/tests/utils.py b/tests/utils.py index 4cdcd876..c2b70f75 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -234,3 +234,127 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr def hash_tensor(t: torch.Tensor): return t.view(torch.int).sum().item() + + +class TorchRef: + def __init__( + self, + ep_group: torch.distributed.ProcessGroup, + num_of_experts: int, + num_of_ranks_per_node: int, + ): + self.ep_group = ep_group + self.group_rank = torch.distributed.get_rank(self.ep_group) + self.group_size = torch.distributed.get_world_size(self.ep_group) + self.num_of_ranks_per_node = num_of_ranks_per_node + # at least one node + self.num_of_nodes = max(1, self.group_size // self.num_of_ranks_per_node) + self.local_rank = self.group_rank % self.num_of_ranks_per_node + + self.num_of_experts = num_of_experts + self.num_local_experts = num_of_experts // self.num_of_ranks_per_node + + def _local_expert_range(self): + start = self.local_rank * self.num_local_experts + end = start + self.num_local_experts # [start, end) + return start, end + + def _select_local_tokens( + self, + global_hidden: torch.Tensor, + global_probs: torch.Tensor, + global_scaling_factor: torch.Tensor | None, + global_routing_map: torch.Tensor, + ): + start, end = self._local_expert_range() + row_mask = global_routing_map[:, start:end].any(dim=1) + + dispatched_hidden = global_hidden[row_mask] + dispatched_probs = ( + global_probs[row_mask, start:end] if global_probs is not None else None + ) + dispatched_scaling_factor = ( + global_scaling_factor[row_mask] + if global_scaling_factor is not None + else None + ) + + return ( + dispatched_hidden, + dispatched_probs, + dispatched_scaling_factor, + ) + + def dispatch( + self, + hidden: torch.Tensor, + routing_map: torch.Tensor, + probs: torch.Tensor = None, + scaling_factor: torch.Tensor = None, + ): + seq_len, hidden_dim = hidden.shape + # Cache sizes for combine + self._last_seq_len = seq_len + self._last_hidden_dim = hidden_dim + # gather the routing map + global_routing_map = torch.empty( + seq_len * self.group_size, + self.num_of_experts, + device=hidden.device, + dtype=torch.bool, + ) + torch.distributed.all_gather_into_tensor( + global_routing_map, routing_map, self.ep_group + ) + + # dispatch the hidden tensor + global_hidden = torch.empty( + seq_len * self.group_size, + hidden_dim, + device=hidden.device, + dtype=hidden.dtype, + ) + torch.distributed.all_gather_into_tensor(global_hidden, hidden, self.ep_group) + + # dispatch the probs tensor + if probs is not None: + global_probs = torch.empty( + seq_len * self.group_size, + self.num_of_experts, + device=probs.device, + dtype=probs.dtype, + ) + torch.distributed.all_gather_into_tensor(global_probs, probs, self.ep_group) + else: + global_probs = None + + # dispatch the scaling factor tensor + if scaling_factor is not None: + global_scaling_factor = torch.empty( + seq_len * self.group_size, + hidden_dim // 128, + device=scaling_factor.device, + dtype=scaling_factor.dtype, + ) + torch.distributed.all_gather_into_tensor( + global_scaling_factor, scaling_factor, self.ep_group + ) + else: + global_scaling_factor = None + + ( + dispatched_hidden, + dispatched_probs, + dispatched_scaling_factor, + ) = self._select_local_tokens( + global_hidden=global_hidden, + global_probs=global_probs, + global_scaling_factor=global_scaling_factor, + global_routing_map=global_routing_map, + ) + + return ( + dispatched_hidden, + dispatched_probs, + dispatched_scaling_factor, + )