-
Notifications
You must be signed in to change notification settings - Fork 946
[Feat] Single Batch Overlap (SBO): Overlaping of Down GEMM with Combine Send #390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Sulfur6 <[email protected]> Co-authored-by: wangfakang <[email protected]> Co-authored-by: alpha-baby <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a very short glance and only tiny nits
csrc/kernels/internode_ll.cu
Outdated
const auto dst_rank = responsible_expert_idx / num_local_experts; | ||
const auto local_expert_idx = responsible_expert_idx % num_local_experts; | ||
// Shared between warps in sms for overlap mode, where each sm only has one warp group | ||
__shared__ int shared_vaild_signal_prefix_sum[288]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would be great to make it a const and assert num experts smaller than this etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments. I have added const kNumMaxExperts and assert statements.
deep_ep/buffer.py
Outdated
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook | ||
|
||
# noinspection PyTypeChecker | ||
def ll_overlap_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: wondering whether we should use the existing api since this seems to only add a few flags
(when using in in blackwell sgl-project/sglang#9870 I feel it would be great to have one single function)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments. We initially added the ll_overlap_combine
Python interface to facilitate development and testing. We will reference your implementation and switch to using the low_latency_combine
interface directly in our refactoring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great, looking forward to that (and no worries for now - my 9870 also temporarily use the ll_overlap_combine api now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I have deleted ll_overlap_combine
and now reuse low_latency_combine
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: is this the optimized version, or is there a version that is faster but not put here yet? b/c I see it seems a bit slow and may want to see whether I can try to optimize its speed (on blackwell), thus to double check I do not conflict code with you
This is already the latest version, feel free to optimize it. |
Co-authored-by: Sulfur6 <[email protected]> Co-authored-by: wangfakang <[email protected]> Co-authored-by: alpha-baby <[email protected]> Co-authored-by: AniZpZ <[email protected]>
1. Motivation
The optimization effect of Two-Batch Overlap (TBO) is suboptimal for the Decode phase on low-compute-power cards (i.e., H20). This is due to two main factors: First, on the Hopper architecture, the WGMMA block_m is 64. Consequently, when TBO is enabled with a small Decode batch size, the MLP GEMM suffers from redundant computations. A positive throughput gain is only observed at larger batch sizes (e.g., 64, 128). Second, at these larger batch sizes, low-compute-power cards like the H20 fail to meet the SLA guarantees for TPOT/ITL.
Therefore, it is necessary to find a solution that can improve Decode throughput even with small batch sizes. Single Batch Overlap (SBO) presents itself as a viable solution.
We implement SBO for DeepSeek v3/R1 by modifying DeepEP and DeepGEMM, including the overlap of Shared Expert and Dispatch Recv, as well as the overlap of Down GEMM with Combine Send.
The overlap of Down GEMM with Combine Send is implemented by modifying SGlang, DeepEP and DeepGEMM, with the detailed implementation available in the PRs below:
We also conduct integration and evaluation in SGLang: sgl-project/sglang#9660.
2. Overlap Design
SBO implements two overlap for the MoE layers of DeepSeek-V3/R1. One is to overlap the Shared Expert computation with the Dispatch Recv communication, and the other is to overlap the Down GEMM computation with the Combine Send communication.


The interaction between Down GEMM and Combine Send is structured as a producer-consumer model synchronized by signals. For each local expert, a signal unit is allocated for every block_m tokens. The Down GEMM computes the results for these block_m tokens and atomically increments the signaling unit after completing a portion of the work. The Combine Send polls this signaling unit. Once the value reaches a threshold, it sends the corresponding block_m tokens.
3. Modifications
ll_overlap_combine
Python interface to support overlapping Down GEMM with Combine Send.low_latency_dispatch
, change the type of packed_recv_src_info from kInt32 to kInt64 to track the src_rank for each token.internode_ll::combine
kernel uses fewer SMs when theoverlap
parameter is enabled and integrates logic for overlap mode, such as polling signals, sending tokens and puting finish flag.4. Evaluation
We integrated the modified DeepEP and DeepGEMM into SGLang for performance evaluation.
4.1. Experiment Setup
4.2. Performance Evaluation
4.3. Accuracy Tests
4.4. Repro Script
Please refer to sgl-project/sglang#9660.