Skip to content

Conversation

jershi425
Copy link
Collaborator

@jershi425 jershi425 commented Sep 21, 2025

Hybrid Expert Parallel (Hybrid-EP) Implementation

Overview

This PR introduces the Hybrid Expert Parallel (Hybrid-EP) implementation to the DeepEP library, developed by NVIDIA as an optimized solution for large-scale MoE (Mixture of Experts) model all-to-all communication. This implementation is specifically designed to leverage NVIDIA GPU hardware capabilities, significantly reducing Streaming Multiprocessor (SM) resource usage while dramatically improving communication efficiency and overall throughput.

🎯 Design Goals

  1. Maximize Network Bandwidth Utilization - Achieve optimal network bandwidth usage for large-scale distributed training
  2. Minimize SM Resource Consumption - Preserve computational resources for core ML workloads
  3. Hardware-Aware Optimization - Leverage NVIDIA NVLink, RDMA, and other advanced hardware features for maximum efficiency

🏗️ Core Architecture

Communication Operators

  • Dispatch: Efficiently distribute tokens to corresponding expert nodes
  • Combine: Aggregate expert computation results with optimized reduction operations

Hierarchical Communication Design

  • Inter-node Communication: High-performance RDMA-based communication across nodes*
  • Intra-node Communication: NVLink-optimized data transfer using Tensor Memory Accelerator (TMA) instructions

*Note: RDMA functionality will be available in upcoming releases following comprehensive testing.

🔧 Implementation Features

Hardware Optimizations

  • TMA Instructions: Leverage Tensor Memory Accelerator instructions for minimal SM overhead
  • RDMA Integration: High-efficiency inter-node communication (coming soon)*
  • Pipeline Architecture: Warp-level pipeline parallelism within execution blocks

Supported Data Types

  • BF16 (Brain Floating Point 16-bit)
  • FP8 (8-bit Floating Point)

CUDA Graph Integration

  • Full CUDA Graph compatibility for reduced launch overhead
  • Zero CPU-GPU synchronization requirements
  • Dynamic block count configuration for optimal resource utilization

*RDMA features are currently under final testing and will be released shortly.

📊 Performance Results

B200 Platform

Test Configuration:

  • Device: B200
  • Tokens: 4096
  • Hidden Dimension: 7168
  • TopK: 8
  • Router: Random Uniform
  • Local Experts: 8
  • Ranks: 8

Performance Comparison (Bandwidth in GB/s):

Implementation Measurement Type SM Count Dispatch (FP8) Dispatch (BF16) Combine
DeepEP Torch API 16 246 348 302
24 349 494 420
28 397 560 477
32 443 619 524
36 482 635 549
40 519 629 570
44 544 640 577
48 554 646 586
HybridEP Torch API 16 409.71 535.94 530.86
Only Kernel Time 16 599.27 734.95 673.84

Key Performance Improvements (at 16 SM):

  • Dispatch (FP8):
    • Torch API: 36.5% improvement (335.68 vs 246 GB/s)
    • Kernel Only: 142.7% improvement (597.12 vs 246 GB/s)
  • Dispatch (BF16):
    • Torch API: 32.6% improvement (461.49 vs 348 GB/s)
    • Kernel Only: 50.5% improvement (523.65 vs 348 GB/s)
  • Combine:
    • Torch API: 73.5% improvement (523.92 vs 302 GB/s)
    • Kernel Only: 120.2% improvement (664.78 vs 302 GB/s)

GB200 Platform

Test Configuration:

  • Device: GB200
  • Tokens: 4096
  • Hidden Dimension: 7168
  • TopK: 8
  • Router: Random Uniform
  • Local Experts: 8
  • SM Count: 16/32
  • Ranks: 8/16/24/32/36

Note: All bandwidth values represent algorithm bandwidth.

HybridEP Performance Results (Bandwidth in GB/s):

