Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions csrc/kernels/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1095,42 +1095,42 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
if (is_cached_dispatch)
return;

EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(num_rdma_ranks <= 32);

// Iterate in reverse order
if (lane_id < num_rdma_ranks and warp_id < num_channels) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx);

// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25;
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) {
auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
if (current_head < 0) {
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
} else {
last_head = current_head;
if (lane_id < num_rdma_ranks) {
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
int token_start_idx, token_end_idx;
get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx);

// NOTES: `1 << 25` is a heuristic large number
int last_head = 1 << 25;
for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) {
auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id);
if (current_head < 0) {
combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1;
} else {
last_head = current_head;
}
}
}
}
} else {
if (is_cached_dispatch)
return;

EP_DEVICE_ASSERT(num_warps >= num_channels);
EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr);
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers");

if (warp_id < num_channels) {
for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) {
constexpr int tma_batch_size = kNumTMABytesPerWarp - sizeof(uint64_t);
constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS;
constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token;
EP_STATIC_ASSERT(num_bytes_per_token % 16 == 0, "num_bytes_per_token should be divisible by 16");

// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp;
auto tma_buffer = smem_tma_buffer + channel_id * kNumTMABytesPerWarp;
auto tma_mbarrier = reinterpret_cast<uint64_t*>(tma_buffer + tma_batch_size);
uint32_t tma_phase = 0;
if (lane_id == 0) {
Expand All @@ -1142,8 +1142,8 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in

for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 2) {
// Iterate in reverse order
int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1];
int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id];
int token_start_idx = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1];
int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id];
int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1];
token_start_idx += shift, token_end_idx += shift;

Expand Down