-
Notifications
You must be signed in to change notification settings - Fork 37
Add scalable sparse map kernel for large MoE models (128+ experts) #375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: add-gpt-oss-converter
Are you sure you want to change the base?
Conversation
This commit adds a new Triton kernel implementation that supports large numbers of experts (e.g., 128 for GPT-OSS) without hitting register pressure limits. ## Problem The original sparse_map_kernel uses one-hot encoding which creates large intermediate tensors (block_size × num_experts). For 128 experts with block_size=1024, this creates 131K register elements, exceeding Triton's limits and causing failures for num_experts > 32. ## Solution Multi-pass atomic kernel approach: - Pass 1: Histogram using atomic operations (no large matrices) - Pass 2a: Atomic assignment (fast, for ≤128 experts) - Pass 2b: Chunked assignment (memory-efficient, for >128 experts) ## Key Benefits - Supports arbitrary numbers of experts (tested up to 256) - No register pressure issues - Two variants: atomic (fastest) and chunked (most scalable) - Matches original kernel for small num_experts - Matches PyTorch fallback for large num_experts ## Testing Comprehensive test suite validates: - Correctness vs original kernel (num_experts ≤ 32) - Correctness vs PyTorch fallback (num_experts > 32) - GPT-OSS specific config (128 experts, 4 active) - Edge cases and determinism Run with: pytest tests/functional/test_sparse_map_scalable.py 🤖 Generated with Claude Code
This commit integrates the scalable sparse map kernel into the existing get_sparse_map() function with automatic kernel selection based on num_experts. ## Changes ### Automatic Kernel Selection Modified get_sparse_map() to automatically choose: - Original one-hot kernel: num_experts ≤ 64 (fastest) - Scalable atomic kernel: 64 < num_experts ≤ 128 (GPT-OSS) - Scalable chunked kernel: num_experts > 128 (maximum scalability) ### Integration Points 1. fast_llm/functional/triton/sparse_copy.py: - Added max_experts_onehot parameter (default 64) - Lazy import of scalable kernel when needed - Clear documentation of kernel selection logic 2. tests/functional/test_sparse_map_scalable.py: - Added test_automatic_kernel_selection() - Validates kernel selection works correctly - Verifies results match PyTorch reference 3. tests/utils/model_configs.py: - Added comment noting GPT-OSS uses 128 experts - Test config uses 4 experts (validates original kernel still works) ## Usage No code changes needed! Existing MoE layers automatically benefit: ```python # In MixtureOfExpertMLP._forward_dropless() sparse_map = get_sparse_map( top_experts, self._config.experts, # Automatically selects correct kernel ) ``` For GPT-OSS with 128 experts, this now uses the scalable atomic kernel instead of failing or falling back to slow looped implementation. 🤖 Generated with Claude Code
Tests should be run via pytest, not as scripts. 🤖 Generated with Claude Code
Adds gpt_oss_128_experts test configuration to validate the scalable MoE kernel with actual GPT-OSS expert count. Configuration: - 128 experts (same as GPT-OSS 120B) - Tiny experts (intermediate_size=256) for fast tests - 2 blocks (reduced from 4) to minimize test time - Tests basic, checkpoint, convert, and distributed groups This ensures the scalable kernel works end-to-end in the full training/evaluation pipeline, not just in unit tests. 🤖 Generated with Claude Code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work. Have you tried it for small number of experts? The existing kernal has disappointing performance, so not sure we stilll need it.
) | ||
|
||
# Compute padded counts and offsets (on CPU is fine, small tensor) | ||
expert_counts_cpu = expert_counts.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better leave on gpu. Moving to cpu is a lot slower and needs cuda sync (really bad). You can use a @torch.compile
function to avoid lots of kernles.
pad_to_multiple: int = MAX_DROPLESS_BLOCK_SIZE_ROW, | ||
block_size=TritonConfig.POINTWISE_BLOCK_SIZE, | ||
use_triton: bool | None = None, | ||
max_experts_onehot: int = 64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Global constant?
Problem
The original
sparse_map_kernel
uses one-hot encoding to perform histogram and sorting operations:For GPT-OSS with 128 experts and
block_size=1024
, this creates a[1024, 128]
intermediate tensor requiring 131,072 register elements. This exceeds Triton's register pressure limits, causing the kernel to fail for num_experts > 32.Solution
This PR introduces a new multi-pass atomic kernel that avoids large intermediate tensors:
Pass 1: Histogram with Atomics
Pass 2: Two Assignment Variants
Variant A: Atomic Assignment (faster, recommended for ≤128 experts)
Variant B: Chunked Assignment (more scalable for >128 experts)
Performance
Testing
Comprehensive test suite validates correctness:
1. Small Experts (≤32): Match Original Kernel
2. Large Experts (>32): Match PyTorch Fallback
3. GPT-OSS Specific Config
4. Correctness Validation
For every test, we verify:
Run tests with:
Files Changed
New Files
fast_llm/functional/triton/sparse_copy_scalable.py
: New kernel implementationsparse_map_histogram_kernel
: Atomic histogram computationsparse_map_assign_kernel
: Atomic index assignmentsparse_map_assign_chunked_kernel
: Chunked index assignmentget_sparse_map_scalable()
: Main API functiontests/functional/test_sparse_map_scalable.py
: Comprehensive test suiteUsage
Next Steps
After this PR is merged, a follow-up PR will:
get_sparse_map()
MixtureOfExpertMLP
to use scalable kernel for num_experts > 32Related
This PR builds on #374 (GPT-OSS converter) which requires 128 experts support.
🤖 Generated with Claude Code