diff --git a/fast_llm/functional/triton/sparse_copy.py b/fast_llm/functional/triton/sparse_copy.py index 7c803689c..f9f1a4a4a 100644 --- a/fast_llm/functional/triton/sparse_copy.py +++ b/fast_llm/functional/triton/sparse_copy.py @@ -302,12 +302,41 @@ def get_sparse_map( pad_to_multiple: int = MAX_DROPLESS_BLOCK_SIZE_ROW, block_size=TritonConfig.POINTWISE_BLOCK_SIZE, use_triton: bool | None = None, + max_experts_onehot: int = 64, ) -> SparseMap: + """ + Get sparse map for MoE token routing. + + Automatically selects the appropriate kernel based on num_experts: + - num_experts <= max_experts_onehot: Use original one-hot kernel (fastest) + - num_experts > max_experts_onehot: Use scalable atomic kernel (supports 128+ experts) + + Args: + max_experts_onehot: Maximum num_experts for one-hot kernel. + Above this, use scalable kernel to avoid register pressure. + Default 64 (conservative, original kernel works up to ~32-64). + """ num_rows_dense, num_experts_per_token = top_experts.shape num_rows_unpadded = num_rows_dense * num_experts_per_token max_rows = (num_rows_unpadded + num_experts * pad_to_multiple) // pad_to_multiple * pad_to_multiple dtype = torch.int16 if max_rows < 32768 else torch.int32 - if (use_triton is None and TritonConfig.TRITON_ENABLED) or use_triton: + + use_triton_kernel = (use_triton is None and TritonConfig.TRITON_ENABLED) or use_triton + + if use_triton_kernel and num_experts > max_experts_onehot: + # Use scalable kernel for large num_experts to avoid register pressure + from fast_llm.functional.triton.sparse_copy_scalable import get_sparse_map_scalable + + expert_ends, expert_pad_begins, sparse_rows = get_sparse_map_scalable( + top_experts, + num_experts, + pad_to_multiple=pad_to_multiple, + block_size=block_size, + # Use atomic assignment for up to ~128 experts, chunked for more + use_atomic_assign=(num_experts <= 128), + ) + elif use_triton_kernel: + # Use original one-hot kernel for small num_experts expert_ends, expert_pad_begins = top_experts.new_empty((2 * num_experts,), dtype=dtype).chunk(2) sparse_rows = expert_ends.new_empty(num_rows_dense, num_experts_per_token) sparse_map_kernel[(triton.cdiv(num_rows_dense, block_size),)]( @@ -323,6 +352,7 @@ def get_sparse_map( DataType.from_torch(dtype).triton, ) else: + # PyTorch fallback expert_ends, expert_pad_begins, sparse_rows = sparse_map_pytorch(top_experts, num_experts, pad_to_multiple) return SparseMap( diff --git a/fast_llm/functional/triton/sparse_copy_scalable.py b/fast_llm/functional/triton/sparse_copy_scalable.py new file mode 100644 index 000000000..b2cae47f8 --- /dev/null +++ b/fast_llm/functional/triton/sparse_copy_scalable.py @@ -0,0 +1,214 @@ +""" +Scalable sparse map kernel for large numbers of experts (e.g., 128+). + +The original sparse_map_kernel uses one-hot encoding which creates large +intermediate tensors (block_size x num_experts) that exceed register limits +for num_experts > 32. + +This implementation uses a multi-pass approach with atomic operations to +handle arbitrary numbers of experts efficiently. +""" + +import torch + +from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW +from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit + + +@triton_jit() +def sparse_map_histogram_kernel( + top_experts_ptr, + expert_counts_ptr, + num_sparse_rows: tl_constexpr, + num_experts: tl_constexpr, + block_size: tl_constexpr, +): + """ + First pass: Count tokens per expert using atomic operations. + This avoids materializing large one-hot matrices. + """ + block_start = tl.program_id(0) * block_size + offsets = tl.arange(0, block_size) + block_start + mask = offsets < num_sparse_rows + + # Load expert indices for this block + expert_indices = tl.load(top_experts_ptr + offsets, mask=mask, other=num_experts) + + # Atomically increment counts for each expert + # This is much more efficient than one-hot encoding for large num_experts + for i in range(block_size): + if block_start + i < num_sparse_rows: + expert_id = tl.load(top_experts_ptr + block_start + i) + # Atomic add ensures thread-safety across all blocks + tl.atomic_add(expert_counts_ptr + expert_id, 1) + + +@triton_jit() +def sparse_map_assign_kernel( + top_experts_ptr, + sparse_rows_ptr, + expert_begins_ptr, + expert_atomic_counters_ptr, + num_sparse_rows: tl_constexpr, + block_size: tl_constexpr, +): + """ + Second pass: Assign sparse row indices using atomic counters per expert. + Each token atomically claims the next available slot for its expert. + """ + block_start = tl.program_id(0) * block_size + offsets = tl.arange(0, block_size) + block_start + mask = offsets < num_sparse_rows + + # Load expert indices + expert_indices = tl.load(top_experts_ptr + offsets, mask=mask, other=0) + + # For each token, atomically claim a slot in its expert's range + for i in range(block_size): + if block_start + i < num_sparse_rows: + expert_id = tl.load(top_experts_ptr + block_start + i) + expert_begin = tl.load(expert_begins_ptr + expert_id) + + # Atomically get the next available index for this expert + local_offset = tl.atomic_add(expert_atomic_counters_ptr + expert_id, 1) + sparse_row = expert_begin + local_offset + + tl.store(sparse_rows_ptr + block_start + i, sparse_row) + + +@triton_jit() +def sparse_map_assign_chunked_kernel( + top_experts_ptr, + sparse_rows_ptr, + expert_begins_ptr, + expert_chunk_start: tl_constexpr, + expert_chunk_size: tl_constexpr, + num_sparse_rows: tl_constexpr, + block_size: tl_constexpr, + dtype: tl_constexpr, +): + """ + Alternative second pass: Process experts in chunks to reduce memory pressure. + This processes only expert_chunk_size experts at a time, scanning through all tokens. + + Better for very large num_experts as it keeps working set small. + """ + block_start = tl.program_id(0) * block_size + offsets = tl.arange(0, block_size) + block_start + mask = offsets < num_sparse_rows + + # Load expert indices for this block + expert_indices = tl.load(top_experts_ptr + offsets, mask=mask, other=-1) + + # Process experts in the current chunk + expert_range = tl.arange(0, expert_chunk_size) + expert_chunk_start + expert_begins = tl.load(expert_begins_ptr + expert_range) + + # For each expert in chunk, find matching tokens and assign indices + for expert_offset in range(expert_chunk_size): + expert_id = expert_chunk_start + expert_offset + expert_begin = tl.load(expert_begins_ptr + expert_id) + + # Find tokens going to this expert + matches = (expert_indices == expert_id).to(dtype) + + # Compute cumulative sum to get local indices (0, 1, 2, ...) + # This gives each matching token a unique consecutive index + cumsum = tl.cumsum(matches) + local_indices = (cumsum - matches) * matches # Shift by 1 and mask + + # Compute final sparse row indices + sparse_rows = (expert_begin + local_indices) * matches + + # Store results (only for matching tokens) + # Use max to handle non-matching tokens (they get 0 which we ignore) + current_values = tl.load(sparse_rows_ptr + offsets, mask=mask, other=0) + new_values = tl.maximum(current_values, sparse_rows) + tl.store(sparse_rows_ptr + offsets, new_values, mask=mask) + + +def get_sparse_map_scalable( + top_experts: torch.Tensor, + num_experts: int, + *, + pad_to_multiple: int = MAX_DROPLESS_BLOCK_SIZE_ROW, + block_size: int = 1024, + expert_chunk_size: int = 32, + use_atomic_assign: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Scalable sparse map computation for large numbers of experts. + + Args: + top_experts: [num_rows_dense, num_experts_per_token] tensor of expert indices + num_experts: Total number of experts + pad_to_multiple: Padding for each expert's allocation + block_size: Block size for Triton kernels + expert_chunk_size: Number of experts to process at once (for chunked approach) + use_atomic_assign: If True, use atomic-based assignment (faster but requires more memory) + If False, use chunked approach (slower but more memory efficient) + + Returns: + expert_ends: Cumulative end index for each expert (including padding) + expert_pad_begins: Start of padding for each expert + sparse_rows: Remapped row indices [num_rows_dense, num_experts_per_token] + """ + device = top_experts.device + dtype = top_experts.dtype + num_rows_dense, num_experts_per_token = top_experts.shape + num_sparse_rows = num_rows_dense * num_experts_per_token + + # Pass 1: Histogram using atomics + expert_counts = torch.zeros(num_experts, dtype=torch.int32, device=device) + num_blocks = triton.cdiv(num_sparse_rows, block_size) + + sparse_map_histogram_kernel[(num_blocks,)]( + top_experts.flatten(), + expert_counts, + num_sparse_rows, + num_experts, + block_size, + ) + + # Compute padded counts and offsets (on CPU is fine, small tensor) + expert_counts_cpu = expert_counts.cpu() + padded_counts = ((expert_counts_cpu + pad_to_multiple - 1) // pad_to_multiple * pad_to_multiple) + expert_ends = padded_counts.cumsum(0).to(device) + expert_begins = expert_ends - padded_counts.to(device) + expert_pad_begins = expert_begins + expert_counts + + # Pass 2: Assign sparse indices + sparse_rows = torch.empty_like(top_experts, dtype=torch.int32) + + if use_atomic_assign: + # Faster approach: Use atomic counters per expert + expert_atomic_counters = torch.zeros(num_experts, dtype=torch.int32, device=device) + + sparse_map_assign_kernel[(num_blocks,)]( + top_experts.flatten(), + sparse_rows.flatten(), + expert_begins, + expert_atomic_counters, + num_sparse_rows, + block_size, + ) + else: + # Memory-efficient approach: Process experts in chunks + sparse_rows.fill_(0) # Initialize + + for chunk_start in range(0, num_experts, expert_chunk_size): + chunk_end = min(chunk_start + expert_chunk_size, num_experts) + actual_chunk_size = chunk_end - chunk_start + + sparse_map_assign_chunked_kernel[(num_blocks,)]( + top_experts.flatten(), + sparse_rows.flatten(), + expert_begins, + chunk_start, + actual_chunk_size, + num_sparse_rows, + block_size, + torch.int32, # dtype for intermediate calculations + ) + + return expert_ends, expert_pad_begins, sparse_rows diff --git a/tests/functional/test_sparse_map_scalable.py b/tests/functional/test_sparse_map_scalable.py new file mode 100644 index 000000000..3c84a05a4 --- /dev/null +++ b/tests/functional/test_sparse_map_scalable.py @@ -0,0 +1,264 @@ +""" +Tests for scalable sparse map kernel that supports large numbers of experts (128+). + +Tests verify: +1. New kernel matches old kernel for small num_experts (<=32) +2. New kernel matches PyTorch fallback for large num_experts (>32) +3. Correctness of histogram and index assignment +4. Both atomic and chunked variants work correctly +""" + +import pytest +import torch + +from fast_llm.functional.triton.sparse_copy import get_sparse_map, sparse_map_pytorch +from fast_llm.functional.triton.sparse_copy_scalable import get_sparse_map_scalable +from fast_llm.utils import Assert + + +def generate_test_experts(num_tokens: int, num_experts: int, experts_per_token: int, device: str = "cuda"): + """Generate random expert assignments for testing.""" + return torch.randint(0, num_experts, (num_tokens, experts_per_token), device=device) + + +def validate_sparse_map_correctness( + sparse_rows: torch.Tensor, + expert_ends: torch.Tensor, + expert_pad_begins: torch.Tensor, + top_experts: torch.Tensor, + num_experts: int, +): + """ + Validate that a sparse map satisfies all invariants: + 1. All sparse_rows are unique within each expert's range + 2. Expert ranges are non-overlapping and consecutive + 3. Token counts match histogram + """ + num_tokens, experts_per_token = top_experts.shape + + # Check expert ranges are valid + expert_begins = torch.cat([torch.tensor([0], device=expert_ends.device), expert_ends[:-1]]) + assert torch.all(expert_begins <= expert_pad_begins), "Pad begins must be >= begins" + assert torch.all(expert_pad_begins <= expert_ends), "Pad begins must be <= ends" + + # Check each token's assignment + for token_idx in range(num_tokens): + for expert_slot in range(experts_per_token): + expert_id = top_experts[token_idx, expert_slot].item() + sparse_row = sparse_rows[token_idx, expert_slot].item() + + # Check sparse_row is in valid range for this expert + expert_begin = expert_begins[expert_id].item() + expert_end = expert_ends[expert_id].item() + assert expert_begin <= sparse_row < expert_end, ( + f"Token {token_idx} expert {expert_id}: " + f"sparse_row {sparse_row} not in range [{expert_begin}, {expert_end})" + ) + + # Check uniqueness: all sparse_rows should be unique (no collisions) + all_sparse_rows = sparse_rows.flatten() + unique_sparse_rows = torch.unique(all_sparse_rows) + assert len(unique_sparse_rows) == len(all_sparse_rows), "Sparse rows must be unique (no collisions)" + + # Check histogram correctness + flat_experts = top_experts.flatten() + for expert_id in range(num_experts): + expected_count = (flat_experts == expert_id).sum().item() + expert_begin = expert_begins[expert_id].item() + expert_pad_begin = expert_pad_begins[expert_id].item() + actual_count = expert_pad_begin - expert_begin + + assert actual_count == expected_count, ( + f"Expert {expert_id}: count mismatch. Expected {expected_count}, got {actual_count}" + ) + + +@pytest.mark.parametrize("num_experts", [4, 8, 16, 32]) +@pytest.mark.parametrize("num_tokens", [64, 256, 1024]) +@pytest.mark.parametrize("experts_per_token", [1, 2, 4]) +@pytest.mark.parametrize("use_atomic", [True, False]) +def test_scalable_kernel_matches_original_small(num_experts, num_tokens, experts_per_token, use_atomic): + """ + Test that the new scalable kernel produces identical results to the original kernel + for small numbers of experts where the original works fine. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + top_experts = generate_test_experts(num_tokens, num_experts, experts_per_token, device) + + # Get results from original kernel + sparse_map_original = get_sparse_map(top_experts, num_experts, use_triton=True) + + # Get results from new scalable kernel + expert_ends_new, expert_pad_begins_new, sparse_rows_new = get_sparse_map_scalable( + top_experts, num_experts, use_atomic_assign=use_atomic + ) + + # Results should be identical + torch.testing.assert_close(sparse_map_original.expert_ends, expert_ends_new) + torch.testing.assert_close(sparse_map_original.expert_pad_begins, expert_pad_begins_new) + torch.testing.assert_close(sparse_map_original.sparse_rows, sparse_rows_new) + + # Validate correctness + validate_sparse_map_correctness( + sparse_rows_new, expert_ends_new, expert_pad_begins_new, top_experts, num_experts + ) + + +@pytest.mark.parametrize("num_experts", [64, 96, 128, 256]) +@pytest.mark.parametrize("num_tokens", [128, 512]) +@pytest.mark.parametrize("experts_per_token", [2, 4]) +@pytest.mark.parametrize("use_atomic", [True, False]) +def test_scalable_kernel_matches_pytorch_large(num_experts, num_tokens, experts_per_token, use_atomic): + """ + Test that the new scalable kernel produces results matching the PyTorch fallback + for large numbers of experts where the original Triton kernel may fail. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + top_experts = generate_test_experts(num_tokens, num_experts, experts_per_token, device) + + # Get results from PyTorch fallback (always correct reference) + expert_ends_pytorch, expert_pad_begins_pytorch, sparse_rows_pytorch = sparse_map_pytorch( + top_experts, num_experts + ) + + # Get results from new scalable kernel + expert_ends_new, expert_pad_begins_new, sparse_rows_new = get_sparse_map_scalable( + top_experts, num_experts, use_atomic_assign=use_atomic + ) + + # Results should match PyTorch reference + torch.testing.assert_close(expert_ends_pytorch, expert_ends_new) + torch.testing.assert_close(expert_pad_begins_pytorch, expert_pad_begins_new) + torch.testing.assert_close(sparse_rows_pytorch, sparse_rows_new) + + # Validate correctness + validate_sparse_map_correctness( + sparse_rows_new, expert_ends_new, expert_pad_begins_new, top_experts, num_experts + ) + + +@pytest.mark.parametrize("num_experts", [128]) +@pytest.mark.parametrize("num_tokens", [1024]) +@pytest.mark.parametrize("experts_per_token", [4]) +def test_gpt_oss_config(num_experts, num_tokens, experts_per_token): + """ + Test with GPT-OSS specific configuration: 128 experts, 4 active per token. + This is the primary use case that motivated the scalable kernel. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + top_experts = generate_test_experts(num_tokens, num_experts, experts_per_token, device) + + # Test both atomic and chunked variants + for use_atomic in [True, False]: + expert_ends, expert_pad_begins, sparse_rows = get_sparse_map_scalable( + top_experts, num_experts, use_atomic_assign=use_atomic + ) + + # Validate correctness + validate_sparse_map_correctness(sparse_rows, expert_ends, expert_pad_begins, top_experts, num_experts) + + # Verify it matches PyTorch reference + expert_ends_ref, expert_pad_begins_ref, sparse_rows_ref = sparse_map_pytorch(top_experts, num_experts) + torch.testing.assert_close(expert_ends, expert_ends_ref) + torch.testing.assert_close(expert_pad_begins, expert_pad_begins_ref) + torch.testing.assert_close(sparse_rows, sparse_rows_ref) + + +def test_edge_cases(): + """Test edge cases like single expert, all tokens to one expert, etc.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + + # Edge case 1: Single expert (degenerate but should work) + top_experts = torch.zeros((100, 2), dtype=torch.long, device=device) + expert_ends, expert_pad_begins, sparse_rows = get_sparse_map_scalable(top_experts, 1) + validate_sparse_map_correctness(sparse_rows, expert_ends, expert_pad_begins, top_experts, 1) + + # Edge case 2: All tokens go to same expert (worst case for load balancing) + top_experts = torch.full((100, 4), 7, dtype=torch.long, device=device) + expert_ends, expert_pad_begins, sparse_rows = get_sparse_map_scalable(top_experts, 16) + validate_sparse_map_correctness(sparse_rows, expert_ends, expert_pad_begins, top_experts, 16) + + # Edge case 3: Perfectly balanced distribution + num_tokens = 128 + num_experts = 64 + experts_per_token = 2 + # Each token gets consecutive expert pairs: [0,1], [2,3], [4,5], ... + top_experts = torch.arange(num_tokens * experts_per_token, device=device).view(num_tokens, experts_per_token) + top_experts = top_experts % num_experts + expert_ends, expert_pad_begins, sparse_rows = get_sparse_map_scalable(top_experts, num_experts) + validate_sparse_map_correctness(sparse_rows, expert_ends, expert_pad_begins, top_experts, num_experts) + + +@pytest.mark.parametrize("num_experts", [32, 64, 128]) +def test_deterministic_results(num_experts): + """Test that kernel produces deterministic results across multiple runs.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + torch.manual_seed(42) + + top_experts = generate_test_experts(512, num_experts, 4, device) + + # Run multiple times and check results are identical + results = [] + for _ in range(3): + expert_ends, expert_pad_begins, sparse_rows = get_sparse_map_scalable(top_experts, num_experts) + results.append((expert_ends.clone(), expert_pad_begins.clone(), sparse_rows.clone())) + + # All runs should produce identical results + for i in range(1, len(results)): + torch.testing.assert_close(results[0][0], results[i][0]) + torch.testing.assert_close(results[0][1], results[i][1]) + torch.testing.assert_close(results[0][2], results[i][2]) + + +def test_automatic_kernel_selection(): + """ + Test that get_sparse_map() automatically selects the correct kernel + based on num_experts and produces correct results. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + device = "cuda" + + # Test 1: Small num_experts should use original kernel (num_experts <= 64) + top_experts_small = generate_test_experts(256, 32, 4, device) + sparse_map_small = get_sparse_map(top_experts_small, 32) + validate_sparse_map_correctness( + sparse_map_small.sparse_rows, + sparse_map_small.expert_ends, + sparse_map_small.expert_pad_begins, + top_experts_small, + 32, + ) + + # Test 2: Large num_experts should use scalable kernel (num_experts > 64) + top_experts_large = generate_test_experts(256, 128, 4, device) + sparse_map_large = get_sparse_map(top_experts_large, 128) + validate_sparse_map_correctness( + sparse_map_large.sparse_rows, + sparse_map_large.expert_ends, + sparse_map_large.expert_pad_begins, + top_experts_large, + 128, + ) + + # Test 3: Results should match PyTorch reference + expert_ends_ref, expert_pad_begins_ref, sparse_rows_ref = sparse_map_pytorch(top_experts_large, 128) + torch.testing.assert_close(sparse_map_large.expert_ends, expert_ends_ref) + torch.testing.assert_close(sparse_map_large.expert_pad_begins, expert_pad_begins_ref) + torch.testing.assert_close(sparse_map_large.sparse_rows, sparse_rows_ref) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 552c10c2f..4acdc5228 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -697,6 +697,7 @@ def _update_and_add_testing_config( _update_and_add_testing_config( # Tests GPT-OSS: heterogeneous blocks (alternating sliding/full attention), MoE, YARN RoPE, attention biases. + # Uses 4 experts for testing (smaller, faster tests). "llama", "gpt_oss", updates={ @@ -761,6 +762,36 @@ def _update_and_add_testing_config( ) +_update_and_add_testing_config( + # Tests GPT-OSS with 128 experts: validates scalable MoE kernel. + # Uses tiny experts (intermediate_size=256) to keep tests fast while exercising the scalable kernel. + "gpt_oss", + "gpt_oss_128_experts", + updates={ + ("model", "base_model", "decoder", "blocks", "sliding", "mlp", "experts"): 128, + ("model", "base_model", "decoder", "blocks", "sliding", "mlp", "intermediate_size"): 256, + ("model", "base_model", "decoder", "blocks", "full", "mlp", "experts"): 128, + ("model", "base_model", "decoder", "blocks", "full", "mlp", "intermediate_size"): 256, + # Reduce to 2 blocks to keep tests fast + ("model", "base_model", "decoder", "num_blocks"): 2, + ("model", "base_model", "decoder", "pattern"): ["sliding", "full"], + }, + megatron_args=None, + checkpoint_format=GptOssCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, + compare_factor=2.0, + # Micro-sequence split not supported (due to MoE). + skip_tests=("ms",), +) + + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models")