Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
396 commits
Select commit Hold shift + click to select a range
f5d6fe2
more
fzyzcjy Aug 27, 2025
f1c9d16
more
fzyzcjy Aug 27, 2025
1191e61
more
fzyzcjy Aug 27, 2025
39ba248
fix
fzyzcjy Aug 27, 2025
4528bba
more
fzyzcjy Aug 27, 2025
a25bc01
more
fzyzcjy Aug 27, 2025
2fbb21d
more
fzyzcjy Aug 27, 2025
3180d23
more
fzyzcjy Aug 27, 2025
14ecba2
more
fzyzcjy Aug 27, 2025
8b6af18
more
fzyzcjy Aug 27, 2025
e8c33a4
more
fzyzcjy Aug 27, 2025
945f4a7
more
fzyzcjy Aug 27, 2025
c1b586f
Revert "fix"
fzyzcjy Aug 27, 2025
b1d5959
more
fzyzcjy Aug 27, 2025
ed6ae45
more
fzyzcjy Aug 27, 2025
e5f369c
more
fzyzcjy Aug 28, 2025
bf8c15a
more
fzyzcjy Aug 28, 2025
614d452
more
fzyzcjy Aug 28, 2025
8166727
more
fzyzcjy Aug 28, 2025
e474017
more
fzyzcjy Aug 28, 2025
db7af14
more
fzyzcjy Aug 28, 2025
bf13e6a
more
fzyzcjy Aug 28, 2025
98a22cd
more
fzyzcjy Aug 28, 2025
f521d34
more
fzyzcjy Aug 28, 2025
5a2fa4e
fix
fzyzcjy Aug 28, 2025
8d82ae7
more
fzyzcjy Aug 28, 2025
159a8d4
more
fzyzcjy Aug 28, 2025
6c8772c
more
fzyzcjy Aug 28, 2025
acaf1c1
more
fzyzcjy Aug 28, 2025
31815d7
more
fzyzcjy Aug 28, 2025
b797e50
more
fzyzcjy Aug 28, 2025
9e81422
more
fzyzcjy Aug 28, 2025
aae399f
more
fzyzcjy Aug 28, 2025
20865f7
more
fzyzcjy Aug 28, 2025
81aff45
more
fzyzcjy Aug 29, 2025
77be765
more
fzyzcjy Aug 29, 2025
33d4025
more
fzyzcjy Aug 29, 2025
7005ae8
more
fzyzcjy Aug 29, 2025
30b2938
more
fzyzcjy Aug 29, 2025
d6fbbcb
more
fzyzcjy Aug 29, 2025
74aec6c
more
fzyzcjy Aug 29, 2025
4e42c82
more
fzyzcjy Aug 29, 2025
0d1488a
more
fzyzcjy Aug 29, 2025
85234e3
more
fzyzcjy Aug 29, 2025
946d954
more
fzyzcjy Aug 29, 2025
5b0fc04
more
fzyzcjy Aug 29, 2025
9a9efe9
more
fzyzcjy Aug 29, 2025
dbe13bf
more
fzyzcjy Aug 29, 2025
edd3e94
more
fzyzcjy Aug 29, 2025
c85d29e
more
fzyzcjy Aug 29, 2025
4fec000
more
fzyzcjy Aug 29, 2025
c7e3bec
more
fzyzcjy Aug 29, 2025
52fa6b4
more
fzyzcjy Aug 29, 2025
4afcd14
more
fzyzcjy Aug 29, 2025
4d65f42
more
fzyzcjy Aug 29, 2025
d07bded
more
fzyzcjy Aug 29, 2025
3592fcc
more
fzyzcjy Aug 29, 2025
7fd3718
more
fzyzcjy Aug 29, 2025
2a5ca59
more
fzyzcjy Aug 29, 2025
8632fe5
more
fzyzcjy Aug 29, 2025
123e7ec
more
fzyzcjy Aug 29, 2025
f71758c
more
fzyzcjy Aug 29, 2025
36ed956
more
fzyzcjy Aug 29, 2025
a49eac3
more
fzyzcjy Aug 29, 2025
700c772
more
fzyzcjy Aug 29, 2025
8d4bfb5
more
fzyzcjy Aug 29, 2025
46042a0
more
fzyzcjy Aug 29, 2025
1e5b911
more
fzyzcjy Aug 29, 2025
a9a6388
more
fzyzcjy Aug 29, 2025
2834f12
more
fzyzcjy Aug 29, 2025
0e679e4
more
fzyzcjy Aug 29, 2025
a4b8345
Revert "more"
fzyzcjy Aug 29, 2025
45deb5d
Revert "more"
fzyzcjy Aug 29, 2025
3291a40
Revert "more"
fzyzcjy Aug 29, 2025
4c708a7
rm cond
fzyzcjy Aug 29, 2025
f5cf546
more
fzyzcjy Aug 29, 2025
7c37192
more
fzyzcjy Aug 29, 2025
5736c08
more
fzyzcjy Aug 29, 2025
c6fd0cd
moew
fzyzcjy Aug 29, 2025
e2378e7
more
fzyzcjy Aug 29, 2025
ee84ee0
more
fzyzcjy Aug 29, 2025
f1fd7e3
more
fzyzcjy Aug 29, 2025
a4c50da
more
fzyzcjy Aug 29, 2025
23cf179
more
fzyzcjy Aug 29, 2025
550a037
more
fzyzcjy Aug 29, 2025
0323a11
more
fzyzcjy Aug 29, 2025
8e2d277
more
fzyzcjy Aug 29, 2025
621535a
more
fzyzcjy Aug 29, 2025
75ec102
more
fzyzcjy Aug 29, 2025
3ff888a
more
fzyzcjy Aug 29, 2025
e791b3e
more
fzyzcjy Aug 29, 2025
b57ec16
more
fzyzcjy Aug 29, 2025
030cdf0
more
fzyzcjy Aug 29, 2025
86c6c7e
more
fzyzcjy Aug 29, 2025
2c53894
more
fzyzcjy Aug 29, 2025
4b21761
more
fzyzcjy Aug 29, 2025
38394eb
more
fzyzcjy Aug 29, 2025
52bbde5
more
fzyzcjy Aug 29, 2025
05ac9c3
more
fzyzcjy Aug 29, 2025
cfbfeb9
more
fzyzcjy Aug 29, 2025
729cfcc
more
fzyzcjy Aug 29, 2025
aba542e
more
fzyzcjy Aug 29, 2025
94d1032
more
fzyzcjy Aug 29, 2025
e6e2349
more
fzyzcjy Aug 29, 2025
83485f0
more
fzyzcjy Aug 29, 2025
bcba00f
more
fzyzcjy Aug 29, 2025
1727b57
more
fzyzcjy Aug 29, 2025
f1b79fa
nire
fzyzcjy Aug 29, 2025
32a93d1
more
fzyzcjy Aug 29, 2025
0c57568
more
fzyzcjy Aug 29, 2025
398ea19
more
fzyzcjy Aug 29, 2025
09a779a
more
fzyzcjy Aug 29, 2025
040decc
more
fzyzcjy Aug 29, 2025
dc90347
more
fzyzcjy Aug 29, 2025
c67ad13
more
fzyzcjy Aug 29, 2025
93a7d9c
more
fzyzcjy Aug 29, 2025
a3e21f4
more
fzyzcjy Aug 29, 2025
c4262e2
more
fzyzcjy Aug 29, 2025
df7942a
more
fzyzcjy Aug 29, 2025
8dca043
more
fzyzcjy Aug 29, 2025
78b5443
more
fzyzcjy Aug 29, 2025
2764ddc
more
fzyzcjy Aug 29, 2025
6435f7f
more
fzyzcjy Aug 29, 2025
d812f47
more
fzyzcjy Aug 29, 2025
3f2f02b
more
fzyzcjy Aug 29, 2025
700d1b8
more
fzyzcjy Aug 29, 2025
471dd36
morwe
fzyzcjy Aug 29, 2025
36f4669
more
fzyzcjy Aug 29, 2025
1e4aab8
more
fzyzcjy Aug 29, 2025
4c81bb2
more
fzyzcjy Aug 29, 2025
3eae3a2
more
fzyzcjy Aug 29, 2025
67805f1
more
fzyzcjy Aug 29, 2025
440bbd2
more
fzyzcjy Aug 29, 2025
b8b4d35
more
fzyzcjy Aug 29, 2025
33cbd0d
more
fzyzcjy Aug 29, 2025
c2e0cee
more
fzyzcjy Aug 29, 2025
7a96ac8
more
fzyzcjy Aug 29, 2025
83cebc0
more
fzyzcjy Aug 29, 2025
ddf364a
more
fzyzcjy Aug 29, 2025
9176b41
more
fzyzcjy Aug 29, 2025
7922f70
more
fzyzcjy Aug 29, 2025
1674d26
more
fzyzcjy Aug 29, 2025
b3521a0
fix
fzyzcjy Aug 29, 2025
5d2c128
rename
fzyzcjy Aug 29, 2025
abb6aea
more
fzyzcjy Aug 29, 2025
da18c34
rm log
fzyzcjy Aug 29, 2025
b0c4599
fix
fzyzcjy Aug 29, 2025
526d74a
more
fzyzcjy Aug 29, 2025
97cd74d
more
fzyzcjy Aug 29, 2025
96d23e9
more
fzyzcjy Aug 29, 2025
ba16fd9
more
fzyzcjy Aug 29, 2025
a87fe74
more
fzyzcjy Aug 29, 2025
63bb84a
more
fzyzcjy Aug 29, 2025
9d56557
more
fzyzcjy Aug 29, 2025
88649f5
more
fzyzcjy Aug 29, 2025
b1ed835
more
fzyzcjy Aug 29, 2025
077ce78
more
fzyzcjy Aug 29, 2025
15f5ea4
more
fzyzcjy Aug 29, 2025
5157097
more
fzyzcjy Aug 29, 2025
e6bb239
hack
fzyzcjy Aug 29, 2025
8539ab0
more
fzyzcjy Aug 30, 2025
8846c1b
fix bug
fzyzcjy Aug 30, 2025
f816735
more
fzyzcjy Aug 30, 2025
4c3388e
Revert "more"
fzyzcjy Aug 30, 2025
f536b27
more
fzyzcjy Aug 30, 2025
8fc2331
more
fzyzcjy Aug 30, 2025
a163cf1
layout use st_volatile_global
fzyzcjy Aug 30, 2025
3a4bf6b
more
fzyzcjy Aug 30, 2025
f0e44a4
more
fzyzcjy Aug 30, 2025
3fe1ee9
temp revert unrolled copy
fzyzcjy Aug 30, 2025
dc5e367
logs
fzyzcjy Aug 30, 2025
3fb1df4
more
fzyzcjy Aug 30, 2025
628ab70
Revert "temp revert unrolled copy"
fzyzcjy Aug 30, 2025
c27021b
temp log
fzyzcjy Aug 30, 2025
6982472
temp hack: rm next_clean cleaning!
fzyzcjy Aug 30, 2025
7a4e94b
Revert "temp hack: rm next_clean cleaning!"
fzyzcjy Aug 30, 2025
158101c
logs
fzyzcjy Aug 30, 2025
186c378
logs
fzyzcjy Aug 30, 2025
318aabe
hack: maxnreg 32
fzyzcjy Aug 30, 2025
98dca42
logs
fzyzcjy Aug 30, 2025
176a082
rm logs
fzyzcjy Aug 30, 2025
c24c65a
more
fzyzcjy Aug 30, 2025
21f8941
rm log
fzyzcjy Aug 30, 2025
daae47a
rm t_start
fzyzcjy Aug 30, 2025
2f0004b
rm timeout check
fzyzcjy Aug 30, 2025
c580196
temp revert unroll copy speedup
fzyzcjy Aug 30, 2025
fa39e42
temp use weaker set layout_range_buffer
fzyzcjy Aug 30, 2025
7f1f948
temp rm maxnreg
fzyzcjy Aug 30, 2025
b43c25a
fix compile
fzyzcjy Aug 30, 2025
bdb4c51
maxnreg 48
fzyzcjy Aug 30, 2025
227dcba
maxnreg 32
fzyzcjy Aug 30, 2025
129b968
rm (unused) debug_tensor
fzyzcjy Aug 30, 2025
56cac67
fix compile
fzyzcjy Aug 30, 2025
a69cafb
fix logical error introduced when copying yesterday
fzyzcjy Aug 30, 2025
aa7632b
maxnreg 48
fzyzcjy Aug 30, 2025
48c6ea7
Revert "temp revert unroll copy speedup"
fzyzcjy Aug 30, 2025
cc8cedc
temp use st_volatile_global layout_range_buffer
fzyzcjy Aug 30, 2025
8a1763d
Revert "temp use st_volatile_global layout_range_buffer"
fzyzcjy Aug 30, 2025
092ea74
simp
fzyzcjy Aug 30, 2025
ee223ae
simp
fzyzcjy Aug 30, 2025
ee618b2
extract
fzyzcjy Aug 30, 2025
5f9443a
more
fzyzcjy Aug 30, 2025
461074b
more
fzyzcjy Aug 30, 2025
f820a2b
more
fzyzcjy Aug 30, 2025
134b928
more
fzyzcjy Aug 30, 2025
becde3a
more
fzyzcjy Aug 30, 2025
ff3f52a
more
fzyzcjy Aug 30, 2025
731f2b4
more
fzyzcjy Aug 30, 2025
3962b41
more
fzyzcjy Aug 30, 2025
2f5d8df
more
fzyzcjy Aug 30, 2025
ee9e523
mv
fzyzcjy Aug 30, 2025
ebf0607
delay remote_start_offset
fzyzcjy Aug 30, 2025
083b684
Revert "delay remote_start_offset"
fzyzcjy Aug 30, 2025
342fcfa
Revert "mv"
fzyzcjy Aug 30, 2025
97a3b20
copy new nvfp4 swizzle
fzyzcjy Aug 30, 2025
8a42c16
more
fzyzcjy Aug 30, 2025
c438037
compile
fzyzcjy Aug 30, 2025
5aa2194
more
fzyzcjy Aug 30, 2025
81811fd
token_idx_and_dst_expert_and_dst_slot_idx_flat_list
fzyzcjy Aug 30, 2025
0c4d561
slot_idx provided by external
fzyzcjy Aug 30, 2025
7ec8209
var
fzyzcjy Aug 30, 2025
9f38ae7
print
fzyzcjy Aug 30, 2025
ca4b823
temp rm slot_idx logic
fzyzcjy Aug 30, 2025
ea4be13
Revert "temp rm slot_idx logic"
fzyzcjy Aug 30, 2025
5cee33d
temp enable stuck handler
fzyzcjy Aug 30, 2025
5e44fd9
Revert "temp enable stuck handler"
fzyzcjy Aug 30, 2025
8511b31
use 2 warp to send 1 token
fzyzcjy Aug 30, 2025
1fca712
fix warppair bug
fzyzcjy Aug 30, 2025
3770baa
Revert "fix warppair bug"
fzyzcjy Aug 30, 2025
3b3a561
Revert "use 2 warp to send 1 token"
fzyzcjy Aug 30, 2025
061e52b
re-introduce 2warp for 1token
fzyzcjy Aug 30, 2025
cd98732
hack: grid sync after each signal
fzyzcjy Aug 30, 2025
07ac7d7
Revert "hack: grid sync after each signal"
fzyzcjy Aug 30, 2025
84d9c7a
Revert "re-introduce 2warp for 1token"
fzyzcjy Aug 30, 2025
e787723
re-introduce debug_tensor
fzyzcjy Aug 30, 2025
8c209be
disable debug
fzyzcjy Aug 31, 2025
5cb62f0
hack: rm scale copying
fzyzcjy Aug 31, 2025
25804f0
Revert "hack: rm scale copying"
fzyzcjy Aug 31, 2025
ad90fe8
change ld_token_signal debug tensor
fzyzcjy Aug 31, 2025
c45fbf1
add send::after_get_remote_start_offset
fzyzcjy Aug 31, 2025
3498db0
re-introduce 2warp for 1token again
fzyzcjy Aug 31, 2025
c117099
more
fzyzcjy Aug 31, 2025
3f34053
Revert "re-introduce 2warp for 1token again"
fzyzcjy Aug 31, 2025
44899aa
disable debug tensor
fzyzcjy Aug 31, 2025
7dedcc8
enable debug_tensor
fzyzcjy Aug 31, 2025
c1d3606
Merge branch 'main-upstream_public' into feat/cu_mem_api
fzyzcjy Sep 1, 2025
c648269
Revert "enable debug_tensor"
fzyzcjy Sep 6, 2025
180a028
Merge branch 'main-upstream_public' into feat/dev_20250825
fzyzcjy Sep 7, 2025
feef99e
Merge branch 'feat/cu_mem_api' into feat/dev_20250825
fzyzcjy Sep 7, 2025
dc79a3c
more
fzyzcjy Sep 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions csrc/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace deep_ep {

template <typename dtype_t>
dtype_t ceil_div(dtype_t a, dtype_t b) {
constexpr dtype_t ceil_div(dtype_t a, dtype_t b) {
return (a + b - 1) / b;
}

Expand Down Expand Up @@ -89,6 +89,11 @@ struct Config {
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2;
num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2;

// NOTE Please keep in sync: Config.get_nvl_buffer_size_hint, LowLatencyLayout.constructor, internode_ll_v2
// NOTE add a large number to be safe
num_bytes += 1048576;

num_bytes = ((num_bytes + 127) / 128) * 128;
return num_bytes;
#else
Expand All @@ -102,7 +107,9 @@ struct LowLatencyBuffer {

void* dispatch_rdma_send_buffer = nullptr;
void* dispatch_rdma_recv_data_buffer = nullptr;
int* dispatch_rdma_recv_count_buffer = nullptr;
// NOTE rename
// int* dispatch_rdma_recv_count_buffer = nullptr;
int* dispatch_rdma_general_signal_buffer = nullptr;

void* combine_rdma_send_buffer = nullptr;
void* combine_rdma_recv_data_buffer = nullptr;
Expand All @@ -112,8 +119,8 @@ struct LowLatencyBuffer {
size_t num_bytes_per_combine_msg = 0;

std::pair<int*, int> clean_meta() {
EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_recv_count_buffer, num_clean_int};
EP_HOST_ASSERT(dispatch_rdma_general_signal_buffer == combine_rdma_recv_flag_buffer);
return {dispatch_rdma_general_signal_buffer, num_clean_int};
}
};

Expand All @@ -129,6 +136,9 @@ struct LowLatencyLayout {
LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) {
const int num_scales = hidden / 128;

EP_HOST_ASSERT(num_experts % num_ranks == 0);
const int num_local_experts = num_experts / num_ranks;

// Dispatch and combine layout:
// - 2 symmetric odd/even send buffer
// - 2 symmetric odd/even receive buffers
Expand Down Expand Up @@ -157,9 +167,13 @@ struct LowLatencyLayout {
total_bytes += recv_buffer_bytes * 2;

// Symmetric signaling buffers
size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes);
// NOTE can only increase instead of decrease to be compatible with v1
// NOTE be careful about alignment
// size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int);
// NOTE Please keep in sync: Config.get_nvl_buffer_size_hint, LowLatencyLayout.constructor, internode_ll_v2
size_t dispatch_general_signal_buffer_bytes = num_experts * sizeof(int64_t) + num_local_experts * sizeof(int);
size_t combine_recv_flag_buffer_bytes = dispatch_general_signal_buffer_bytes;
size_t signaling_buffer_bytes = std::max(dispatch_general_signal_buffer_bytes, combine_recv_flag_buffer_bytes);
size_t signaling_buffer_bytes_aligned = align<size_t>(signaling_buffer_bytes, 128);
total_bytes += signaling_buffer_bytes_aligned * 2;

Expand Down
332 changes: 303 additions & 29 deletions csrc/deep_ep.cpp

Large diffs are not rendered by default.

47 changes: 42 additions & 5 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,33 @@
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif

namespace shared_memory {

union MemHandleInner {
cudaIpcMemHandle_t cuda_ipc_mem_handle;
CUmemFabricHandle cu_mem_fabric_handle;
};

struct MemHandle {
MemHandleInner inner;
size_t size;
};

constexpr size_t HANDLE_SIZE = sizeof(MemHandle);

class SharedMemoryAllocator {
public:
SharedMemoryAllocator();
void malloc(void** ptr, size_t size);
void free(void* ptr);
void get_mem_handle(MemHandle* mem_handle, void* ptr);
void open_mem_handle(void** ptr, MemHandle* mem_handle);
void close_mem_handle(void* ptr);
private:
bool enable_fabric;
};
}

namespace deep_ep {

struct Buffer {
Expand All @@ -44,7 +71,7 @@ struct Buffer {
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS];

// Stream for communication
at::cuda::CUDAStream comm_stream;
Expand Down Expand Up @@ -76,6 +103,8 @@ struct Buffer {
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;

shared_memory::SharedMemoryAllocator shared_memory_allocator;

public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy);

Expand Down Expand Up @@ -144,20 +173,28 @@ struct Buffer {
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts);

std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
low_latency_dispatch(bool enable_v2, const torch::Tensor& x, const torch::Tensor& topk_idx,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook);
bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& zeroed_tensor_a,
const std::optional<torch::Tensor>& zeroed_tensor_b,
const std::optional<torch::Tensor>& zeroed_buffer_for_atomic_counter_per_expert,
bool use_nvfp4,
const std::optional<torch::Tensor>& dst_signals,
const std::optional<torch::Tensor>& count_per_expert, const std::optional<torch::Tensor>& token_idx_and_dst_expert_and_dst_slot_idx_flat_list,
const std::optional<torch::Tensor>& debug_tensor);

std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
low_latency_combine(bool enable_v2, const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
const torch::Tensor& src_info, const torch::Tensor& layout_range,
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
const std::optional<torch::Tensor>& out = std::nullopt,
const std::optional<torch::Tensor>& src_signals = std::nullopt, uint32_t src_signal_expect_value = 0);

torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
Expand Down
18 changes: 12 additions & 6 deletions csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,26 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
int* clean_1, int num_clean_int_1,
cudaStream_t stream);

void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
void dispatch(bool enable_v2, void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
void* x, const int64_t* topk_idx, // NOTE rm `const` of x
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases);

void combine(void* combined_x,
cudaStream_t stream, int phases,
bool use_nvfp4, uint32_t* dst_signals,
uint32_t* count_per_expert, int64_t* token_idx_and_dst_expert_and_dst_slot_idx_flat_list,
int* remote_start_offset_buffer,
int* zeroed_buffer_for_atomic_counter_per_expert,
int* debug_tensor);

void combine(bool enable_v2, void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
const int* src_info, const int64_t* layout_range,
Expand All @@ -163,7 +168,8 @@ void combine(void* combined_x,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy);
cudaStream_t stream, int phases, bool zero_copy,
uint32_t* src_signals, uint32_t src_signal_expect_value);

} // namespace internode_ll

Expand Down
15 changes: 15 additions & 0 deletions csrc/kernels/exception.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ do { \
} while (0)
#endif

#ifndef CU_CHECK
#define CU_CHECK(cmd) \
do { \
CUresult e = (cmd); \
if (e != CUDA_SUCCESS) { \
const char *error_str = NULL; \
cuGetErrorString(e, &error_str); \
throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \
} \
} while (0)
#endif

#ifndef EP_HOST_ASSERT
#define EP_HOST_ASSERT(cond) \
do { \
Expand All @@ -49,3 +61,6 @@ do { \
} \
} while (0)
#endif

#define EP_DEBUG_DEVICE_ASSERT(cond) EP_DEVICE_ASSERT(cond)
// #define EP_DEBUG_DEVICE_ASSERT(cond) do {} while (0)
Loading