Skip to content
Open
38 changes: 36 additions & 2 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1094,8 +1094,10 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Te
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
const std::optional<torch::Tensor>& x_global_scale,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_nvfp4, bool use_ue8m0_for_sf,
bool async, bool return_recv_hook) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);
Expand Down Expand Up @@ -1140,8 +1142,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
stream_wait(launch_stream, compute_stream);

// Allocate packed tensors
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden},
x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16));
auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / 2 : hidden},
x.options().dtype(use_nvfp4 ? torch::kUInt8 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)));
auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA));
auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA));
Expand All @@ -1151,6 +1153,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
void* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");

EP_HOST_ASSERT(not (use_fp8 and use_nvfp4));
if (use_fp8) {
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
Expand All @@ -1164,6 +1167,35 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
}
packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2);
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
}else if (use_nvfp4) {
constexpr int kNumPerChannels = 16;
constexpr int NUM_SF_ELEMS_PER_PACK = 4;
constexpr int mTileSize_dim_0 = 32;
constexpr int mTileSize_dim_1 = 4;
constexpr int mTileSize = mTileSize_dim_0 * mTileSize_dim_1;

assert(hidden % kNumPerChannels == 0);
auto l = num_local_experts;
auto m = num_ranks * num_max_dispatch_tokens_per_rank;
auto rm = (m + 127) / 128;
auto rk = (hidden + (kNumPerChannels * NUM_SF_ELEMS_PER_PACK) -1 ) / (kNumPerChannels * NUM_SF_ELEMS_PER_PACK);
// The physical layout is (l, rm, rk, 32, 4, 4).
if (use_ue8m0_for_sf) {
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
torch::dtype(torch::kInt).device(torch::kCUDA));
} else {
packed_recv_x_scales = torch::empty({l, rm, rk, 32, 4, 4},
torch::dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA));
}
// After permute, the logical shape is (32, 4, rm, 4, rk, l)
packed_recv_x_scales = packed_recv_x_scales.value().permute({3, 4, 1, 5, 2, 0});

// The physical layout is (l, m, k // 2).
// After permute, the logical shape is (m, k // 2, l).
packed_recv_x = packed_recv_x.permute({1, 2, 0});

packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
EP_HOST_ASSERT(packed_recv_x_scales_ptr != nullptr);
}

// Kernel launch
Expand All @@ -1174,13 +1206,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_count.data_ptr<int>(),
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
x_global_scale.has_value() ? x_global_scale->data_ptr<float>() : nullptr,
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
buffer.dispatch_rdma_send_buffer,
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0,
use_nvfp4, use_ue8m0_for_sf,
workspace, num_device_sms,
launch_stream, phases);
};
Expand Down
2 changes: 2 additions & 0 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ struct Buffer {
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
const std::optional<torch::Tensor>& x_global_scale,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_nvfp4, bool use_ue8m0_for_sf,
bool async, bool return_recv_hook);

std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
Expand Down
2 changes: 2 additions & 0 deletions csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
const float* x_global_scale,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_nvfp4, bool use_ue8m0_for_sf,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases);

Expand Down
Loading