Skip to content

Conversation

yuantailing
Copy link
Contributor

If I understand correctly,

  1. L62 is corresponding to
    auto rdma_channel_meta = SymBuffer<int>(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels);
  2. L67 is corresponding to
    auto idx_value = static_cast<int>(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx));
    auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx);
    st_na_global(reinterpret_cast<int*>(dst_send_buffers[rank_idx]) + copy_idx, idx_value);
    , and it is also counted as sizeof(int) in get_num_bytes_per_token
  3. L70 is corresponding to int4 alignment:
    int get_num_bytes_per_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) {
    return static_cast<int>(align_up(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4)));
    }
  4. L94 is similar to L67
  5. L98 is corresponding to
    auto rdma_channel_head = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);
    auto rdma_channel_tail = SymBuffer<uint64_t, false>(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels);

Before change:

>>> deep_ep.Config(24, 8, 512, 16, 128).get_nvl_buffer_size_hint(7168, 16)
453380736
>>> deep_ep.Config(24, 8, 512, 16, 128).get_rdma_buffer_size_hint(7168, 16)
56774016

After change:

>>> deep_ep.Config(24, 8, 512, 16, 128).get_nvl_buffer_size_hint(7168, 16)
429000960
>>> deep_ep.Config(24, 8, 512, 16, 128).get_rdma_buffer_size_hint(7168, 16)
53629056

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant