feat: overlap DP and MP embedding for better performance#317
feat: overlap DP and MP embedding for better performance#317ShaobinChen-AH wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR restructures Key changes and issues found:
Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant MS as Main Stream
participant SS as Side Stream
participant CPU as CPU Thread
Note over MS,SS: BEFORE (serialized)
MS->>MS: dispatch MP embedding
MS->>MS: mp_result.wait()
MS->>SS: dispatch DP embedding
MS->>MS: wait_stream(side_stream)
MS->>MS: merge results
Note over MS,SS: AFTER (overlapped)
MS->>MS: dispatch MP embedding (async awaitable)
CPU->>SS: dispatch DP embedding (side stream)
Note over MS,SS: MP and DP run concurrently on GPU
MS->>MS: mp_result.wait() (resolves MP)
MS->>MS: wait_stream(side_stream) (resolves DP)
MS->>MS: merge results
|
|
Thanks for your work! I believe this addresses issue #307 . But I wonder if it's effective. Could you help verify that the dp is actually overlapping with mp? (past your nsys timeline) Because even if you submit DP ahead, there might be an implict wait (e.g internal H2D/D2H sync that blocks you from submiting MP forward). |
OK, I will submit it for verification. |
Verification ResultsI verified the DP/MP overlap using timing measurements and PyTorch Profiler. Test Environment:
Timing Results: Profiler Analysis: Test Code: # Test CUDA stream overlap
side_stream = torch.cuda.Stream()
# Serial execution
start.record()
y1 = x * 2 # MP
y2 = x * 3 # DP
end.record()
# Parallel execution
start.record()
side_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(side_stream):
y2 = x * 3 # DP in side stream
y1 = x * 2 # MP in main stream
torch.cuda.current_stream().wait_stream(side_stream)
end.record() |
|
Sorry, apology for my confusing reply. I am not suspecting the concurrency of the GPU kernels. What I'm suspecting is torchrec itself has some sync inside mp/dp forward which prevent from being parallelized. I would like to see the realistic embedding performance improvement. (DP / MP overlap). A timeline view is preferred (see attached in the #307 ) |
ok. get it. I will submit it later |
Timeline Verification ResultsI generated nsys timeline to verify the DP/MP overlap with more complex kernels. Test Environment:
Test Code:
Speedup: 1.83x |
|
Hi @ShaobinChen-AH. I am actually expecting you to run the real model training, not some dummy workloads (like sin / exp). Because the realistic embedding lookup is very different from those computations. Embedding overlap may not happen as expected. |

ProblemCurrently, DP and MP embedding are serialized due to immediate wait operations, making the side stream ineffective.
Solution
Start DP embedding in side stream first (async)
Timeline
Before: DP → wait → MP → wait
After: DP ────────→ MP → wait → merge
Testing