Ranks SM Count Torch API Kernel Only
Dispatch (FP8) Dispatch (BF16) Combine Dispatch (FP8) Dispatch (BF16) Combine
8 16 421.67 550.10 538.44 620.98 750.15 684.27
32 455.35 545.71 568.94 713.98 764.03 737.13
16 16 397.33 472.84 474.48 577.17 661.93 600.75
32 444.67 523.48 521.55 650.48 706.95 666.26
24 16 281.73 441.89 444.40 360.12 637.80 565.53
32 403.20 507.32 483.76 577.96 665.97 639.80
32 16 236.33 485.50 423.19 286.93 629.79 547.25
32 392.70 484.22 464.54 538.86 642.23 605.15
36 16 215.36 469.96 418.27 260.53 612.85 543.27
32 361.13 479.02 447.89 489.27 632.31 596.99

DeepEP Performance Results (Bandwidth in GB/s):

Ranks SM Count Torch API
Dispatch (FP8) Dispatch (BF16) Combine
8 16 248.86 362.01 310.21
24 350.97 512.72 425.95
32 447.76 615.78 519.57
16 16 242.51 328.80 278.34
24 338.87 442.47 378.32
32 393.72 520.76 442.51
24 16 258.33 324.64 126.53
24 351.05 450.22 163.62
32 405.04 502.84 207.10

GB200 Performance Highlights:

  • Best HybridEP Performance: 738.48 GB/s (Dispatch BF16, 8 ranks, 32 SM, kernel only)
  • Significant Scalability: Performance maintains high efficiency across different rank configurations

🏛️ Code Structure

New Files

csrc/
├── hybrid_ep.cu              # Main CUDA implementation
├── hybrid_ep.cuh             # Header definitions
└── kernels/
    ├── hybrid_ep_backend.cuh        # Backend core implementation
    └── hybrid_ep_backend_configs.hpp # Configuration parameters
    
deep_ep/
├── hybrid_ep_buffer.py       # Python interface
└── buffer.py                 # Buffer management

tests/
└── test_mnnvlink_hybridep.py       # Multi-node NVLink testing and Intra-node testing

Build Instructions

Follow the same build process as the main branch. No additional dependencies required.

🚀 Usage Guide

Quick Start

Refer to tests/test_mnnvlink_hybridep.py for comprehensive usage examples including:

  • Multi-node NVLink configuration
  • Intra-node testing scenarios
  • Performance benchmarking setups

Important Configuration Note

Current Limitation: Due to template-based optimization, parameters in the Python test file must match those defined in csrc/kernels/hybrid_ep_backend_configs.hpp. After modifying the header file, recompilation and reinstallation are required.

Future Enhancement: We plan to implement Just-In-Time (JIT) compilation to eliminate this manual configuration requirement and improve developer experience.


📋 Implementation Status & Roadmap

✅ Current Features

  • Full compatibility with existing DeepEP codebase
  • Optimized intra-node communication via NVLink
  • Support for BF16 and FP8 data types
  • CUDA Graph integration
  • Comprehensive performance improvements

🚧 Upcoming Features

  • Low Latency Mode: Enhanced performance for latency-critical workloads
  • RDMA Integration: High-performance inter-node communication
  • JIT Compilation: Dynamic parameter configuration without recompilation

⚠️ Current Limitations

  • Template-based implementation requires recompilation for parameter changes
  • RDMA functionality not yet available (under final testing)
  • Configuration parameters must be manually synchronized between Python and C++ files

🎯 Migration Notes

This implementation maintains full backward compatibility with DeepEP. Users can seamlessly integrate Hybrid-EP into existing workflows without code modifications.

@wangdong1991
Copy link

The main branch of DeepEP currently does not support mnnvl. The data for DeepEP mnnvl was tested using #218.

@LyricZhao
Copy link
Collaborator

Thanks so much!

@jershi425 jershi425 changed the base branch from main to hybrid-ep September 22, 2025 03:24
@jershi425 jershi425 merged commit c9f647d into deepseek-ai:hybrid-ep Sep 22, 2025
@jershi425
Copy link
Collaborator Author

Merged. Thanks @LyricZhao !

@robertgshaw2-redhat
Copy link

Awesome work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants