From 1f86a2dd2e6423155a44a604e71e9760defca268 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Apr 2025 19:50:59 +0000 Subject: [PATCH 001/190] moe refactoring Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 4 + .../layers/fused_moe/modular_kernel.py | 99 +++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/modular_kernel.py diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f6305822c2d..0cc0dccaccb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1420,6 +1420,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + if True: + intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K) + intermediate_cache3.mul_(curr_topk_weights.view(tokens_in_chunk, -1, 1)) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py new file mode 100644 index 00000000000..a688ae41a75 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -0,0 +1,99 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple +import torch + + +class FusedMoEDispatchQuantize(ABC): + def __init__(self): + pass + + @abstractmethod + def apply( + self, + hidden_states, + hidden_states_scales, + topk_ids, + num_experts, + expert_map, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # returns (hidden_states, scales, sorted_token_ids, expert_ids, inv_perm) # make more abstract? + raise NotImplementedError + + +# store weights, etc. here +class FusedMoEExperts(ABC): + def __init__(self): + pass + + @abstractmethod + def apply(self): + raise NotImplementedError + + +class FusedMoEUnpermuteCombine(ABC): + def __init__(self): + pass + + @abstractmethod + def apply( + self, + out, + hidden_states, + topk_weights, + topk, + inv_perm, + ) -> torch>Tensor: + raise NotImplementedError + + +class ModularFusedMoEKernel(torch.nn.Module): # should this be a module? + def __init__( + self, + dispatch: FusedMoEDispatchQuantize, + fused_experts: FusedMoEExperts, + combine: FusedMoEUnpermuteCombine, + ): + self.dispatch = dispatch + self.fused_experts = fused_experts + self.combine = combine + + def forward( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + self.dispatch() + + fused_out = self.fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + ) + + self.combine(hidden_states, fused_out) + return hidden_states From 6b6630bb401151985acaab76d681a2601f913d86 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Apr 2025 22:06:39 +0000 Subject: [PATCH 002/190] module deepgemm moe working Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 16 +- .../layers/fused_moe/deep_gemm_moe.py | 139 ++++++++++++++++++ .../layers/fused_moe/modular_kernel.py | 112 ++++++++++---- 3 files changed, 235 insertions(+), 32 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 38c7e461bb9..44efd48d689 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,8 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8) + deep_gemm_moe_fp8, + modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -381,12 +382,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): block_size = [block_m, block_m] dtype = torch.bfloat16 - # only aligned sizes + # only aligned sizes TODO: use _valid_deep_gemm here instead? if (N % block_m != 0 or K % block_m != 0 or topk > E): pytest.skip( f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - if N <= 512: + if False and N <= 512: pytest.skip("Skipping N <= 512 until performance issues solved.") vllm_config = VllmConfig() @@ -427,6 +428,13 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + if True: + dgm = modular_deep_gemm_fused_moe_fp8() + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): + return dgm(a, w1, w2, topk_weights, topk_ids, w1_scale=w1_s, w2_scale=w2_s) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -439,7 +447,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): topk_weights, topk_ids, token_expert_indices = fused_topk( a, score.float(), topk, False) - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 353c8cc9d59..1c5212e943e 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, @@ -292,3 +293,141 @@ def deep_gemm_moe_fp8( workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights) return out_hidden_states + + +class DeepGemmDispatch(mk.FusedMoEDispatchQuantize): + def __init__(self): + super().__init__() + import deep_gemm as dg + block_m = dg.get_m_alignment_for_contiguous_layout() + self.block_shape = [block_m, block_m] + + def apply( + self, + hidden_states: torch.Tensor, + hidden_states_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + q_hidden_states, q_hidden_states_scale = _fp8_quantize( + hidden_states, + hidden_states_scale, + self.block_shape, + ) + + q_hidden_states, q_hidden_states_scale, _, expert_ids, inv_perm = _moe_permute( + q_hidden_states, + q_hidden_states_scale, + topk_ids, + num_experts, + expert_map, + self.block_shape[0], + ) + + return q_hidden_states, q_hidden_states_scale, expert_ids, inv_perm + + +class DeepGemmExperts(mk.FusedMoEExperts): + def __init__(self): + super().__init__() + import deep_gemm as dg + block_m = dg.get_m_alignment_for_contiguous_layout() + self.block_shape = [block_m, block_m] + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int]: + block_m = self.block_shape[0] + M_sum = (M * topk) + num_experts * (block_m - 1) + M_sum = round_up(M_sum, block_m) + workspace1 = M_sum * max(N, K) + workspace2 = M_sum * (N // 2) + # return tuples???? + return (workspace1, workspace2) + + def apply( + self, + q_hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + inplace: bool, + activation: str, + expert_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: # or None? assume inplace? + import deep_gemm as dg + + # chunking in here or in ModularFusedMoEKernel? ignore for now + M_sum = q_hidden_states.shape[0] # double check this + E, N, _ = w1.shape + _, K, _ = w2.shape + + #print(f"M_sum = {M_sum}") + + workspace1 = _resize_cache(workspace13, (M_sum, N)) + workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) + workspace3 = _resize_cache(workspace13, (M_sum, K)) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (q_hidden_states, a1_scale), (w1, w1_scale), + workspace1, + expert_ids) + + if activation == "silu": + torch.ops._C.silu_and_mul(workspace2, + workspace1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(workspace2, + workspace1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + a2q_scale: Optional[torch.Tensor] = None + + qworkspace2, a2q_scale = _fp8_quantize( + workspace2, a2_scale, self.block_shape) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (qworkspace2, a2q_scale), (w2, w2_scale), + workspace3, expert_ids) + + return workspace3 + + +class DeepGemmUnpermuteCombine(mk.FusedMoEUnpermuteCombine): + def __init__(self): + super().__init__() + + def apply( + self, + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: Optional[torch.Tensor], + ) -> torch.Tensor: + _moe_unpermute_and_reduce( + out, + hidden_states, + inv_perm, + topk_weights + ) + return out + + +def modular_deep_gemm_fused_moe_fp8() -> mk.ModularFusedMoEKernel: + return mk.ModularFusedMoEKernel( + DeepGemmDispatch(), + DeepGemmExperts(), + DeepGemmUnpermuteCombine(), + ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a688ae41a75..5866129eccb 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,13 +10,13 @@ def __init__(self): @abstractmethod def apply( self, - hidden_states, - hidden_states_scales, - topk_ids, - num_experts, - expert_map, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # returns (hidden_states, scales, sorted_token_ids, expert_ids, inv_perm) # make more abstract? + hidden_states: torch.Tensor, + hidden_states_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + # returns (hidden_states, scales, expert_ids, inv_perm) # make more abstract? raise NotImplementedError @@ -26,7 +26,32 @@ def __init__(self): pass @abstractmethod - def apply(self): + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int]: + raise NotImplementedError + + @abstractmethod + def apply( + self, + out: torch.Tensor, + q_hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + activation: str, + expert_ids: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + q_hidden_states_scale: Optional[torch.Tensor], + hidden_states_scale_2: Optional[torch.Tensor], + workspace1: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: # or None? assume inplace? raise NotImplementedError @@ -37,12 +62,11 @@ def __init__(self): @abstractmethod def apply( self, - out, - hidden_states, - topk_weights, - topk, - inv_perm, - ) -> torch>Tensor: + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: Optional[torch.Tensor], + ) -> torch.Tensor: raise NotImplementedError @@ -53,6 +77,7 @@ def __init__( fused_experts: FusedMoEExperts, combine: FusedMoEUnpermuteCombine, ): + super().__init__() self.dispatch = dispatch self.fused_experts = fused_experts self.combine = combine @@ -75,25 +100,56 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - self.dispatch() + M, _ = hidden_states.shape + E, N, _ = w1.shape + K = w2.shape[1] + if global_num_experts == -1: + global_num_experts = E + top_k = topk_ids.shape[1] - fused_out = self.fused_experts( + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + #print(f"TKN = {topk_ids.numel()} {M*top_k}") + + workspace13_shape, workspace2_shape = self.fused_experts.workspace_shapes(M, N, K, top_k, global_num_experts) + + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 + workspace13 = torch.empty(workspace13_shape, + device=hidden_states.device, + dtype=hidden_states.dtype) + workspace2 = torch.empty(workspace2_shape, + device=hidden_states.device, + dtype=hidden_states.dtype) + + #print(f"\nbefore M = {hidden_states.shape[0]}") + + hidden_states, a1_scale, expert_ids, inv_perm = self.dispatch.apply( + hidden_states, + a1_scale, + topk_ids, + global_num_experts, + expert_map, + ) + + #print(f"after M = {hidden_states.shape[0]}") + + fused_out = self.fused_experts.apply( hidden_states, w1, w2, - topk_weights, - topk_ids, inplace, activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1_scale, - a2_scale, + expert_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, ) - self.combine(hidden_states, fused_out) - return hidden_states + return self.combine.apply(out_hidden_states, fused_out, topk_weights, inv_perm) From 33c61d66ff5d5f682374fd761331455d634b76c3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 01:13:38 +0000 Subject: [PATCH 003/190] working deep gemm, wip cutlass Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 5 + .../layers/fused_moe/cutlass_moe.py | 180 ++++++++++++++++++ .../layers/fused_moe/deep_gemm_moe.py | 2 +- .../layers/fused_moe/modular_kernel.py | 2 + 4 files changed, 188 insertions(+), 1 deletion(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 44efd48d689..176a158493a 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -404,6 +405,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) +# if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): +# pytest.skip( +# f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 960c7f83485..416630be522 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -5,6 +5,9 @@ import torch from vllm import _custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import (_resize_cache, + _fp8_perm) #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -173,8 +176,185 @@ def cutlass_moe_fp8( ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, expert_offsets[:-1], problem_sizes2, ab_strides2, ab_strides2, c_strides2) + # Gather tokens c2 = c2[c_map].view(m, topk, k) if not apply_router_weight_on_input: c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) return c2.sum(dim=1) + + +class CutlassDispatch(mk.FusedMoEDispatchQuantize): + def __init__(self): + super().__init__() + + def apply( + self, + hidden_states: torch.Tensor, + hidden_states_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + m = hidden_states.size(0) + k = w1_q.size(1) + n = w2_q.size(1) + device = hidden_states.device + + # a2_scale.numel() != 1 if a2_scale is not None else False + per_act_token = hidden_states_scale.numel() != 1 if hidden_states_scale is not None else False + + expert_offsets = torch.empty((num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((num_experts, 3), + dtype=torch.int32, + device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + ops.get_cutlass_moe_mm_data(topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, c_map, + num_experts, + n, + k) + + rep_a_q = _fp8_perm(hidden_states, a_map) + rep_a1_scales = hidden_states_scale[a_map] if per_act_token else hidden_states_scale + + return rep_a_q, rep_a1_scales, expert_offsets, c_map + + +class CutlassExperts(mk.FusedMoEExperts): + def __init__( + self, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + ): + super().__init__() + self.ab_strides1 = ab_strides1 + self.c_strides1 = c_strides1 + self.ab_strides2 = ab_strides2 + self.c_strides2 = c_strides2 + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int]: + workspace1 = M * topk * N + workspace2 = M * topk * K + # return tuples???? + return (workspace1, workspace2) + + def apply( + self, + q_hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + inplace: bool, + activation: str, + expert_offsets: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: # or None? assume inplace? + # chunking in here or in ModularFusedMoEKernel? ignore for now + M = q_hidden_states.shape[0] + E, N, _ = w1.shape + _, K, _ = w2.shape + topk = X + device = q_hidden_states.device + + # fix names + c1 = _resize_cache(workspace13, (M * topk, N)) + c2 = _resize_cache(workspace13, (M * topk, K)) + c3 = _resize_cache(workspace2, (M * topk, N // 2)) + + # HACK, share these with other bits + problem_sizes1 = torch.empty((E, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((E, 3), + dtype=torch.int32, + device=E) + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + ops.cutlass_moe_mm(c1, q_hidden_states, w1, a1_scale, w1_scale, + expert_offsets[:-1], + problem_sizes1, + self.ab_strides1, + self.ab_strides1, + self.c_strides1) + + if activation == "silu": + torch.ops._C.silu_and_mul(c3, c1) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(c3, c1) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + intemediate_q, a2_scale = ops.scaled_fp8_quant( + c3, a2_scale, use_per_token_if_dynamic=per_act_token) + + ops.cutlass_moe_mm(c2, intemediate_q, w2, a2_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, self.ab_strides2, + self.ab_strides2, self.c_strides2) + + return c2 + + +class CutlassUnpermuteCombine(mk.FusedMoEUnpermuteCombine): + def __init__(self, out_dtype): + super().__init__() + self.out_dtype = out_dtype + + def apply( + self, + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: Optional[torch.Tensor], + ) -> torch.Tensor: + M, topk = topk_weights.shape + K = hidden_states.shape[1] + hidden_states = hidden_states[inv_perm, ...] + hidden_states = hidden_states.view(M, topk, K) + out = hidden_states.mul_(topk_weights.view(M, topk, 1).to(self.out_dtype)).sum(dim=1) + return out + + +def modular_cutlass_moe_fp8( + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype, +) -> mk.ModularFusedMoEKernel: + return mk.ModularFusedMoEKernel( + CutlassDispatch(), + CutlassExperts( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ), + CutlassUnpermuteCombine(out_dtype), + ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 1c5212e943e..5050e251f54 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -19,7 +19,7 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None - +# TODO: check types? def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 5866129eccb..fbce6dbb14c 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -70,6 +70,8 @@ def apply( raise NotImplementedError +# Note: only intended for use with a single model layer (due to temp buffers, constants, etc.) +# TODO: permute/unpermute must be paired class ModularFusedMoEKernel(torch.nn.Module): # should this be a module? def __init__( self, From e1ab18a8fb3173480b4b9247e224c121c728ca53 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 13:49:41 +0000 Subject: [PATCH 004/190] working cutlass Signed-off-by: Bill Nell --- tests/kernels/test_cutlass_moe.py | 274 ++++++++++++++++++ .../layers/fused_moe/cutlass_moe.py | 116 ++++---- .../layers/fused_moe/deep_gemm_moe.py | 14 +- .../layers/fused_moe/modular_kernel.py | 16 +- 4 files changed, 359 insertions(+), 61 deletions(-) create mode 100644 tests/kernels/test_cutlass_moe.py diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py new file mode 100644 index 00000000000..d4b62a8c86e --- /dev/null +++ b/tests/kernels/test_cutlass_moe.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8, modular_cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk) +from vllm.platforms import current_platform + +NUM_EXPERTS = [40, 64] +TOP_KS = [6, 8] + + +def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, + w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) + + +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_no_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + # Get the right scale for tests. + _, a_scale1 = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale1, + use_per_token_if_dynamic=per_act_token) + + a_d = a_q.float().mul(a_scale1).to(dtype) + + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 + + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) + + if True: + cutlass_moe_fp8_fn = modular_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ) + else: + def cutlass_moe_fp8_fn( + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + a1_scale=a_scale1 + ): + return cutlass_moe_fp8( + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale1 + ) + + cutlass_output = cutlass_moe_fp8_fn( + a, + w1_q, + w2_q, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + a1_scale=a_scale1) + + #print(triton_output) + #print(cutlass_output) + #print("*") + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=5e-2, + rtol=1e-2) + + +@pytest.mark.parametrize("m", [2, 64, 224]) +@pytest.mark.parametrize("n", [1024, 3072]) +@pytest.mark.parametrize("k", [1024, 1536]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_cuda_graph( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, +): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + # Get the right scale for tests. + _, a_scale1 = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=per_act_token) + a_q, _ = ops.scaled_fp8_quant(a, + a_scale1, + use_per_token_if_dynamic=per_act_token) + + a_d = a_q.float().mul(a_scale1).to(dtype) + + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 + + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) + w1_q = w1_q.transpose(1, 2) + w2_q = w2_q.transpose(1, 2) + + ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) + + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale, + topk_weights, topk_ids, ab_strides1, + c_strides1, ab_strides2, c_strides2) + torch.cuda.synchronize() + graph.replay() + torch.cuda.synchronize() + + #print(triton_output) + #print(cutlass_output) + #print("*") + + torch.testing.assert_close(triton_output, + cutlass_output, + atol=9e-2, + rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 416630be522..dafe0d8c601 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -60,7 +60,7 @@ def cutlass_moe_fp8( - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize the intermediate result between the gemms. Shape: scalar or [M] - - out_dtype (torch.Tensor): The output tensor type. + - out_dtype (torch.dtype): The output tensor type. - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, every Rank is responsible for a subset of experts. expert_map is a mapping from global expert-id to local expert-id. When expert_map[i] @@ -190,19 +190,24 @@ def __init__(self): def apply( self, - hidden_states: torch.Tensor, - hidden_states_scale: Optional[torch.Tensor], + a: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], + k: int # Try to get rid of? ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: - m = hidden_states.size(0) - k = w1_q.size(1) - n = w2_q.size(1) - device = hidden_states.device + m, n = a.shape + device = a.device # a2_scale.numel() != 1 if a2_scale is not None else False - per_act_token = hidden_states_scale.numel() != 1 if hidden_states_scale is not None else False + #per_act_token = hidden_states_scale.numel() != 1 if hidden_states_scale is not None else False + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a_q, a1_scale = ops.scaled_fp8_quant( + a, a1_scale, use_per_token_if_dynamic=per_act_token) expert_offsets = torch.empty((num_experts + 1), dtype=torch.int32, @@ -221,15 +226,16 @@ def apply( expert_offsets, problem_sizes1, problem_sizes2, - a_map, c_map, + a_map, + c_map, num_experts, - n, - k) + k, + n) - rep_a_q = _fp8_perm(hidden_states, a_map) - rep_a1_scales = hidden_states_scale[a_map] if per_act_token else hidden_states_scale + rep_a_q = _fp8_perm(a_q, a_map) + rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale - return rep_a_q, rep_a1_scales, expert_offsets, c_map + return rep_a_q, rep_a1_scales, expert_offsets, c_map, (problem_sizes1, problem_sizes2) class CutlassExperts(mk.FusedMoEExperts): @@ -249,13 +255,13 @@ def __init__( def workspace_shapes( self, M: int, - N: int, K: int, + N: int, topk: int, num_experts: int ) -> Tuple[int, int]: - workspace1 = M * topk * N - workspace2 = M * topk * K + workspace1 = M * topk * max(2 * N, K) + workspace2 = M * topk * N # return tuples???? return (workspace1, workspace2) @@ -273,52 +279,61 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + context: Optional[Any] = None, ) -> torch.Tensor: # or None? assume inplace? # chunking in here or in ModularFusedMoEKernel? ignore for now M = q_hidden_states.shape[0] - E, N, _ = w1.shape - _, K, _ = w2.shape - topk = X - device = q_hidden_states.device + E, N, K = w2.shape # because w1 + w2 are transposed # fix names - c1 = _resize_cache(workspace13, (M * topk, N)) - c2 = _resize_cache(workspace13, (M * topk, K)) - c3 = _resize_cache(workspace2, (M * topk, N // 2)) - - # HACK, share these with other bits - problem_sizes1 = torch.empty((E, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((E, 3), - dtype=torch.int32, - device=E) + c1 = _resize_cache(workspace13, (M, N * 2)) + c2 = _resize_cache(workspace2, (M, N)) + c3 = _resize_cache(workspace13, (M, K)) + # why check a1_scale again? per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - ops.cutlass_moe_mm(c1, q_hidden_states, w1, a1_scale, w1_scale, - expert_offsets[:-1], - problem_sizes1, - self.ab_strides1, - self.ab_strides1, - self.c_strides1) + assert context is not None + problem_sizes1, problem_sizes2 = context + + ops.cutlass_moe_mm( + c1, + q_hidden_states, + w1, + a1_scale, + w1_scale, + expert_offsets[:-1], + problem_sizes1, + self.ab_strides1, + self.ab_strides1, + self.c_strides1 + ) if activation == "silu": - torch.ops._C.silu_and_mul(c3, c1) + torch.ops._C.silu_and_mul(c2, c1) elif activation == "gelu": - torch.ops._C.gelu_and_mul(c3, c1) + torch.ops._C.gelu_and_mul(c2, c1) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") intemediate_q, a2_scale = ops.scaled_fp8_quant( - c3, a2_scale, use_per_token_if_dynamic=per_act_token) + c2, a2_scale, use_per_token_if_dynamic=per_act_token) - ops.cutlass_moe_mm(c2, intemediate_q, w2, a2_scale, w2_scale, - expert_offsets[:-1], problem_sizes2, self.ab_strides2, - self.ab_strides2, self.c_strides2) + ops.cutlass_moe_mm( + c3, + intemediate_q, + w2, + a2_scale, + w2_scale, + expert_offsets[:-1], + problem_sizes2, + self.ab_strides2, + self.ab_strides2, + self.c_strides2 + ) - return c2 + return c3 class CutlassUnpermuteCombine(mk.FusedMoEUnpermuteCombine): @@ -335,10 +350,11 @@ def apply( ) -> torch.Tensor: M, topk = topk_weights.shape K = hidden_states.shape[1] - hidden_states = hidden_states[inv_perm, ...] - hidden_states = hidden_states.view(M, topk, K) - out = hidden_states.mul_(topk_weights.view(M, topk, 1).to(self.out_dtype)).sum(dim=1) - return out + hidden_states = hidden_states[inv_perm, ...].view(-1, topk, K) + hidden_states = (hidden_states * topk_weights.view(M, -1, 1).to(self.out_dtype)).sum(dim=1) + # use moe_sum? to write into out? + return hidden_states + def modular_cutlass_moe_fp8( @@ -346,7 +362,7 @@ def modular_cutlass_moe_fp8( c_strides1: torch.Tensor, ab_strides2: torch.Tensor, c_strides2: torch.Tensor, - out_dtype, + out_dtype: torch.dtype = torch.half, ) -> mk.ModularFusedMoEKernel: return mk.ModularFusedMoEKernel( CutlassDispatch(), diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 5050e251f54..2be26da9fa1 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch @@ -306,10 +306,13 @@ def apply( self, hidden_states: torch.Tensor, hidden_states_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + n: int, # TODO try to get rid of this? + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], Optional[Any]]: + # TODO: move? q_hidden_states, q_hidden_states_scale = _fp8_quantize( hidden_states, hidden_states_scale, @@ -325,7 +328,7 @@ def apply( self.block_shape[0], ) - return q_hidden_states, q_hidden_states_scale, expert_ids, inv_perm + return q_hidden_states, q_hidden_states_scale, expert_ids, inv_perm, None class DeepGemmExperts(mk.FusedMoEExperts): @@ -346,8 +349,8 @@ def workspace_shapes( block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) - workspace1 = M_sum * max(N, K) - workspace2 = M_sum * (N // 2) + workspace1 = M_sum * max(N * 2, K) + workspace2 = M_sum * N # return tuples???? return (workspace1, workspace2) @@ -365,6 +368,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + context: Optional[Any] = None, ) -> torch.Tensor: # or None? assume inplace? import deep_gemm as dg diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index fbce6dbb14c..ed358273bb4 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch @@ -12,11 +12,13 @@ def apply( self, hidden_states: torch.Tensor, hidden_states_scale: Optional[torch.Tensor], + a2: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: - # returns (hidden_states, scales, expert_ids, inv_perm) # make more abstract? + n: int, # TODO try to get rid of this? + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], Optional[Any]]: + # returns (hidden_states, scales, expert_ids, inv_perm, context) # make more abstract? raise NotImplementedError @@ -103,8 +105,7 @@ def forward( a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: M, _ = hidden_states.shape - E, N, _ = w1.shape - K = w2.shape[1] + E, K, N = w2.shape if global_num_experts == -1: global_num_experts = E top_k = topk_ids.shape[1] @@ -129,12 +130,14 @@ def forward( #print(f"\nbefore M = {hidden_states.shape[0]}") - hidden_states, a1_scale, expert_ids, inv_perm = self.dispatch.apply( + hidden_states, a1_scale, expert_ids, inv_perm, context = self.dispatch.apply( hidden_states, a1_scale, + a2_scale, topk_ids, global_num_experts, expert_map, + w2.shape[1], ) #print(f"after M = {hidden_states.shape[0]}") @@ -152,6 +155,7 @@ def forward( a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, + context=context, ) return self.combine.apply(out_hidden_states, fused_out, topk_weights, inv_perm) From 4965854709dc952ed5a68de6e49971f6e51097d5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 20:33:59 +0000 Subject: [PATCH 005/190] deepgemm working again Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 119 +++++++++--------- .../layers/fused_moe/deep_gemm_moe.py | 107 ++++++++-------- .../layers/fused_moe/modular_kernel.py | 101 ++++++++------- 3 files changed, 163 insertions(+), 164 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index dafe0d8c601..9f9d24aa385 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -184,11 +184,12 @@ def cutlass_moe_fp8( return c2.sum(dim=1) -class CutlassDispatch(mk.FusedMoEDispatchQuantize): - def __init__(self): +class CutlassDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, out_dtype: torch.dtype): super().__init__() + self.out_dtype = out_dtype - def apply( + def dispatch( self, a: torch.Tensor, a1_scale: Optional[torch.Tensor], @@ -196,31 +197,27 @@ def apply( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - k: int # Try to get rid of? - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: - m, n = a.shape - device = a.device - - # a2_scale.numel() != 1 if a2_scale is not None else False - #per_act_token = hidden_states_scale.numel() != 1 if hidden_states_scale is not None else False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + # why do we need to check a2_scale here? per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) a_q, a1_scale = ops.scaled_fp8_quant( a, a1_scale, use_per_token_if_dynamic=per_act_token) - expert_offsets = torch.empty((num_experts + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) + return a_q, a1_scale, topk_ids - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + def combine( + self, + out: torch.Tensor, #TBD + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + ) -> torch.Tensor: + M, topk = topk_weights.shape + K = hidden_states.shape[1] + hidden_states = (hidden_states.view(-1, topk, K) * topk_weights.view(M, -1, 1).to(self.out_dtype)).sum(dim=1) + # use moe_sum? to write into out? + return hidden_states ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, @@ -238,7 +235,7 @@ def apply( return rep_a_q, rep_a1_scales, expert_offsets, c_map, (problem_sizes1, problem_sizes2) -class CutlassExperts(mk.FusedMoEExperts): +class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, ab_strides1: torch.Tensor, @@ -267,36 +264,64 @@ def workspace_shapes( def apply( self, + out: torch.Tensor, # TBD q_hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_ids: torch.Tensor, inplace: bool, activation: str, - expert_offsets: torch.Tensor, + expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - context: Optional[Any] = None, ) -> torch.Tensor: # or None? assume inplace? # chunking in here or in ModularFusedMoEKernel? ignore for now M = q_hidden_states.shape[0] - E, N, K = w2.shape # because w1 + w2 are transposed + E, N, _ = w2.shape # because w1 + w2 are transposed + K = w1.shape[1] #? + assert K == w2.shape[-1] + device = q_hidden_states.device + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + expert_offsets = torch.empty((E + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((E, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((E, 3), + dtype=torch.int32, + device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + #print(f"prob {k}, {n}") + + ops.get_cutlass_moe_mm_data(topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + E, + N, + K) + + q_hidden_states = _fp8_perm(q_hidden_states, a_map) + a1_scale = a1_scale[a_map] if per_act_token else a1_scale # fix names c1 = _resize_cache(workspace13, (M, N * 2)) c2 = _resize_cache(workspace2, (M, N)) c3 = _resize_cache(workspace13, (M, K)) - # why check a1_scale again? - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - - assert context is not None - problem_sizes1, problem_sizes2 = context - ops.cutlass_moe_mm( c1, q_hidden_states, @@ -333,28 +358,9 @@ def apply( self.c_strides2 ) - return c3 - - -class CutlassUnpermuteCombine(mk.FusedMoEUnpermuteCombine): - def __init__(self, out_dtype): - super().__init__() - self.out_dtype = out_dtype - - def apply( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - inv_perm: Optional[torch.Tensor], - ) -> torch.Tensor: - M, topk = topk_weights.shape - K = hidden_states.shape[1] - hidden_states = hidden_states[inv_perm, ...].view(-1, topk, K) - hidden_states = (hidden_states * topk_weights.view(M, -1, 1).to(self.out_dtype)).sum(dim=1) - # use moe_sum? to write into out? - return hidden_states + c3 = c3[c_map, ...] + return c3 def modular_cutlass_moe_fp8( @@ -363,14 +369,13 @@ def modular_cutlass_moe_fp8( ab_strides2: torch.Tensor, c_strides2: torch.Tensor, out_dtype: torch.dtype = torch.half, -) -> mk.ModularFusedMoEKernel: - return mk.ModularFusedMoEKernel( - CutlassDispatch(), +) -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + CutlassDispatchCombine(out_dtype), CutlassExperts( ab_strides1, c_strides1, ab_strides2, c_strides2, ), - CutlassUnpermuteCombine(out_dtype), ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 2be26da9fa1..3b39e45c7ea 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch @@ -19,6 +19,13 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + +def deep_gemm_block_shape() -> List[int]: + import deep_gemm as dg + block = dg.get_m_alignment_for_contiguous_layout() + return [block, block] + + # TODO: check types? def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, @@ -109,7 +116,8 @@ def _moe_unpermute_and_reduce( """ M, topk = topk_weight.shape K = curr_hidden.shape[1] - curr_hidden = curr_hidden[inv_perm, ...] + if inv_perm is not None: + curr_hidden = curr_hidden[inv_perm, ...] curr_hidden = curr_hidden.view(-1, topk, K) curr_hidden.mul_(topk_weight.view(M, -1, 1)) ops.moe_sum(curr_hidden, out) @@ -295,48 +303,46 @@ def deep_gemm_moe_fp8( return out_hidden_states -class DeepGemmDispatch(mk.FusedMoEDispatchQuantize): +class DeepGemmDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): def __init__(self): super().__init__() - import deep_gemm as dg - block_m = dg.get_m_alignment_for_contiguous_layout() - self.block_shape = [block_m, block_m] + self.block_shape = deep_gemm_block_shape() - def apply( + def dispatch( self, - hidden_states: torch.Tensor, - hidden_states_scale: Optional[torch.Tensor], + a: torch.Tensor, + a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - n: int, # TODO try to get rid of this? - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], Optional[Any]]: - # TODO: move? + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: q_hidden_states, q_hidden_states_scale = _fp8_quantize( - hidden_states, - hidden_states_scale, + a, + a1_scale, self.block_shape, ) + return q_hidden_states, q_hidden_states_scale, topk_ids - q_hidden_states, q_hidden_states_scale, _, expert_ids, inv_perm = _moe_permute( - q_hidden_states, - q_hidden_states_scale, - topk_ids, - num_experts, - expert_map, - self.block_shape[0], + def combine( + self, + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + ) -> torch.Tensor: + _moe_unpermute_and_reduce( + out, + hidden_states, + None, + topk_weights ) - - return q_hidden_states, q_hidden_states_scale, expert_ids, inv_perm, None + return out -class DeepGemmExperts(mk.FusedMoEExperts): +class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): super().__init__() - import deep_gemm as dg - block_m = dg.get_m_alignment_for_contiguous_layout() - self.block_shape = [block_m, block_m] + self.block_shape = deep_gemm_block_shape() def workspace_shapes( self, @@ -352,33 +358,43 @@ def workspace_shapes( workspace1 = M_sum * max(N * 2, K) workspace2 = M_sum * N # return tuples???? - return (workspace1, workspace2) + return (workspace1, workspace2) # TODO add type def apply( self, + out: torch.Tensor, #unused tbd q_hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_ids: torch.Tensor, inplace: bool, activation: str, - expert_ids: torch.Tensor, + expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - context: Optional[Any] = None, ) -> torch.Tensor: # or None? assume inplace? import deep_gemm as dg # chunking in here or in ModularFusedMoEKernel? ignore for now - M_sum = q_hidden_states.shape[0] # double check this E, N, _ = w1.shape _, K, _ = w2.shape #print(f"M_sum = {M_sum}") + q_hidden_states, a1_scale, _, expert_ids, inv_perm = _moe_permute( + q_hidden_states, + a1_scale, + topk_ids, + E, + expert_map, + self.block_shape[0], + ) + + M_sum = q_hidden_states.shape[0] workspace1 = _resize_cache(workspace13, (M_sum, N)) workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) workspace3 = _resize_cache(workspace13, (M_sum, K)) @@ -406,32 +422,13 @@ def apply( (qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids) - return workspace3 - - -class DeepGemmUnpermuteCombine(mk.FusedMoEUnpermuteCombine): - def __init__(self): - super().__init__() + workspace3 = workspace3[inv_perm, ...] - def apply( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - inv_perm: Optional[torch.Tensor], - ) -> torch.Tensor: - _moe_unpermute_and_reduce( - out, - hidden_states, - inv_perm, - topk_weights - ) - return out + return workspace3 -def modular_deep_gemm_fused_moe_fp8() -> mk.ModularFusedMoEKernel: - return mk.ModularFusedMoEKernel( - DeepGemmDispatch(), +def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + DeepGemmDispatchCombine(), DeepGemmExperts(), - DeepGemmUnpermuteCombine(), ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ed358273bb4..cef11efe22a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -3,27 +3,36 @@ import torch -class FusedMoEDispatchQuantize(ABC): +class FusedMoEQuantizeDispatchCombine(ABC): def __init__(self): pass @abstractmethod - def apply( + def dispatch( self, - hidden_states: torch.Tensor, - hidden_states_scale: Optional[torch.Tensor], - a2: Optional[torch.Tensor], + a: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - n: int, # TODO try to get rid of this? - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor], Optional[Any]]: - # returns (hidden_states, scales, expert_ids, inv_perm, context) # make more abstract? + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + # TODO: figure this out + # returns (quantized+dispatched hidden_states, quantized+dispatched scales, dispatched topk_ids) + raise NotImplementedError + + @abstractmethod + def combine( + self, + out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + ) -> torch.Tensor: raise NotImplementedError # store weights, etc. here -class FusedMoEExperts(ABC): +class FusedMoEPermuteExpertsUnpermute(ABC): def __init__(self): pass @@ -45,46 +54,31 @@ def apply( q_hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool, activation: str, - expert_ids: torch.Tensor, + expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], - q_hidden_states_scale: Optional[torch.Tensor], - hidden_states_scale_2: Optional[torch.Tensor], - workspace1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, workspace2: torch.Tensor, ) -> torch.Tensor: # or None? assume inplace? raise NotImplementedError -class FusedMoEUnpermuteCombine(ABC): - def __init__(self): - pass - - @abstractmethod - def apply( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - inv_perm: Optional[torch.Tensor], - ) -> torch.Tensor: - raise NotImplementedError - - # Note: only intended for use with a single model layer (due to temp buffers, constants, etc.) # TODO: permute/unpermute must be paired -class ModularFusedMoEKernel(torch.nn.Module): # should this be a module? +class FusedMoEModularKernel(torch.nn.Module): # should this be a module? def __init__( self, - dispatch: FusedMoEDispatchQuantize, - fused_experts: FusedMoEExperts, - combine: FusedMoEUnpermuteCombine, + dispatch_combine: FusedMoEQuantizeDispatchCombine, + fused_experts: FusedMoEPermuteExpertsUnpermute, ): super().__init__() - self.dispatch = dispatch + self.dispatch_combine = dispatch_combine self.fused_experts = fused_experts - self.combine = combine def forward( self, @@ -110,14 +104,17 @@ def forward( global_num_experts = E top_k = topk_ids.shape[1] - if inplace: + if False and inplace: out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) #print(f"TKN = {topk_ids.numel()} {M*top_k}") - workspace13_shape, workspace2_shape = self.fused_experts.workspace_shapes(M, N, K, top_k, global_num_experts) + workspace13_shape, workspace2_shape = ( + self.fused_experts.workspace_shapes( + M, N, K, top_k, global_num_experts) + ) # We can reuse the memory between cache1 and cache3 because by the time # we need cache3, we're done with cache1 @@ -130,32 +127,32 @@ def forward( #print(f"\nbefore M = {hidden_states.shape[0]}") - hidden_states, a1_scale, expert_ids, inv_perm, context = self.dispatch.apply( - hidden_states, - a1_scale, - a2_scale, - topk_ids, - global_num_experts, - expert_map, - w2.shape[1], + hidden_states, a1_scale, new_topk_ids = self.dispatch_combine.dispatch( + a=hidden_states, + a1_scale=a1_scale, + a2_scale=a2_scale, + topk_ids=topk_ids, + num_experts=global_num_experts, + expert_map=expert_map, ) #print(f"after M = {hidden_states.shape[0]}") fused_out = self.fused_experts.apply( - hidden_states, - w1, - w2, - inplace, - activation, - expert_ids, + out=hidden_states, + q_hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_ids=new_topk_ids, + inplace=inplace, + activation=activation, + expert_map=expert_map, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, - context=context, ) - return self.combine.apply(out_hidden_states, fused_out, topk_weights, inv_perm) + return self.dispatch_combine.combine(out_hidden_states, fused_out, topk_weights) From a6564c15dd0b1ec1c09205ac16ebc70fa415e7d0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 20:36:32 +0000 Subject: [PATCH 006/190] cutlass working again Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 9f9d24aa385..4ebf48d026c 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -283,7 +283,9 @@ def apply( M = q_hidden_states.shape[0] E, N, _ = w2.shape # because w1 + w2 are transposed K = w1.shape[1] #? + topk = topk_ids.shape[1] assert K == w2.shape[-1] + assert E == w1.shape[0] device = q_hidden_states.device per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( @@ -318,9 +320,9 @@ def apply( a1_scale = a1_scale[a_map] if per_act_token else a1_scale # fix names - c1 = _resize_cache(workspace13, (M, N * 2)) - c2 = _resize_cache(workspace2, (M, N)) - c3 = _resize_cache(workspace13, (M, K)) + c1 = _resize_cache(workspace13, (M * topk, N * 2)) + c2 = _resize_cache(workspace2, (M * topk, N)) + c3 = _resize_cache(workspace13, (M * topk, K)) ops.cutlass_moe_mm( c1, From 1968c4ab09ecf352299f41a892c2f4551e4c3419 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 20:37:13 +0000 Subject: [PATCH 007/190] cutlass working again Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index cef11efe22a..c780a494f4e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -104,7 +104,9 @@ def forward( global_num_experts = E top_k = topk_ids.shape[1] - if False and inplace: + assert not inplace, "NYI" + + if inplace: out_hidden_states = hidden_states else: out_hidden_states = torch.empty_like(hidden_states) From be10ed4fdabaf7674ac24c0ac1e6080ba58a7906 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 22:14:54 +0000 Subject: [PATCH 008/190] fix inplace, format and name cleanups Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 21 +- tests/kernels/test_cutlass_moe.py | 132 +++++++----- .../layers/fused_moe/cutlass_moe.py | 195 ++++++++---------- .../layers/fused_moe/deep_gemm_moe.py | 124 +++++------ .../layers/fused_moe/fused_moe.py | 4 - .../layers/fused_moe/modular_kernel.py | 186 ++++++++--------- 6 files changed, 319 insertions(+), 343 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 176a158493a..3fb17b26284 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -405,9 +405,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) -# if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): -# pytest.skip( -# f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + # if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): + # pytest.skip( + # f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") score = torch.randn((M, E), dtype=dtype) @@ -435,8 +435,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): if True: dgm = modular_deep_gemm_fused_moe_fp8() - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): - return dgm(a, w1, w2, topk_weights, topk_ids, w1_scale=w1_s, w2_scale=w2_s) + + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids): + return dgm(a, + w1, + w2, + topk_weights, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s) else: deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 @@ -452,7 +460,8 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): topk_weights, topk_ids, token_expert_indices = fused_topk( a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index d4b62a8c86e..0dc572c7288 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Optional + import pytest import torch from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8, modular_cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8, modular_cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) from vllm.platforms import current_platform @@ -13,6 +16,48 @@ TOP_KS = [6, 8] +def get_cutlass_moe_fp8(ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype=torch.half) -> Callable: + if True: + return modular_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + out_dtype, + ) + else: + + def cutlass_moe_fp8_fn( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + a1_scale: Optional[torch.Tensor], + ) -> torch.Tensor: + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale, + out_dtype=out_dtype) + + return cutlass_moe_fp8_fn + + def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, @@ -21,18 +66,22 @@ def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) + + cutlass_moe_fp8_fn = get_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ) + + return cutlass_moe_fp8_fn(a, + w1_q, + w2_q, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + a1_scale=a_scale) @pytest.mark.parametrize("m", [2, 64, 224]) @@ -118,48 +167,21 @@ def test_cutlass_moe_no_graph( triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - if True: - cutlass_moe_fp8_fn = modular_cutlass_moe_fp8( - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - ) - else: - def cutlass_moe_fp8_fn( - a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - a1_scale=a_scale1 - ): - return cutlass_moe_fp8( - a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale1 - ) - - cutlass_output = cutlass_moe_fp8_fn( - a, - w1_q, - w2_q, - w1_scale=w1_scale, - w2_scale=w2_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - a1_scale=a_scale1) + cutlass_moe_fp8_fn = get_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ) + + cutlass_output = cutlass_moe_fp8_fn(a, + w1_q, + w2_q, + w1_scale=w1_scale, + w2_scale=w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + a1_scale=a_scale1) #print(triton_output) #print(cutlass_output) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 4ebf48d026c..68783501a72 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 """Fused MoE kernel.""" -from typing import Optional +from typing import Optional, Tuple import torch -from vllm import _custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.utils import (_resize_cache, - _fp8_perm) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -185,39 +184,42 @@ def cutlass_moe_fp8( class CutlassDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, out_dtype: torch.dtype): super().__init__() self.out_dtype = out_dtype def dispatch( - self, - a: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: # why do we need to check a2_scale here? per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - a_q, a1_scale = ops.scaled_fp8_quant( - a, a1_scale, use_per_token_if_dynamic=per_act_token) + a1q, a1q_scale = ops.scaled_fp8_quant( + a1, a1_scale, use_per_token_if_dynamic=per_act_token) - return a_q, a1_scale, topk_ids + return a1q, a1_scale, topk_ids def combine( - self, - out: torch.Tensor, #TBD - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - ) -> torch.Tensor: + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: M, topk = topk_weights.shape - K = hidden_states.shape[1] - hidden_states = (hidden_states.view(-1, topk, K) * topk_weights.view(M, -1, 1).to(self.out_dtype)).sum(dim=1) - # use moe_sum? to write into out? - return hidden_states + K = fused_expert_output.shape[1] + fused_expert_output = fused_expert_output.view( + -1, topk, K) * topk_weights.view( + M, -1, 1) #.to(self.out_dtype)).sum(dim=1) + assert output.dtype == self.out_dtype + ops.moe_sum(fused_expert_output, output) ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, @@ -236,106 +238,85 @@ def combine( class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( - self, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, + self, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype: torch.dtype, ): super().__init__() self.ab_strides1 = ab_strides1 self.c_strides1 = c_strides1 self.ab_strides2 = ab_strides2 self.c_strides2 = c_strides2 + self.out_dtype = out_dtype def workspace_shapes( self, M: int, - K: int, + K: int, # Note that K, N are transposed N: int, topk: int, - num_experts: int - ) -> Tuple[int, int]: + num_experts: int) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(2 * N, K) workspace2 = M * topk * N - # return tuples???? - return (workspace1, workspace2) + return (workspace1, workspace2, self.out_dtype) def apply( - self, - out: torch.Tensor, # TBD - q_hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool, - activation: str, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - ) -> torch.Tensor: # or None? assume inplace? - # chunking in here or in ModularFusedMoEKernel? ignore for now - M = q_hidden_states.shape[0] - E, N, _ = w2.shape # because w1 + w2 are transposed - K = w1.shape[1] #? + self, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: + # TODO: chunking in here or in FusedMoEModularKernel? ignore for now + M = a1q.shape[0] + E, N, K = w2.shape # because w1 + w2 are transposed topk = topk_ids.shape[1] - assert K == w2.shape[-1] - assert E == w1.shape[0] - device = q_hidden_states.device + device = a1q.device - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + assert w1.shape[1] == K + assert w1.shape[0] == E - expert_offsets = torch.empty((E + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((E, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((E, 3), - dtype=torch.int32, - device=device) + per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) - a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) - c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + expert_offsets = torch.empty((E + 1), dtype=torch.int32, device=device) + problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device=device) - #print(f"prob {k}, {n}") + a_map = torch.empty((topk_ids.numel()), + dtype=torch.int32, + device=device) + c_map = torch.empty((topk_ids.numel()), + dtype=torch.int32, + device=device) - ops.get_cutlass_moe_mm_data(topk_ids, - expert_offsets, - problem_sizes1, - problem_sizes2, - a_map, - c_map, - E, - N, - K) + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, E, N, K) - q_hidden_states = _fp8_perm(q_hidden_states, a_map) - a1_scale = a1_scale[a_map] if per_act_token else a1_scale + a1q = _fp8_perm(a1q, a_map) + a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale # fix names c1 = _resize_cache(workspace13, (M * topk, N * 2)) c2 = _resize_cache(workspace2, (M * topk, N)) c3 = _resize_cache(workspace13, (M * topk, K)) - ops.cutlass_moe_mm( - c1, - q_hidden_states, - w1, - a1_scale, - w1_scale, - expert_offsets[:-1], - problem_sizes1, - self.ab_strides1, - self.ab_strides1, - self.c_strides1 - ) + ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, + expert_offsets[:-1], problem_sizes1, + self.ab_strides1, self.ab_strides1, self.c_strides1) if activation == "silu": torch.ops._C.silu_and_mul(c2, c1) @@ -344,21 +325,12 @@ def apply( else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") - intemediate_q, a2_scale = ops.scaled_fp8_quant( + a2q, a2q_scale = ops.scaled_fp8_quant( c2, a2_scale, use_per_token_if_dynamic=per_act_token) - ops.cutlass_moe_mm( - c3, - intemediate_q, - w2, - a2_scale, - w2_scale, - expert_offsets[:-1], - problem_sizes2, - self.ab_strides2, - self.ab_strides2, - self.c_strides2 - ) + ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, + expert_offsets[:-1], problem_sizes2, + self.ab_strides2, self.ab_strides2, self.c_strides2) c3 = c3[c_map, ...] @@ -366,11 +338,11 @@ def apply( def modular_cutlass_moe_fp8( - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, - out_dtype: torch.dtype = torch.half, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + out_dtype: torch.dtype = torch.half, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( CutlassDispatchCombine(out_dtype), @@ -379,5 +351,6 @@ def modular_cutlass_moe_fp8( c_strides1, ab_strides2, c_strides2, + out_dtype, ), ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 3b39e45c7ea..f170e4a02d5 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger -import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, @@ -304,123 +304,109 @@ def deep_gemm_moe_fp8( class DeepGemmDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() def dispatch( - self, - a: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - q_hidden_states, q_hidden_states_scale = _fp8_quantize( - a, + a1q, a1q_scale = _fp8_quantize( + a1, a1_scale, self.block_shape, ) - return q_hidden_states, q_hidden_states_scale, topk_ids + return a1q, a1q_scale, topk_ids def combine( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - ) -> torch.Tensor: - _moe_unpermute_and_reduce( - out, - hidden_states, - None, - topk_weights - ) - return out + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: + _moe_unpermute_and_reduce(output, fused_expert_output, None, + topk_weights) class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() + self.out_dtype = torch.bfloat16 - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int]: + def workspace_shapes(self, M: int, N: int, K: int, topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) workspace1 = M_sum * max(N * 2, K) workspace2 = M_sum * N - # return tuples???? - return (workspace1, workspace2) # TODO add type + return (workspace1, workspace2, self.out_dtype) def apply( - self, - out: torch.Tensor, #unused tbd - q_hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool, - activation: str, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - ) -> torch.Tensor: # or None? assume inplace? + self, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: import deep_gemm as dg - # chunking in here or in ModularFusedMoEKernel? ignore for now - E, N, _ = w1.shape - _, K, _ = w2.shape + # TODO: chunking in here or in FusedMoEModularKernel? ignore for now + #E, N, _ = w1.shape + #_, K, _ = w2.shape + E, N, K = w1.shape - #print(f"M_sum = {M_sum}") + assert w2.shape[1] == K + assert w2.shape[0] == E - q_hidden_states, a1_scale, _, expert_ids, inv_perm = _moe_permute( - q_hidden_states, - a1_scale, + a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( + a1q, + a1q_scale, topk_ids, E, expert_map, self.block_shape[0], ) - M_sum = q_hidden_states.shape[0] + # Note: M_sum is different than the pre-permuted shape of a1q. + M_sum = a1q.shape[0] workspace1 = _resize_cache(workspace13, (M_sum, N)) workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) workspace3 = _resize_cache(workspace13, (M_sum, K)) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (q_hidden_states, a1_scale), (w1, w1_scale), - workspace1, - expert_ids) + (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) if activation == "silu": - torch.ops._C.silu_and_mul(workspace2, - workspace1.view(-1, N)) + torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) elif activation == "gelu": - torch.ops._C.gelu_and_mul(workspace2, - workspace1.view(-1, N)) + torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") a2q_scale: Optional[torch.Tensor] = None - qworkspace2, a2q_scale = _fp8_quantize( - workspace2, a2_scale, self.block_shape) + a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, self.block_shape) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qworkspace2, a2q_scale), (w2, w2_scale), - workspace3, expert_ids) + (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) workspace3 = workspace3[inv_perm, ...] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 0cc0dccaccb..f6305822c2d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1420,10 +1420,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - if True: - intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K) - intermediate_cache3.mul_(curr_topk_weights.view(tokens_in_chunk, -1, 1)) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index c780a494f4e..ce08d984c3a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,160 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, Optional, Tuple +from typing import Optional, Tuple + import torch +# TODO: add comments + class FusedMoEQuantizeDispatchCombine(ABC): + def __init__(self): pass @abstractmethod def dispatch( - self, - a: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: # TODO: figure this out - # returns (quantized+dispatched hidden_states, quantized+dispatched scales, dispatched topk_ids) + # returns (quantized+dispatched a, quantized+dispatched a1_scales, dispatched topk_ids) raise NotImplementedError @abstractmethod def combine( - self, - out: torch.Tensor, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - ) -> torch.Tensor: + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, # not reduced or weighted + topk_weights: torch.Tensor, + ) -> None: raise NotImplementedError # store weights, etc. here class FusedMoEPermuteExpertsUnpermute(ABC): + def __init__(self): pass @abstractmethod - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int]: + def workspace_shapes(self, M: int, N: int, K: int, topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: raise NotImplementedError @abstractmethod def apply( - self, - out: torch.Tensor, - q_hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool, - activation: str, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - ) -> torch.Tensor: # or None? assume inplace? + self, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: raise NotImplementedError # Note: only intended for use with a single model layer (due to temp buffers, constants, etc.) # TODO: permute/unpermute must be paired -class FusedMoEModularKernel(torch.nn.Module): # should this be a module? +class FusedMoEModularKernel(torch.nn.Module): # should this be a module? + def __init__( - self, - dispatch_combine: FusedMoEQuantizeDispatchCombine, - fused_experts: FusedMoEPermuteExpertsUnpermute, + self, + dispatch_combine: FusedMoEQuantizeDispatchCombine, + fused_experts: FusedMoEPermuteExpertsUnpermute, ): super().__init__() self.dispatch_combine = dispatch_combine self.fused_experts = fused_experts def forward( - self, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, + self, + a1: torch.Tensor, # aka hidden states + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - M, _ = hidden_states.shape + M, _ = a1.shape E, K, N = w2.shape if global_num_experts == -1: global_num_experts = E top_k = topk_ids.shape[1] - assert not inplace, "NYI" - if inplace: - out_hidden_states = hidden_states + output = a1 else: - out_hidden_states = torch.empty_like(hidden_states) - - #print(f"TKN = {topk_ids.numel()} {M*top_k}") + output = torch.empty_like(a1) - workspace13_shape, workspace2_shape = ( - self.fused_experts.workspace_shapes( - M, N, K, top_k, global_num_experts) - ) + workspace13_shape, workspace2_shape, workspace_dtype = ( + self.fused_experts.workspace_shapes(M, N, K, top_k, + global_num_experts)) # We can reuse the memory between cache1 and cache3 because by the time # we need cache3, we're done with cache1 workspace13 = torch.empty(workspace13_shape, - device=hidden_states.device, - dtype=hidden_states.dtype) + device=a1.device, + dtype=workspace_dtype) workspace2 = torch.empty(workspace2_shape, - device=hidden_states.device, - dtype=hidden_states.dtype) - - #print(f"\nbefore M = {hidden_states.shape[0]}") - - hidden_states, a1_scale, new_topk_ids = self.dispatch_combine.dispatch( - a=hidden_states, - a1_scale=a1_scale, - a2_scale=a2_scale, - topk_ids=topk_ids, - num_experts=global_num_experts, - expert_map=expert_map, + device=a1.device, + dtype=workspace_dtype) + + a1q, a1q_scale, dispatched_topk_ids = self.dispatch_combine.dispatch( + a1, + a1_scale, + a2_scale, + topk_ids, + global_num_experts, + expert_map, ) - #print(f"after M = {hidden_states.shape[0]}") - fused_out = self.fused_experts.apply( - out=hidden_states, - q_hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_ids=new_topk_ids, - inplace=inplace, - activation=activation, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + a1q, + w1, + w2, + dispatched_topk_ids, + activation, + expert_map, + w1_scale, + w2_scale, + a1q_scale, + a2_scale, workspace13=workspace13, workspace2=workspace2, ) - return self.dispatch_combine.combine(out_hidden_states, fused_out, topk_weights) + self.dispatch_combine.combine(output, fused_out, topk_weights) + + return output From f8b64f5be166c2df0ee08748caecfcf41965d4d9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 22:20:58 +0000 Subject: [PATCH 009/190] fix inplace, format + name cleanups Signed-off-by: Bill Nell --- .../layers/fused_moe/modular_kernel.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index ce08d984c3a..3bef7ee30d1 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -9,8 +9,8 @@ class FusedMoEQuantizeDispatchCombine(ABC): - def __init__(self): - pass + # def __init__(self): + # pass @abstractmethod def dispatch( @@ -23,7 +23,9 @@ def dispatch( expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: # TODO: figure this out - # returns (quantized+dispatched a, quantized+dispatched a1_scales, dispatched topk_ids) + # returns (quantized+dispatched a, + # quantized+dispatched a1_scales, + # dispatched topk_ids) raise NotImplementedError @abstractmethod @@ -39,8 +41,8 @@ def combine( # store weights, etc. here class FusedMoEPermuteExpertsUnpermute(ABC): - def __init__(self): - pass + # def __init__(self): + # pass @abstractmethod def workspace_shapes(self, M: int, N: int, K: int, topk: int, @@ -66,8 +68,8 @@ def apply( raise NotImplementedError -# Note: only intended for use with a single model layer (due to temp buffers, constants, etc.) -# TODO: permute/unpermute must be paired +# Note: only intended for use with a single model layer (due to temp buffers, +# constants, etc.) class FusedMoEModularKernel(torch.nn.Module): # should this be a module? def __init__( @@ -103,10 +105,7 @@ def forward( global_num_experts = E top_k = topk_ids.shape[1] - if inplace: - output = a1 - else: - output = torch.empty_like(a1) + output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( self.fused_experts.workspace_shapes(M, N, K, top_k, From e94d6c1d4c06fe587dd9a0f22dbaf4c9a93a964c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 01:18:53 +0000 Subject: [PATCH 010/190] test improvements Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 18 +++++---------- .../layers/fused_moe/deep_gemm_moe.py | 22 ++++++++----------- .../layers/fused_moe/fused_moe.py | 5 ++++- 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 3fb17b26284..ac2e002ce5a 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,9 +11,9 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm, deep_gemm_moe_fp8, - modular_deep_gemm_fused_moe_fp8) + modular_deep_gemm_fused_moe_fp8, + _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -383,13 +383,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): block_size = [block_m, block_m] dtype = torch.bfloat16 - # only aligned sizes TODO: use _valid_deep_gemm here instead? - if (N % block_m != 0 or K % block_m != 0 or topk > E): - pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") - if False and N <= 512: - pytest.skip("Skipping N <= 512 until performance issues solved.") + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") vllm_config = VllmConfig() @@ -405,10 +403,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - # if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): - # pytest.skip( - # f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index f170e4a02d5..943e383f3bc 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch @@ -20,12 +20,18 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None -def deep_gemm_block_shape() -> List[int]: +def deep_gemm_block_shape() -> list[int]: + # Lazy import to avoid CUDA initialization problems. import deep_gemm as dg block = dg.get_m_alignment_for_contiguous_layout() return [block, block] +def _valid_deep_gemm_shape(M: int, N: int, K: int): + align = deep_gemm_block_shape()[0] + return M >= align and N % align == 0 and K % align == 0 + + # TODO: check types? def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, @@ -39,23 +45,13 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, if not has_deep_gemm: return False - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - # Expert maps not supported yet. if expert_map is not None: return False - align = dg.get_m_alignment_for_contiguous_layout() M = hidden_states.shape[0] _, K, N = w2.shape - - # For now, disable DeepGemm for small N until better permute/unpermute - # ops are available. - if N <= 512: - return False - - if align > M or N % align != 0 or K % align != 0: + if not _valid_deep_gemm_shape(M, N, K): return False return (hidden_states.is_contiguous() and w1.is_contiguous() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f6305822c2d..5c66b208f31 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1131,7 +1131,10 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, allow_deep_gemm: bool = False) -> torch.Tensor: - if (allow_deep_gemm and use_fp8_w8a8 + # For now, disable DeepGemm for small N (<= 512) until better + # permute/unpermute ops are available. + N = w1.shape[1] + if (allow_deep_gemm and use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( From d6751384014f70b4aae498dc3cffb9ac99be75f7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 04:41:01 +0000 Subject: [PATCH 011/190] make modular triton classes, fix edge cases Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 53 ++-- .../layers/fused_moe/cutlass_moe.py | 37 +-- .../layers/fused_moe/deep_gemm_moe.py | 27 +- .../layers/fused_moe/fused_moe.py | 258 ++++++++++++++++++ .../layers/fused_moe/modular_kernel.py | 32 ++- 5 files changed, 351 insertions(+), 56 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index abf3e3667a7..0e671ac9683 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -68,31 +68,34 @@ def test_fused_moe( else: e_map = None - torch_output = torch_moe(a, w1, w2, score, topk, e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w1, w2, score, topk, e_map) + iterative_output = iterative_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, @@ -113,7 +116,7 @@ def test_fused_moe( def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, group_size: int, has_zp: bool, weight_bits: int): - print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) + #print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 68783501a72..75d25f418d1 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -214,10 +214,9 @@ def combine( topk_weights: torch.Tensor, ) -> None: M, topk = topk_weights.shape - K = fused_expert_output.shape[1] - fused_expert_output = fused_expert_output.view( - -1, topk, K) * topk_weights.view( - M, -1, 1) #.to(self.out_dtype)).sum(dim=1) + K = fused_expert_output.shape[-1] + fused_expert_output = (fused_expert_output.view(-1, topk, K) * + topk_weights.view(M, -1, 1)) assert output.dtype == self.out_dtype ops.moe_sum(fused_expert_output, output) @@ -255,12 +254,14 @@ def __init__( self.out_dtype = out_dtype def workspace_shapes( - self, - M: int, - K: int, # Note that K, N are transposed - N: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + self, + a_dtype: torch.dtype, + M: int, + K: int, # Note that K, N are transposed + N: int, + topk: int, + num_experts: int + ) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(2 * N, K) workspace2 = M * topk * N return (workspace1, workspace2, self.out_dtype) @@ -272,9 +273,12 @@ def apply( w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, + global_num_experts: int, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -282,19 +286,19 @@ def apply( ) -> torch.Tensor: # TODO: chunking in here or in FusedMoEModularKernel? ignore for now M = a1q.shape[0] - E, N, K = w2.shape # because w1 + w2 are transposed + _, N, K = w2.shape # because w1 + w2 are transposed topk = topk_ids.shape[1] device = a1q.device assert w1.shape[1] == K - assert w1.shape[0] == E + assert global_num_experts != -1 per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - expert_offsets = torch.empty((E + 1), dtype=torch.int32, device=device) - problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device=device) - problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device=device) + expert_offsets = torch.empty((global_num_experts + 1), dtype=torch.int32, device=device) + problem_sizes1 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, @@ -304,7 +308,8 @@ def apply( device=device) ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, E, N, K) + problem_sizes2, a_map, c_map, global_num_experts, + N, K) a1q = _fp8_perm(a1q, a_map) a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 943e383f3bc..2ecba0be45c 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -111,7 +111,7 @@ def _moe_unpermute_and_reduce( reduction on the hidden states. """ M, topk = topk_weight.shape - K = curr_hidden.shape[1] + K = curr_hidden.shape[-1] if inv_perm is not None: curr_hidden = curr_hidden[inv_perm, ...] curr_hidden = curr_hidden.view(-1, topk, K) @@ -336,16 +336,22 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() - self.out_dtype = torch.bfloat16 - def workspace_shapes(self, M: int, N: int, K: int, topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int, torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) workspace1 = M_sum * max(N * 2, K) workspace2 = M_sum * N - return (workspace1, workspace2, self.out_dtype) + return (workspace1, workspace2, a_dtype) def apply( self, @@ -354,9 +360,12 @@ def apply( w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, + global_num_experts: int, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -365,18 +374,16 @@ def apply( import deep_gemm as dg # TODO: chunking in here or in FusedMoEModularKernel? ignore for now - #E, N, _ = w1.shape - #_, K, _ = w2.shape - E, N, K = w1.shape + _, N, K = w1.shape + assert global_num_experts != -1 assert w2.shape[1] == K - assert w2.shape[0] == E a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( a1q, a1q_scale, topk_ids, - E, + global_num_experts, expert_map, self.block_shape[0], ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5c66b208f31..d78fe5c38a2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,6 +8,7 @@ import torch import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( @@ -18,6 +19,8 @@ per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, + _resize_cache) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -1152,6 +1155,30 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale=a1_scale, a2_scale=a2_scale, ) + elif hidden_states.shape[0] <= envs.VLLM_FUSED_MOE_CHUNK_SIZE: + fe = modular_triton_fused_moe( + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + block_shape, + ) + return fe( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, @@ -1159,6 +1186,7 @@ def fused_experts(hidden_states: torch.Tensor, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, + inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, @@ -1540,3 +1568,233 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape) + + +class TritonDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + + def __init__(self, use_fp8_w8a8: bool, block_shape: Optional[List[int]]): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.block_shape = block_shape + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + if self.use_fp8_w8a8: + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + ) + else: + a1q = a1 + a1q_scale = a1_scale + + return a1q, a1q_scale, topk_ids + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: + M, topk = topk_weights.shape + K = fused_expert_output.shape[-1] + fused_expert_output = fused_expert_output.view(-1, topk, K) + fused_expert_output.mul_(topk_weights.view(M, -1, 1)) + ops.moe_sum(fused_expert_output, output) + + +class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( + self, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]], + ): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a16 = use_int8_w8a16 + self.block_shape = block_shape + + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int, torch.dtype]: + workspace1 = M * topk * max(N * 2, K) + workspace2 = M * topk * N + return (workspace1, workspace2, a_dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + K = w2.shape[1] + if global_num_experts == -1: + global_num_experts = E + top_k_num = topk_ids.shape[1] + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + M = num_tokens + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + dtype=hidden_states.dtype) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + top_k_num, + config_dtype, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") + + curr_hidden_states = hidden_states + tokens_in_chunk, _ = curr_hidden_states.shape + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, (tokens_in_chunk, top_k_num, N)) + intermediate_cache2 = _resize_cache(workspace2, (tokens_in_chunk * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, (tokens_in_chunk, top_k_num, K)) + + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids + + qcurr_hidden_states, a1q_scale = hidden_states, a1q_scale + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) + + invoke_fused_moe_kernel(qcurr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + block_shape=self.block_shape) + + if activation == "silu": + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + + a2q_scale: Optional[torch.Tensor] = None + + if self.use_fp8_w8a8: + qintermediate_cache2, a2q_scale = _fp8_quantize( + intermediate_cache2, a2_scale, self.block_shape) + else: + qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale + + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + block_shape=self.block_shape) + + return intermediate_cache3 + + +def modular_triton_fused_moe( + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, +) -> mk.FusedMoEModularKernel: + return mk.FusedMoEModularKernel( + TritonDispatchCombine(use_fp8_w8a8, block_shape), + TritonExperts( + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + block_shape, + ), + ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 3bef7ee30d1..08a004f7565 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -45,8 +45,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): # pass @abstractmethod - def workspace_shapes(self, M: int, N: int, K: int, topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int + ) -> Tuple[int, int, torch.dtype]: raise NotImplementedError @abstractmethod @@ -57,9 +64,12 @@ def apply( w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, + global_num_experts: int, expert_map: Optional[torch.Tensor], w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, @@ -100,7 +110,9 @@ def forward( a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: M, _ = a1.shape - E, K, N = w2.shape + E, N, _ = w1.shape + K = w2.shape[1] + #E, K, N = w2.shape if global_num_experts == -1: global_num_experts = E top_k = topk_ids.shape[1] @@ -108,8 +120,15 @@ def forward( output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes(M, N, K, top_k, - global_num_experts)) + self.fused_experts.workspace_shapes( + a1.dtype, + M, + N, + K, + top_k, + global_num_experts + ) + ) # We can reuse the memory between cache1 and cache3 because by the time # we need cache3, we're done with cache1 @@ -135,9 +154,12 @@ def forward( w2, dispatched_topk_ids, activation, + global_num_experts, expert_map, w1_scale, w2_scale, + w1_zp, + w2_zp, a1q_scale, a2_scale, workspace13=workspace13, From a665564ae8338d8726d32a5b3aadf180ee6afed2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 05:33:17 +0000 Subject: [PATCH 012/190] fix outplace bug Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d78fe5c38a2..aaf68061b85 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1186,7 +1186,6 @@ def fused_experts(hidden_states: torch.Tensor, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, From a6459aae1c8d6d2dc6e8d32f08130cb65102b6dd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 17:17:16 +0000 Subject: [PATCH 013/190] refactor dispatch/combine stuff Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 58 +--------- .../layers/fused_moe/deep_gemm_moe.py | 108 ++---------------- .../layers/fused_moe/dispatch_combine.py | 44 +++++++ .../layers/fused_moe/modular_kernel.py | 7 +- .../layers/fused_moe/moe_permute_unpermute.py | 67 ++++++++++- .../layers/fused_moe/pplx_dispatch_combine.py | 64 +++++++++++ vllm/model_executor/layers/fused_moe/utils.py | 9 +- 7 files changed, 200 insertions(+), 157 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/dispatch_combine.py create mode 100644 vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 75d25f418d1..77fa6daec95 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -7,6 +7,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -183,59 +186,6 @@ def cutlass_moe_fp8( return c2.sum(dim=1) -class CutlassDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - - def __init__(self, out_dtype: torch.dtype): - super().__init__() - self.out_dtype = out_dtype - - def dispatch( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - # why do we need to check a2_scale here? - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - - a1q, a1q_scale = ops.scaled_fp8_quant( - a1, a1_scale, use_per_token_if_dynamic=per_act_token) - - return a1q, a1_scale, topk_ids - - def combine( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - ) -> None: - M, topk = topk_weights.shape - K = fused_expert_output.shape[-1] - fused_expert_output = (fused_expert_output.view(-1, topk, K) * - topk_weights.view(M, -1, 1)) - assert output.dtype == self.out_dtype - ops.moe_sum(fused_expert_output, output) - - ops.get_cutlass_moe_mm_data(topk_ids, - expert_offsets, - problem_sizes1, - problem_sizes2, - a_map, - c_map, - num_experts, - k, - n) - - rep_a_q = _fp8_perm(a_q, a_map) - rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale - - return rep_a_q, rep_a1_scales, expert_offsets, c_map, (problem_sizes1, problem_sizes2) - - class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -350,7 +300,7 @@ def modular_cutlass_moe_fp8( out_dtype: torch.dtype = torch.half, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - CutlassDispatchCombine(out_dtype), + StandardDispatchCombine(), CutlassExperts( ab_strides1, c_strides1, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 2ecba0be45c..550a8153693 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -6,13 +6,16 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, - _fp8_quantize, +from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_permute, + _moe_unpermute_and_reduce +) +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) from vllm.utils import round_up logger = init_logger(__name__) @@ -58,67 +61,6 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, and w2.is_contiguous()) -def _moe_permute( - curr_hidden_states: torch.Tensor, - a1q_scale: Optional[torch.Tensor], - curr_topk_ids: torch.Tensor, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - block_m: int, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - Optional[torch.Tensor]]: - """ - Determine the sorted_token_ids, expert_ids for the given problem size. - Permute the hidden states and scales according to `sorted_token_ids`. - """ - top_k_num = curr_topk_ids.shape[1] - - tokens_in_chunk, _ = curr_hidden_states.shape - - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, - block_m, - global_num_experts, - expert_map, - pad_sorted_ids=True)) - - inv_perm: Optional[torch.Tensor] = None - - num_tokens = top_k_num * tokens_in_chunk - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] - - # Permute according to sorted token ids. - curr_hidden_states = _fp8_perm(curr_hidden_states, - sorted_token_ids // top_k_num) - - if a1q_scale is not None: - a1q_scale = a1q_scale[sorted_token_ids // top_k_num] - - return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) - - -def _moe_unpermute_and_reduce( - out: torch.Tensor, - curr_hidden: torch.Tensor, - inv_perm: Optional[torch.Tensor], - topk_weight: torch.Tensor, -) -> None: - """ - Unpermute the final result and apply topk_weights, then perform the final - reduction on the hidden states. - """ - M, topk = topk_weight.shape - K = curr_hidden.shape[-1] - if inv_perm is not None: - curr_hidden = curr_hidden[inv_perm, ...] - curr_hidden = curr_hidden.view(-1, topk, K) - curr_hidden.mul_(topk_weight.view(M, -1, 1)) - ops.moe_sum(curr_hidden, out) - - def deep_gemm_moe_fp8( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -299,38 +241,6 @@ def deep_gemm_moe_fp8( return out_hidden_states -class DeepGemmDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - - def __init__(self): - super().__init__() - self.block_shape = deep_gemm_block_shape() - - def dispatch( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - ) - return a1q, a1q_scale, topk_ids - - def combine( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - ) -> None: - _moe_unpermute_and_reduce(output, fused_expert_output, None, - topk_weights) - - class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): @@ -418,6 +328,6 @@ def apply( def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - DeepGemmDispatchCombine(), + StandardDispatchCombine(deep_gemm_block_shape()), DeepGemmExperts(), ) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py new file mode 100644 index 00000000000..589955fb65d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -0,0 +1,44 @@ +import torch +from typing import Optional, Tuple + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_unpermute_and_reduce +) + +class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + + def __init__(self, block_shape: Optional[list[int]] = None): + super().__init__() + self.block_shape = block_shape + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + per_act_token, + ) + return a1q, a1q_scale, topk_ids + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: + _moe_unpermute_and_reduce(output, fused_expert_output, None, + topk_weights) + diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 08a004f7565..b7582bcb4fe 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -109,10 +109,15 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # Note: extracting the problem shape from the weight and activation tensors is + # tricky. It needs to be done this way specifically due to subtle issues with + # particular kernels, e.g. the int4 kernels divide the trailing dimension by + # two, so it's not "correct" to extract N or K from the trailing dimension of + # w1 or w2. Similarly, some kernels transpose the weights, so this needs to + # be kept in mind. M, _ = a1.shape E, N, _ = w1.shape K = w2.shape[1] - #E, K, N = w2.shape if global_num_experts == -1: global_num_experts = E top_k = topk_ids.shape[1] diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index cdf7e31c143..ad9e149f5d7 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -1,7 +1,72 @@ # SPDX-License-Identifier: Apache-2.0 +import torch from typing import Optional, Tuple -import torch +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm + + +def _moe_permute( + curr_hidden_states: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + curr_topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + block_m: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """ + Determine the sorted_token_ids, expert_ids for the given problem size. + Permute the hidden states and scales according to `sorted_token_ids`. + """ + top_k_num = curr_topk_ids.shape[1] + + tokens_in_chunk, _ = curr_hidden_states.shape + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, + block_m, + global_num_experts, + expert_map, + pad_sorted_ids=True)) + + inv_perm: Optional[torch.Tensor] = None + + num_tokens = top_k_num * tokens_in_chunk + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) + inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] + + # Permute according to sorted token ids. + curr_hidden_states = _fp8_perm(curr_hidden_states, + sorted_token_ids // top_k_num) + + if a1q_scale is not None: + a1q_scale = a1q_scale[sorted_token_ids // top_k_num] + + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, + inv_perm) + + +def _moe_unpermute_and_reduce( + out: torch.Tensor, + curr_hidden: torch.Tensor, + inv_perm: Optional[torch.Tensor], + topk_weight: torch.Tensor, +) -> None: + """ + Unpermute the final result and apply topk_weights, then perform the final + reduction on the hidden states. + """ + M, topk = topk_weight.shape + K = curr_hidden.shape[-1] + if inv_perm is not None: + curr_hidden = curr_hidden[inv_perm, ...] + curr_hidden = curr_hidden.view(-1, topk, K) + curr_hidden.mul_(topk_weight.view(M, -1, 1)) + ops.moe_sum(curr_hidden, out) def moe_permute( diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py new file mode 100644 index 00000000000..1eb500d932a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -0,0 +1,64 @@ +import torch +from typing import Optional, Tuple + +import pplx_kernels as pplx +import vllm.model_executor.layers.fused_moe.modular_kernel as mk + + +class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, a2a: pplx.AllToAll): + super().__init__() + self.a2a = a2a + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + self.a2a.dispatch( + out_expert_num_tokens, # torch.Tensor, + out_expert_x, # torch.Tensor, + out_expert_x_scale, # torch.Tensor | None, + dp_x, # torch.Tensor, + dp_x_scale, # torch.Tensor | None, + indices, # torch.Tensor, + bound_m, # torch.Tensor | None, + do_send, # bool = True, + do_recv, # bool = True, + ) + return 1q, a1q_scale, topk_ids + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + ) -> None: + self.a2a.combine( + out_tokens, #: torch.Tensor, + indices, #: torch.Tensor, + weights, #: torch.Tensor, + expert_y, #: torch.Tensor, + bound_m, #: torch.Tensor | None, + do_send, #: bool = True, + do_recv, #: bool = True, + ) + + +# singleton-ish +def get_a2a( + max_num_tokens: int, + num_experts: int, + experts_per_token: int, + rank: int, + world_size: int, + dp_size: int, + hidden_dim: int, + hidden_dim_bytes: int, + hidden_dim_scale_bytes: int, +) -> pplx.AllToAll: + pass diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index db31422f727..ee8e8857fab 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -22,14 +22,19 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], - block_shape: Optional[List[int]], + block_shape: Optional[List[int]] = None, + per_act_token: bool = False, # make sure this is the same default as op ) -> Tuple[torch.Tensor, torch.Tensor]: """ Perform fp8 quantization on the inputs. If a block_shape is provided, the output will be blocked. """ if block_shape is None: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) + A, A_scale = ops.scaled_fp8_quant( + A, + A_scale, + use_per_token_if_dynamic=per_act_token + ) else: assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] From da3fe2b543b8281c177ffc3da2fadca549c98dad Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 19:39:35 +0000 Subject: [PATCH 014/190] initial pplx dispatch/combine class Signed-off-by: Bill Nell --- .../layers/fused_moe/dispatch_combine.py | 6 +- .../layers/fused_moe/fused_moe.py | 6 +- .../layers/fused_moe/modular_kernel.py | 20 ++- .../layers/fused_moe/pplx_dispatch_combine.py | 114 ++++++++++++------ 4 files changed, 92 insertions(+), 54 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index 589955fb65d..cd981cfb696 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -21,7 +21,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) @@ -31,14 +31,14 @@ def dispatch( self.block_shape, per_act_token, ) - return a1q, a1q_scale, topk_ids + return a1q, a1q_scale def combine( self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> None: _moe_unpermute_and_reduce(output, fused_expert_output, None, topk_weights) - diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index aaf68061b85..f86b3f7276d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1569,6 +1569,7 @@ def fused_moe( block_shape=block_shape) +# TODO: merge with StandardDispatchCombine class TritonDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): def __init__(self, use_fp8_w8a8: bool, block_shape: Optional[List[int]]): @@ -1584,7 +1585,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.use_fp8_w8a8: a1q, a1q_scale = _fp8_quantize( a1, @@ -1595,13 +1596,14 @@ def dispatch( a1q = a1 a1q_scale = a1_scale - return a1q, a1q_scale, topk_ids + return a1q, a1q_scale def combine( self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> None: M, topk = topk_weights.shape K = fused_expert_output.shape[-1] diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b7582bcb4fe..6ff85c21cee 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -9,9 +9,6 @@ class FusedMoEQuantizeDispatchCombine(ABC): - # def __init__(self): - # pass - @abstractmethod def dispatch( self, @@ -21,11 +18,9 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - # TODO: figure this out + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # returns (quantized+dispatched a, - # quantized+dispatched a1_scales, - # dispatched topk_ids) + # quantized+dispatched a1_scales) raise NotImplementedError @abstractmethod @@ -34,6 +29,7 @@ def combine( output: torch.Tensor, fused_expert_output: torch.Tensor, # not reduced or weighted topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> None: raise NotImplementedError @@ -41,9 +37,6 @@ def combine( # store weights, etc. here class FusedMoEPermuteExpertsUnpermute(ABC): - # def __init__(self): - # pass - @abstractmethod def workspace_shapes( self, @@ -115,6 +108,7 @@ def forward( # two, so it's not "correct" to extract N or K from the trailing dimension of # w1 or w2. Similarly, some kernels transpose the weights, so this needs to # be kept in mind. + # TODO: make this a method/utility function, e.g. problem_size(a, w1, w2, topk_ids, ...) M, _ = a1.shape E, N, _ = w1.shape K = w2.shape[1] @@ -144,7 +138,7 @@ def forward( device=a1.device, dtype=workspace_dtype) - a1q, a1q_scale, dispatched_topk_ids = self.dispatch_combine.dispatch( + a1q, a1q_scale = self.dispatch_combine.dispatch( a1, a1_scale, a2_scale, @@ -157,7 +151,7 @@ def forward( a1q, w1, w2, - dispatched_topk_ids, + topk_ids, activation, global_num_experts, expert_map, @@ -171,6 +165,6 @@ def forward( workspace2=workspace2, ) - self.dispatch_combine.combine(output, fused_out, topk_weights) + self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids) return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 1eb500d932a..fea0c5c1f16 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -1,64 +1,106 @@ import torch -from typing import Optional, Tuple +from typing import List, Optional, Tuple import pplx_kernels as pplx import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +# Note use: layer.get_all_to_all() to get an AllToAll instance +# The max_num_tokens, world_size and dp_size must be the same +# as the ones used to create the AllToAll. Unfortunately, there's +# no way(?) to extract this info from AllToAll class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, a2a: pplx.AllToAll): + def __init__( + self, + a2a: pplx.AllToAll, + max_num_tokens: int, + world_size: int, + dp_size: int, + block_shape: Optional[List[int]] = None): super().__init__() self.a2a = a2a + self.block_shape = block_shape + self.dp_num_tokens = max_num_tokens * (world_size // dp_size) def dispatch( self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, + rank_topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Is this always going to be a1.device? + device = a1.device + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + per_act_token, + ) + + expert_num_tokens = torch.empty( + num_experts, + dtype=torch.int32, + device=device, + ) + + expert_x = torch.empty( + (num_experts, self.dp_num_tokens, a1q.shape[-1]), + dtype=a1q.dtype, + device=device, + ) + + expert_x_scale: torch.Tensor | None = None + if a1q.dtype.itemsize == 1: + float32_size = torch.float32.itemsize + block_size = (self.block_shape[0] if self.block_shape is not None else 1) * float32_size + expert_x_scale = torch.empty( + ( + num_experts, + expert_x.size(1), + (expert_x.size(2) + block_size - 1) // block_size, + ), + dtype=torch.float32, + device=device, + ) + + # This argument is optional + bound_m = torch.tensor([a1q.shape[0]], dtype=torch.uint32, device=device) + self.a2a.dispatch( - out_expert_num_tokens, # torch.Tensor, - out_expert_x, # torch.Tensor, - out_expert_x_scale, # torch.Tensor | None, - dp_x, # torch.Tensor, - dp_x_scale, # torch.Tensor | None, - indices, # torch.Tensor, - bound_m, # torch.Tensor | None, - do_send, # bool = True, - do_recv, # bool = True, + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=rank_topk_ids, + bound_m=bound_m, ) - return 1q, a1q_scale, topk_ids + return expert_x, expert_x_scale def combine( self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, ) -> None: - self.a2a.combine( - out_tokens, #: torch.Tensor, - indices, #: torch.Tensor, - weights, #: torch.Tensor, - expert_y, #: torch.Tensor, - bound_m, #: torch.Tensor | None, - do_send, #: bool = True, - do_recv, #: bool = True, - ) + # This argument is optional + bound_m = torch.tensor([output.shape[0]], dtype=torch.uint32, device=output.device) + # TODO assert output is the proper size -# singleton-ish -def get_a2a( - max_num_tokens: int, - num_experts: int, - experts_per_token: int, - rank: int, - world_size: int, - dp_size: int, - hidden_dim: int, - hidden_dim_bytes: int, - hidden_dim_scale_bytes: int, -) -> pplx.AllToAll: - pass + self.a2a.combine( + out_tokens=output, + indices=topk_ids, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m + ) From f2fe65a98fa4fc9b6fbc864a6629f63565663dad Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 20:41:40 +0000 Subject: [PATCH 015/190] merge triton dispatch into standard, add some comments Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 2 +- .../layers/fused_moe/deep_gemm_moe.py | 3 +- .../layers/fused_moe/dispatch_combine.py | 28 ++- .../layers/fused_moe/fused_moe.py | 51 +----- .../layers/fused_moe/modular_kernel.py | 172 ++++++++++++++++-- .../layers/fused_moe/pplx_dispatch_combine.py | 21 ++- 6 files changed, 196 insertions(+), 81 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 77fa6daec95..7ea999d5086 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -300,7 +300,7 @@ def modular_cutlass_moe_fp8( out_dtype: torch.dtype = torch.half, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - StandardDispatchCombine(), + StandardDispatchCombine(quant_dtype=torch.float8_e4m3fn), CutlassExperts( ab_strides1, c_strides1, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 550a8153693..19c54dd2c31 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -328,6 +328,7 @@ def apply( def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - StandardDispatchCombine(deep_gemm_block_shape()), + StandardDispatchCombine(quant_dtype=torch.float8_e4m3fn, + block_shape=deep_gemm_block_shape()), DeepGemmExperts(), ) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index cd981cfb696..207a1c69860 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -9,9 +9,14 @@ class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, block_shape: Optional[list[int]] = None): + def __init__( + self, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None + ): super().__init__() self.block_shape = block_shape + self.quant_dtype = quant_dtype def dispatch( self, @@ -22,15 +27,20 @@ def dispatch( num_experts: int, expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + if self.quant_dtype == torch.float8_e4m3fn: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + per_act_token, + ) + else: + a1q = a1 + a1q_scale = a1_scale - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - per_act_token, - ) return a1q, a1q_scale def combine( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f86b3f7276d..182bee11f18 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -13,6 +13,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -1569,49 +1572,6 @@ def fused_moe( block_shape=block_shape) -# TODO: merge with StandardDispatchCombine -class TritonDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - - def __init__(self, use_fp8_w8a8: bool, block_shape: Optional[List[int]]): - super().__init__() - self.use_fp8_w8a8 = use_fp8_w8a8 - self.block_shape = block_shape - - def dispatch( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.use_fp8_w8a8: - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - ) - else: - a1q = a1 - a1q_scale = a1_scale - - return a1q, a1q_scale - - def combine( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ) -> None: - M, topk = topk_weights.shape - K = fused_expert_output.shape[-1] - fused_expert_output = fused_expert_output.view(-1, topk, K) - fused_expert_output.mul_(topk_weights.view(M, -1, 1)) - ops.moe_sum(fused_expert_output, output) - - class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, @@ -1791,7 +1751,10 @@ def modular_triton_fused_moe( block_shape: Optional[List[int]] = None, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - TritonDispatchCombine(use_fp8_w8a8, block_shape), + StandardDispatchCombine( + quant_dtype=torch.float8_e4m3fn if use_fp8_w8a8 else None, + block_shape=block_shape + ), TritonExperts( use_fp8_w8a8, use_int8_w8a16, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 6ff85c21cee..7f617a06e2d 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,11 +4,51 @@ import torch -# TODO: add comments +def moe_problem_size( + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, +) -> Tuple[int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + Note: extracting the problem shape from the weight and activation tensors is + not obvious. It needs to be done this way specifically due to subtle issues + with particular kernels, e.g. the int4 kernels divide the trailing dimension + by two, so it's not "correct" to extract N or K from the trailing dimension + of w1 or w2. Similarly, some kernels transpose the weights, so this needs to + be kept in mind. + """ + # Make sure we are using the correct a1 (pre-permute) + assert topk_ids.shape[0] == a1.shape[0] + M, _ = a1.shape + E, N, _ = w1.shape + K = w2.shape[1] + topk = topk_ids.shape[1] + return E, M, N, K, topk -class FusedMoEQuantizeDispatchCombine(ABC): +# +# A set of base classes used to make MoE kernels more modular. +# +# Architecture: +# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] +# +# [Quantize-Dispatch] and [Combine] functionality are bundled into a single +# class `FusedMoEQuantizeDispatchCombine` since they could use collective +# communication mechanisms that need to be consistent. +# +# Ideal architecture: +# [Router] → [Quantize-Dispatch-Permute] → [Experts] → [Unpermute-Combine] +# +class FusedMoEQuantizeDispatchCombine(ABC): + """ + """ @abstractmethod def dispatch( self, @@ -19,22 +59,43 @@ def dispatch( num_experts: int, expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # returns (quantized+dispatched a, - # quantized+dispatched a1_scales) + """ + Perform any quantization (and/or) dispatching needed + for this kernel. + - a1: The (unquantized) input to the MoE layer. + - a1_scale: Optional scales for a1 + - a2_scale: Optional scales for the second MoE gemm. Required to make sure the quantization is consistent for both gemms. + - topk_ids: The topk_ids. + - num_experts: The total number of experts in the global expert space. + - expert_map: A tensor mapping expert indices from the global expert + space to the local expert space of the expert parallel shard. + + Returns a tuple of: + - quantized + dispatched a. + - quantized + dispatched a1_scales. + """ raise NotImplementedError @abstractmethod def combine( self, output: torch.Tensor, - fused_expert_output: torch.Tensor, # not reduced or weighted + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ) -> None: + """ + Perform any combine plus apply weights and perform a reduction on the + fused experts output. + - output: The output tensor, written in place. Must be (M, K) shape. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. + - topk_weights: The weights to be applied to the fused_experts_output. + - topk_ids: The topk_ids. + """ raise NotImplementedError -# store weights, etc. here class FusedMoEPermuteExpertsUnpermute(ABC): @abstractmethod @@ -47,6 +108,19 @@ def workspace_shapes( topk: int, num_experts: int ) -> Tuple[int, int, torch.dtype]: + """ + Compute the number of elements for the temporary outputs of the two + gemms and activation in the fused expert function. Since the + gemms are independent, the workspace for the first gemm can be shared + with the workspace for the last gemm. + + Returns a tuple of: + - Number of workspace13 elements: must be large enough to hold the result + of either expert gemm. + - Number of workspace2 elements: must be large enough to hold the result + of the activation function. + - Workspace type: The dtype to use for the workspace tensors. + """ raise NotImplementedError @abstractmethod @@ -68,6 +142,42 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, ) -> torch.Tensor: + """ + This function computes the intermediate result of a Mixture of Experts (MoE) + layer using two sets of weights, w1 and w2. + + Parameters: + - a1q: (torch.Tensor): The (quantized) input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_ids (torch.Tensor): A map of row to expert id. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs + must be large enough to hold output of either MoE gemm. + - workspace2 (torch.Tensor): A scratch tensor used for the activation + function. + + Returns: + - torch.Tensor: The unweighted, unreduced output tensor + """ raise NotImplementedError @@ -86,7 +196,7 @@ def __init__( def forward( self, - a1: torch.Tensor, # aka hidden states + a1: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, @@ -102,19 +212,45 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # Note: extracting the problem shape from the weight and activation tensors is - # tricky. It needs to be done this way specifically due to subtle issues with - # particular kernels, e.g. the int4 kernels divide the trailing dimension by - # two, so it's not "correct" to extract N or K from the trailing dimension of - # w1 or w2. Similarly, some kernels transpose the weights, so this needs to - # be kept in mind. - # TODO: make this a method/utility function, e.g. problem_size(a, w1, w2, topk_ids, ...) - M, _ = a1.shape - E, N, _ = w1.shape - K = w2.shape[1] + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - a1: (torch.Tensor): The input tensor to the MoE layer (aka hidden_states). + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. + - topk_ids (torch.Tensor): A map of row to expert id. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for + w1. + - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + E, M, N, K, top_k = moe_problem_size(a1, w1, w2, topk_ids) + if global_num_experts == -1: global_num_experts = E - top_k = topk_ids.shape[1] output = a1 if inplace else torch.empty_like(a1) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index fea0c5c1f16..3bc6b50720c 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -17,6 +17,7 @@ def __init__( max_num_tokens: int, world_size: int, dp_size: int, + quant_dtype: Optional[torch.dtype] = None, block_shape: Optional[List[int]] = None): super().__init__() self.a2a = a2a @@ -35,15 +36,19 @@ def dispatch( # Is this always going to be a1.device? device = a1.device - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + if self.quant_dtype == torch.float8_e4m3fn: + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - per_act_token, - ) + a1q, a1q_scale = _fp8_quantize( + a1, + a1_scale, + self.block_shape, + per_act_token, + ) + else: + a1q = a1 + a1q_scale = a1_scale expert_num_tokens = torch.empty( num_experts, From caf9805747789ad4a4564a527ec8420916c3ed8b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 20:47:41 +0000 Subject: [PATCH 016/190] format Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 4 +- .../layers/fused_moe/cutlass_moe.py | 36 +++--- .../layers/fused_moe/deep_gemm_moe.py | 25 ++-- .../layers/fused_moe/dispatch_combine.py | 21 ++-- .../layers/fused_moe/fused_moe.py | 84 +++++++------- .../layers/fused_moe/modular_kernel.py | 109 ++++++++---------- .../layers/fused_moe/pplx_dispatch_combine.py | 46 ++++---- vllm/model_executor/layers/fused_moe/utils.py | 5 +- 8 files changed, 156 insertions(+), 174 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index ac2e002ce5a..a05effa5bd6 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,9 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8, - modular_deep_gemm_fused_moe_fp8, - _valid_deep_gemm_shape) + _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 7ea999d5086..c6b50729b24 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -6,10 +6,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) + StandardDispatchCombine) +from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -204,14 +203,13 @@ def __init__( self.out_dtype = out_dtype def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - K: int, # Note that K, N are transposed - N: int, - topk: int, - num_experts: int - ) -> Tuple[int, int, torch.dtype]: + self, + a_dtype: torch.dtype, + M: int, + K: int, # Note that K, N are transposed + N: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(2 * N, K) workspace2 = M * topk * N return (workspace1, workspace2, self.out_dtype) @@ -246,9 +244,15 @@ def apply( per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - expert_offsets = torch.empty((global_num_experts + 1), dtype=torch.int32, device=device) - problem_sizes1 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) - problem_sizes2 = torch.empty((global_num_experts, 3), dtype=torch.int32, device=device) + expert_offsets = torch.empty((global_num_experts + 1), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((global_num_experts, 3), + dtype=torch.int32, + device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, @@ -258,8 +262,8 @@ def apply( device=device) ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, global_num_experts, - N, K) + problem_sizes2, a_map, c_map, + global_num_experts, N, K) a1q = _fp8_perm(a1q, a_map) a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 19c54dd2c31..6ffb40cb52c 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -7,15 +7,12 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine) +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_permute, _moe_unpermute_and_reduce) from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_permute, - _moe_unpermute_and_reduce -) -from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) from vllm.utils import round_up logger = init_logger(__name__) @@ -32,7 +29,7 @@ def deep_gemm_block_shape() -> list[int]: def _valid_deep_gemm_shape(M: int, N: int, K: int): align = deep_gemm_block_shape()[0] - return M >= align and N % align == 0 and K % align == 0 + return align <= M and N % align == 0 and K % align == 0 # TODO: check types? @@ -247,15 +244,9 @@ def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() - def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int, torch.dtype]: + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index 207a1c69860..06b90c35025 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -1,19 +1,19 @@ -import torch +# SPDX-License-Identifier: Apache-2.0 from typing import Optional, Tuple +import torch + import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_unpermute_and_reduce -) + _moe_unpermute_and_reduce) +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize + class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__( - self, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None - ): + def __init__(self, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): super().__init__() self.block_shape = block_shape self.quant_dtype = quant_dtype @@ -28,7 +28,8 @@ def dispatch( expert_map: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.quant_dtype == torch.float8_e4m3fn: - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + per_act_token = a1_scale.numel( + ) != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) a1q, a1q_scale = _fp8_quantize( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 182bee11f18..a578924fb3d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -14,8 +14,7 @@ from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) + StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -1573,6 +1572,7 @@ def fused_moe( class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( self, use_fp8_w8a8: bool, @@ -1586,15 +1586,9 @@ def __init__( self.use_int8_w8a16 = use_int8_w8a16 self.block_shape = block_shape - def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int, torch.dtype]: + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(N * 2, K) workspace2 = M * topk * N return (workspace1, workspace2, a_dtype) @@ -1622,9 +1616,11 @@ def apply( assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[ + 2], "Hidden size mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ @@ -1637,9 +1633,9 @@ def apply( if global_num_experts == -1: global_num_experts = E top_k_num = topk_ids.shape[1] + # We execute the fused_moe kernel in chunks to circumvent this issue: # https://github.com/vllm-project/vllm/issues/5938 - M = num_tokens config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, @@ -1663,16 +1659,20 @@ def apply( elif hidden_states.dtype == torch.float8_e4m3fn: compute_type = tl.bfloat16 else: - raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") curr_hidden_states = hidden_states tokens_in_chunk, _ = curr_hidden_states.shape # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, (tokens_in_chunk, top_k_num, N)) - intermediate_cache2 = _resize_cache(workspace2, (tokens_in_chunk * top_k_num, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, (tokens_in_chunk, top_k_num, K)) + intermediate_cache1 = _resize_cache(workspace13, + (tokens_in_chunk, top_k_num, N)) + intermediate_cache2 = _resize_cache( + workspace2, (tokens_in_chunk * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, + (tokens_in_chunk, top_k_num, K)) config = get_config_func(tokens_in_chunk) @@ -1721,40 +1721,38 @@ def apply( qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale - invoke_fused_moe_kernel( - qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - block_shape=self.block_shape) + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + block_shape=self.block_shape) return intermediate_cache3 def modular_triton_fused_moe( - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - block_shape: Optional[List[int]] = None, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( StandardDispatchCombine( quant_dtype=torch.float8_e4m3fn if use_fp8_w8a8 else None, - block_shape=block_shape - ), + block_shape=block_shape), TritonExperts( use_fp8_w8a8, use_int8_w8a16, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 7f617a06e2d..196c29eca8a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,6 +4,7 @@ import torch + def moe_problem_size( a1: torch.Tensor, w1: torch.Tensor, @@ -21,8 +22,8 @@ def moe_problem_size( not obvious. It needs to be done this way specifically due to subtle issues with particular kernels, e.g. the int4 kernels divide the trailing dimension by two, so it's not "correct" to extract N or K from the trailing dimension - of w1 or w2. Similarly, some kernels transpose the weights, so this needs to - be kept in mind. + of w1 or w2. Similarly, some kernels transpose the weights, so this needs + to be kept in mind. """ # Make sure we are using the correct a1 (pre-permute) assert topk_ids.shape[0] == a1.shape[0] @@ -32,6 +33,7 @@ def moe_problem_size( topk = topk_ids.shape[1] return E, M, N, K, topk + # # A set of base classes used to make MoE kernels more modular. # @@ -46,9 +48,11 @@ def moe_problem_size( # [Router] → [Quantize-Dispatch-Permute] → [Experts] → [Unpermute-Combine] # + class FusedMoEQuantizeDispatchCombine(ABC): """ """ + @abstractmethod def dispatch( self, @@ -64,7 +68,8 @@ def dispatch( for this kernel. - a1: The (unquantized) input to the MoE layer. - a1_scale: Optional scales for a1 - - a2_scale: Optional scales for the second MoE gemm. Required to make sure the quantization is consistent for both gemms. + - a2_scale: Optional scales for the second MoE gemm. Required to make + sure the quantization is consistent for both gemms. - topk_ids: The topk_ids. - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert @@ -99,15 +104,9 @@ def combine( class FusedMoEPermuteExpertsUnpermute(ABC): @abstractmethod - def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - N: int, - K: int, - topk: int, - num_experts: int - ) -> Tuple[int, int, torch.dtype]: + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: """ Compute the number of elements for the temporary outputs of the two gemms and activation in the fused expert function. Since the @@ -115,10 +114,10 @@ def workspace_shapes( with the workspace for the last gemm. Returns a tuple of: - - Number of workspace13 elements: must be large enough to hold the result - of either expert gemm. - - Number of workspace2 elements: must be large enough to hold the result - of the activation function. + - Number of workspace13 elements: must be large enough to hold the + result of either expert gemm. + - Number of workspace2 elements: must be large enough to hold the + result of the activation function. - Workspace type: The dtype to use for the workspace tensors. """ raise NotImplementedError @@ -143,8 +142,8 @@ def apply( workspace2: torch.Tensor, ) -> torch.Tensor: """ - This function computes the intermediate result of a Mixture of Experts (MoE) - layer using two sets of weights, w1 and w2. + This function computes the intermediate result of a Mixture of Experts + (MoE) layer using two sets of weights, w1 and w2. Parameters: - a1q: (torch.Tensor): The (quantized) input tensor to the MoE layer. @@ -152,24 +151,21 @@ def apply( - w2 (torch.Tensor): The second set of expert weights. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first - MoE layer. + MoE layer. - global_num_experts (int): The total number of experts in the global - expert space. + expert space. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. + w1. - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be used for - a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2. + w2. + - a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be + used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. - workspace13 (torch.Tensor): A scratch tensor used for gemm outputs must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation @@ -213,36 +209,33 @@ def forward( a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. + This function computes a Mixture of Experts (MoE) layer using two sets + of weights, w1 and w2, and top-k gating mechanism. Parameters: - - a1: (torch.Tensor): The input tensor to the MoE layer (aka hidden_states). + - a1: (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. + - topk_weights (torch.Tensor): The topk weights applied at the end of + the layer. - topk_ids (torch.Tensor): A map of row to expert id. - inplace (bool): If True, perform the operation in-place. - Defaults to False. + Defaults to False. - activation (str): The activation function to apply after the first - MoE layer. + MoE layer. - global_num_experts (int): The total number of experts in the global - expert space. + expert space. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. + from the global expert space to the local expert space of the expert + parallel shard. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. - w1_zp (Optional[torch.Tensor]): Optional zero points to be used for - w1. + w1. - w2_zp (Optional[torch.Tensor]): Optional zero points to be used for - w2. - - a1_scale (Optional[torch.Tensor]): Optional scale to be used for - a1. - - a2_scale (Optional[torch.Tensor]): Optional scale to be used for - a2. + w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -255,15 +248,8 @@ def forward( output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes( - a1.dtype, - M, - N, - K, - top_k, - global_num_experts - ) - ) + self.fused_experts.workspace_shapes(a1.dtype, M, N, K, top_k, + global_num_experts)) # We can reuse the memory between cache1 and cache3 because by the time # we need cache3, we're done with cache1 @@ -301,6 +287,7 @@ def forward( workspace2=workspace2, ) - self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids) + self.dispatch_combine.combine(output, fused_out, topk_weights, + topk_ids) return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 3bc6b50720c..7219ea2c0a3 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -1,7 +1,9 @@ -import torch +# SPDX-License-Identifier: Apache-2.0 from typing import List, Optional, Tuple import pplx_kernels as pplx +import torch + import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize @@ -11,14 +13,14 @@ # as the ones used to create the AllToAll. Unfortunately, there's # no way(?) to extract this info from AllToAll class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__( - self, - a2a: pplx.AllToAll, - max_num_tokens: int, - world_size: int, - dp_size: int, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[List[int]] = None): + + def __init__(self, + a2a: pplx.AllToAll, + max_num_tokens: int, + world_size: int, + dp_size: int, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[List[int]] = None): super().__init__() self.a2a = a2a self.block_shape = block_shape @@ -37,7 +39,8 @@ def dispatch( device = a1.device if self.quant_dtype == torch.float8_e4m3fn: - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + per_act_token = a1_scale.numel( + ) != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) a1q, a1q_scale = _fp8_quantize( @@ -65,7 +68,8 @@ def dispatch( expert_x_scale: torch.Tensor | None = None if a1q.dtype.itemsize == 1: float32_size = torch.float32.itemsize - block_size = (self.block_shape[0] if self.block_shape is not None else 1) * float32_size + block_size = (self.block_shape[0] if self.block_shape is not None + else 1) * float32_size expert_x_scale = torch.empty( ( num_experts, @@ -77,7 +81,9 @@ def dispatch( ) # This argument is optional - bound_m = torch.tensor([a1q.shape[0]], dtype=torch.uint32, device=device) + bound_m = torch.tensor([a1q.shape[0]], + dtype=torch.uint32, + device=device) self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -98,14 +104,14 @@ def combine( topk_ids: torch.Tensor, ) -> None: # This argument is optional - bound_m = torch.tensor([output.shape[0]], dtype=torch.uint32, device=output.device) + bound_m = torch.tensor([output.shape[0]], + dtype=torch.uint32, + device=output.device) # TODO assert output is the proper size - self.a2a.combine( - out_tokens=output, - indices=topk_ids, - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m - ) + self.a2a.combine(out_tokens=output, + indices=topk_ids, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index ee8e8857fab..152007d4216 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -31,10 +31,7 @@ def _fp8_quantize( """ if block_shape is None: A, A_scale = ops.scaled_fp8_quant( - A, - A_scale, - use_per_token_if_dynamic=per_act_token - ) + A, A_scale, use_per_token_if_dynamic=per_act_token) else: assert len(block_shape) == 2 _, block_k = block_shape[0], block_shape[1] From fab51b8f6bf15b2b3921641c2ed10d3d93cd50ed Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 21:04:11 +0000 Subject: [PATCH 017/190] comments Signed-off-by: Bill Nell --- .../layers/fused_moe/modular_kernel.py | 53 ++++++++++++------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 196c29eca8a..1b084b198f3 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,6 +4,24 @@ import torch +# +# This file defines a set of base classes used to make MoE kernels more modular. +# The goal is to be able to utilize different communication mechanisms with +# any fused MoE kernel without needing to have combinatoric implementations. +# +# Break the fused moe layer down into the following components. Each component +# will be independent of the others except for [Quantize-Dispatch] and +# [Combine]. The components can then be mixed and matched with different fused +# moe kernels so that DP+EP can be supported easily for multiple MoE +# implementations. +# +# Architecture: +# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] +# +# [Quantize-Dispatch] and [Combine] functionality are bundled into a single +# class `FusedMoEQuantizeDispatchCombine` since they could use collective +# communication mechanisms that need to be consistent. +# def moe_problem_size( a1: torch.Tensor, @@ -34,23 +52,10 @@ def moe_problem_size( return E, M, N, K, topk -# -# A set of base classes used to make MoE kernels more modular. -# -# Architecture: -# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] -# -# [Quantize-Dispatch] and [Combine] functionality are bundled into a single -# class `FusedMoEQuantizeDispatchCombine` since they could use collective -# communication mechanisms that need to be consistent. -# -# Ideal architecture: -# [Router] → [Quantize-Dispatch-Permute] → [Experts] → [Unpermute-Combine] -# - - class FusedMoEQuantizeDispatchCombine(ABC): """ + An abstract base class for the [Quantize-Dispatch] and [Combine] steps + described above. """ @abstractmethod @@ -102,6 +107,10 @@ def combine( class FusedMoEPermuteExpertsUnpermute(ABC): + """ + An abstract base class for the [Permute-Experts-Unpermute] step described + above. + """ @abstractmethod def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, @@ -177,10 +186,18 @@ def apply( raise NotImplementedError -# Note: only intended for use with a single model layer (due to temp buffers, -# constants, etc.) -class FusedMoEModularKernel(torch.nn.Module): # should this be a module? +class FusedMoEModularKernel(torch.nn.Module): + """ + This class combines a FusedMoEQuantizeDispatchCombine instance and + a FusedMoEPermuteExpertsUnpermute to provide an interface that + is compatible with the `fused_experts` function in fused_moe.py. + + It takes care of managing any required scratch space. + Note: Instances of this class should only be used for a single model + layer due to any layer specific state that may be used by the component + objects. + """ def __init__( self, dispatch_combine: FusedMoEQuantizeDispatchCombine, From 5694036d895ea306d4924a3df921cdca33d98c20 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 21:18:30 +0000 Subject: [PATCH 018/190] fix linter Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 5 ++++- vllm/model_executor/layers/fused_moe/modular_kernel.py | 2 +- .../model_executor/layers/fused_moe/pplx_dispatch_combine.py | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index c6b50729b24..ea903e2500a 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -206,10 +206,12 @@ def workspace_shapes( self, a_dtype: torch.dtype, M: int, - K: int, # Note that K, N are transposed N: int, + K: int, topk: int, num_experts: int) -> Tuple[int, int, torch.dtype]: + # Note that K, N are transposed + N, K = K, N workspace1 = M * topk * max(2 * N, K) workspace2 = M * topk * N return (workspace1, workspace2, self.out_dtype) @@ -240,6 +242,7 @@ def apply( assert w1.shape[1] == K assert global_num_experts != -1 + assert a1q_scale is not None per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 1b084b198f3..f56790d4dcc 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -28,7 +28,7 @@ def moe_problem_size( w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, -) -> Tuple[int, int, int, int]: +) -> Tuple[int, int, int, int, int]: """ Extract the MoE problem size from the given tensor arguments: - a: The hidden states, input to the MoE layer. diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 7219ea2c0a3..fc5ff1ae020 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -25,6 +25,7 @@ def __init__(self, self.a2a = a2a self.block_shape = block_shape self.dp_num_tokens = max_num_tokens * (world_size // dp_size) + self.quant_dtype = quant_dtype def dispatch( self, From 124f0ba503fa4bd140acb2703ce31f85e2900fba Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 21:26:18 +0000 Subject: [PATCH 019/190] fix more linter stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 11 +++-------- .../model_executor/layers/fused_moe/modular_kernel.py | 5 ++++- .../layers/fused_moe/pplx_dispatch_combine.py | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index ea903e2500a..64e6d425bde 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -202,14 +202,9 @@ def __init__( self.c_strides2 = c_strides2 self.out_dtype = out_dtype - def workspace_shapes( - self, - a_dtype: torch.dtype, - M: int, - N: int, - K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: # Note that K, N are transposed N, K = K, N workspace1 = M * topk * max(2 * N, K) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f56790d4dcc..5db49a630a4 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -23,6 +23,7 @@ # communication mechanisms that need to be consistent. # + def moe_problem_size( a1: torch.Tensor, w1: torch.Tensor, @@ -43,7 +44,8 @@ def moe_problem_size( of w1 or w2. Similarly, some kernels transpose the weights, so this needs to be kept in mind. """ - # Make sure we are using the correct a1 (pre-permute) + + # Make sure we are using the correct a1 (pre-permute). assert topk_ids.shape[0] == a1.shape[0] M, _ = a1.shape E, N, _ = w1.shape @@ -198,6 +200,7 @@ class FusedMoEModularKernel(torch.nn.Module): layer due to any layer specific state that may be used by the component objects. """ + def __init__( self, dispatch_combine: FusedMoEQuantizeDispatchCombine, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index fc5ff1ae020..5c844ff57a7 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -66,7 +66,7 @@ def dispatch( device=device, ) - expert_x_scale: torch.Tensor | None = None + expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: float32_size = torch.float32.itemsize block_size = (self.block_shape[0] if self.block_shape is not None From d98a3c34ee7cf398962036d7243b93e270827aa8 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 23:20:14 +0000 Subject: [PATCH 020/190] cleanup for review Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 20 +- tests/kernels/test_cutlass_moe.py | 102 ++----- .../layers/fused_moe/cutlass_moe.py | 74 ++++- .../layers/fused_moe/deep_gemm_moe.py | 254 +++++------------- .../layers/fused_moe/fused_moe.py | 51 +--- .../layers/fused_moe/modular_kernel.py | 28 +- 6 files changed, 199 insertions(+), 330 deletions(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index a05effa5bd6..5a02270b3bf 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -425,21 +425,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - if True: - dgm = modular_deep_gemm_fused_moe_fp8() - - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids): - return dgm(a, - w1, - w2, - topk_weights, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s) - else: - deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -452,8 +437,7 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_weights, topk_ids, token_expert_indices = fused_topk( a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py index 0dc572c7288..3cfed6ae853 100644 --- a/tests/kernels/test_cutlass_moe.py +++ b/tests/kernels/test_cutlass_moe.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional - import pytest import torch from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8, modular_cutlass_moe_fp8) +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) from vllm.platforms import current_platform @@ -16,48 +13,6 @@ TOP_KS = [6, 8] -def get_cutlass_moe_fp8(ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, - out_dtype=torch.half) -> Callable: - if True: - return modular_cutlass_moe_fp8( - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - out_dtype, - ) - else: - - def cutlass_moe_fp8_fn( - a: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - a1_scale: Optional[torch.Tensor], - ) -> torch.Tensor: - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale, - out_dtype=out_dtype) - - return cutlass_moe_fp8_fn - - def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, @@ -66,22 +21,18 @@ def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig( pipeline_parallel_size=1))): - - cutlass_moe_fp8_fn = get_cutlass_moe_fp8( - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - ) - - return cutlass_moe_fp8_fn(a, - w1_q, - w2_q, - w1_scale=w1_scale, - w2_scale=w2_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - a1_scale=a_scale) + return cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale) @pytest.mark.parametrize("m", [2, 64, 224]) @@ -167,21 +118,18 @@ def test_cutlass_moe_no_graph( triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - cutlass_moe_fp8_fn = get_cutlass_moe_fp8( - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - ) - - cutlass_output = cutlass_moe_fp8_fn(a, - w1_q, - w2_q, - w1_scale=w1_scale, - w2_scale=w2_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - a1_scale=a_scale1) + cutlass_output = cutlass_moe_fp8(a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale1) #print(triton_output) #print(cutlass_output) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 64e6d425bde..669505c656d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -229,7 +229,6 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, ) -> torch.Tensor: - # TODO: chunking in here or in FusedMoEModularKernel? ignore for now M = a1q.shape[0] _, N, K = w2.shape # because w1 + w2 are transposed topk = topk_ids.shape[1] @@ -311,3 +310,76 @@ def modular_cutlass_moe_fp8( out_dtype, ), ) + + +#TODO make the grouped gemm kernel consistent with scaled gemm kernel +def cutlass_moe_fp8( + a: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.half, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with CUTLASS + grouped gemm. + + Parameters: + - a (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - ab_strides1 (torch.Tensor): The input and weights strides of the first + grouped gemm. + - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. + - ab_strides2 (torch.Tensor): The input and weights strides of the second + grouped gemm. + - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + - out_dtype (torch.dtype): The output tensor type. + + Returns: + - torch.Tensor: The fp16 output tensor after applying the MoE layer. + """ + fn = modular_cutlass_moe_fp8( + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + out_dtype, + ) + return fn( + a, + w1_q, + w2_q, + topk_weights, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 6ffb40cb52c..b19d1f52fa4 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -4,13 +4,12 @@ import torch -import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.dispatch_combine import ( StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_permute, _moe_unpermute_and_reduce) + _moe_permute) from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) from vllm.utils import round_up @@ -58,186 +57,6 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, and w2.is_contiguous()) -def deep_gemm_moe_fp8( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a a8w8-quantized Mixture of Experts (MoE) layer - using two sets of quantized weights, w1_q and w2_q, and top-k gating - mechanism. The matrix multiplications are implemented with DeepGemm - grouped gemm. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - Shape: [M, K] - - w1 (torch.Tensor): The first set of fp8 quantized expert weights. - Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2 (torch.Tensor): The second set of fp8 quantized expert weights. - Shape: [num_experts, N, K] (the weights are passed transposed) - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - topk_weights (torch.Tensor): The weights of each token->expert mapping. - - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - activation (str): The activation function to apply after the first - MoE layer. - - global_num_experts (int): The total number of experts in the global - expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert - parallel shard. - - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - Shape: scalar or [M] - - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - quantize the intermediate result between the gemms. - Shape: scalar or [M] - - Returns: - - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. - """ - # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg - - assert expert_map is None, "Expert maps not supported yet" - - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - assert w1.dtype == torch.float8_e4m3fn - assert w2.dtype == torch.float8_e4m3fn - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[ - 0] == hidden_states.shape[0], "Input scale shape mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape - K = w2.shape[1] - if global_num_experts == -1: - global_num_experts = E - - # We execute the fused_moe kernel in chunks to circumvent this issue: - # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - - assert _valid_deep_gemm(hidden_states, w1, w2, expert_map) - - if inplace: - out_hidden_states = hidden_states - else: - out_hidden_states = torch.empty_like(hidden_states) - - block_m = dg.get_m_alignment_for_contiguous_layout() - block_shape = [block_m, block_m] - - assert w1_scale is not None - assert w2_scale is not None - - # We attempt to transpose and align offline in Fp8MoEMethod, in which - # case these calls will be nops. Otherwise, they'll be performed every - # time the layer is executed. - w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() - w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - - M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) - M_sum = round_up(M_sum, block_m) - - num_chunks = (num_tokens // CHUNK_SIZE) + 1 - - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - workspace13 = torch.empty(M_sum * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - - workspace1 = workspace13[:M_sum * N].view(M_sum, N) - workspace2 = torch.empty((M_sum, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - workspace3 = workspace13[:M_sum * K].view(M_sum, K) - - for chunk in range(num_chunks): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) - curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape - - if tokens_in_chunk == 0: - break - - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - - a1q_scale: Optional[torch.Tensor] = None - - qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states, - a1_scale, block_shape) - - (qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale, - curr_topk_ids, global_num_experts, - expert_map, block_m) - - # Adjust the intermediate cache size and config for the last chunk. - # Note that in most cases we only have one chunk so the cache size - # and config are already set correctly and do not need to be adjusted. - if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - curr_M = sorted_token_ids.numel() - workspace1 = _resize_cache(workspace1, (curr_M, N)) - workspace2 = _resize_cache(workspace2, (curr_M, N // 2)) - workspace3 = _resize_cache(workspace3, (curr_M, K)) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1, - expert_ids) - - if activation == "silu": - torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") - - a2q_scale: Optional[torch.Tensor] = None - - qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale, - block_shape) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids) - - _moe_unpermute_and_reduce( - out_hidden_states[begin_chunk_idx:end_chunk_idx], - workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights) - - return out_hidden_states - - class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self): @@ -274,7 +93,6 @@ def apply( ) -> torch.Tensor: import deep_gemm as dg - # TODO: chunking in here or in FusedMoEModularKernel? ignore for now _, N, K = w1.shape assert global_num_experts != -1 @@ -323,3 +141,73 @@ def modular_deep_gemm_fused_moe_fp8() -> mk.FusedMoEModularKernel: block_shape=deep_gemm_block_shape()), DeepGemmExperts(), ) + + +def deep_gemm_moe_fp8( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a a8w8-quantized Mixture of Experts (MoE) layer + using two sets of quantized weights, w1_q and w2_q, and top-k gating + mechanism. The matrix multiplications are implemented with DeepGemm + grouped gemm. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + Shape: [M, K] + - w1 (torch.Tensor): The first set of fp8 quantized expert weights. + Shape: [num_experts, K, 2N] (the weights are passed transposed) + - w2 (torch.Tensor): The second set of fp8 quantized expert weights. + Shape: [num_experts, N, K] (the weights are passed transposed) + - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. + Shape: [num_experts] or [num_experts, 2N] + - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. + Shape: [num_experts] or [num_experts, K] + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mapping for topk_weights. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - activation (str): The activation function to apply after the first + MoE layer. + - global_num_experts (int): The total number of experts in the global + expert space. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. + Shape: scalar or [M] + - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to + quantize the intermediate result between the gemms. + Shape: scalar or [M] + + Returns: + - torch.Tensor: The bfloat16 output tensor after applying the MoE layer. + """ + fn = modular_deep_gemm_fused_moe_fp8() + return fn( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a578924fb3d..3927c56b62e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1157,30 +1157,6 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale=a1_scale, a2_scale=a2_scale, ) - elif hidden_states.shape[0] <= envs.VLLM_FUSED_MOE_CHUNK_SIZE: - fe = modular_triton_fused_moe( - use_fp8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - block_shape, - ) - return fe( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1_scale, - a2_scale, - ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, @@ -1634,19 +1610,17 @@ def apply( global_num_experts = E top_k_num = topk_ids.shape[1] - # We execute the fused_moe kernel in chunks to circumvent this issue: - # https://github.com/vllm-project/vllm/issues/5938 config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, dtype=hidden_states.dtype) - get_config_func = functools.partial( - try_get_optimal_moe_config, + config = try_get_optimal_moe_config( w1.shape, w2.shape, top_k_num, config_dtype, + num_tokens, block_shape=self.block_shape, ) @@ -1662,29 +1636,20 @@ def apply( raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") - curr_hidden_states = hidden_states - tokens_in_chunk, _ = curr_hidden_states.shape - # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, - (tokens_in_chunk, top_k_num, N)) - intermediate_cache2 = _resize_cache( - workspace2, (tokens_in_chunk * top_k_num, N // 2)) + (num_tokens, top_k_num, N)) + intermediate_cache2 = _resize_cache(workspace2, + (num_tokens * top_k_num, N // 2)) intermediate_cache3 = _resize_cache(workspace13, - (tokens_in_chunk, top_k_num, K)) - - config = get_config_func(tokens_in_chunk) - - curr_topk_ids = topk_ids - - qcurr_hidden_states, a1q_scale = hidden_states, a1q_scale + (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - invoke_fused_moe_kernel(qcurr_hidden_states, + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, a1q_scale, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 5db49a630a4..2dcbf0dd341 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -9,22 +9,34 @@ # The goal is to be able to utilize different communication mechanisms with # any fused MoE kernel without needing to have combinatoric implementations. # -# Break the fused moe layer down into the following components. Each component -# will be independent of the others except for [Quantize-Dispatch] and -# [Combine]. The components can then be mixed and matched with different fused -# moe kernels so that DP+EP can be supported easily for multiple MoE -# implementations. +# The fused moe kernels are broken down into the following components: # -# Architecture: # [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] # +# Each component will be independent of the others except for +# [Quantize-Dispatch] and `[Combine] (see below). The components can then be +# mixed and matched with so that DP+EP can be supported easily for multiple +# MoE kernel implementations. +# +# The following main classes are defined: +# * FusedMoEQuantizeDispatchCombine - an abstract base class for quantization, +# dispatching and combing. The dispatch method takes care of any needed +# quantization and the combine method applies weights and does the final +# reduction of the output. +# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused +# MoE operation. One important feature to note is that this class does not +# apply topk weights or reduce the final output. +# * FusedMoEModularKernel - an interface class that combines a +# FusedMoEQuantizeDispatchCombine and a FusedMoEPermuteExpertsUnpermute to +# provide the standard fused MoE kernel interface. +# # [Quantize-Dispatch] and [Combine] functionality are bundled into a single # class `FusedMoEQuantizeDispatchCombine` since they could use collective # communication mechanisms that need to be consistent. # -def moe_problem_size( +def _moe_problem_size( a1: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -260,7 +272,7 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - E, M, N, K, top_k = moe_problem_size(a1, w1, w2, topk_ids) + E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) if global_num_experts == -1: global_num_experts = E From a00c12cbbf4640c2c853346b1c760a18359bc020 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 01:51:15 +0000 Subject: [PATCH 021/190] review comments Signed-off-by: Bill Nell --- .../layers/fused_moe/deep_gemm_moe.py | 16 ++++++++++++---- .../layers/fused_moe/pplx_dispatch_combine.py | 4 +++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b19d1f52fa4..250f03ae7f0 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -31,7 +31,6 @@ def _valid_deep_gemm_shape(M: int, N: int, K: int): return align <= M and N % align == 0 and K % align == 0 -# TODO: check types? def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -42,19 +41,28 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, aligned by `dg.get_m_alignment_for_contiguous_layout()`. """ if not has_deep_gemm: + logger.debug("DeepGemm disabled: deep_gemm not available.") return False - # Expert maps not supported yet. if expert_map is not None: + logger.debug("DeepGemm disabled: expert map NYI.") return False M = hidden_states.shape[0] _, K, N = w2.shape if not _valid_deep_gemm_shape(M, N, K): + logger.debug("DeepGemm disabled: unalinged problem size.") return False - return (hidden_states.is_contiguous() and w1.is_contiguous() - and w2.is_contiguous()) + if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + logger.debug("DeepGemm disabled: invalid weight dtype(s).") + return False + + if (not hidden_states.is_contiguous() or not w1.is_contiguous() + or not w2.is_contiguous()): + logger.debug( + "DeepGemm disabled: weights or activations not contiguous.") + return False class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 5c844ff57a7..936aee14a7b 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -24,6 +24,7 @@ def __init__(self, super().__init__() self.a2a = a2a self.block_shape = block_shape + self.max_num_tokens = max_num_tokens self.dp_num_tokens = max_num_tokens * (world_size // dp_size) self.quant_dtype = quant_dtype @@ -109,7 +110,8 @@ def combine( dtype=torch.uint32, device=output.device) - # TODO assert output is the proper size + assert output.shape[0] == self.max_num_tokens + assert output.shape[1] == fused_expert_output.shape[-1] self.a2a.combine(out_tokens=output, indices=topk_ids, From 27cc6d165580979761df9d619f958ac34723612e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 01:58:16 +0000 Subject: [PATCH 022/190] forgot return Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 250f03ae7f0..e9adb335355 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -64,6 +64,8 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, "DeepGemm disabled: weights or activations not contiguous.") return False + return True + class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): From 1db9fcf18d7f3bb75fef8082bbf39b04c477556e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 15:08:02 +0000 Subject: [PATCH 023/190] add dp_rank_num_tokens to DPMetadata Signed-off-by: Bill Nell --- vllm/forward_context.py | 7 ++++++- .../layers/fused_moe/pplx_dispatch_combine.py | 9 +++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 9ddc3d1f2c5..a9815aba1d6 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -32,6 +32,7 @@ @dataclass class DPMetadata: cu_tokens_across_dp_cpu: torch.Tensor + dp_rank_num_tokens: torch.Tensor @dataclass @@ -95,7 +96,11 @@ def set_forward_context(attn_metadata: Any, from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) - dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) + dp_rank_num_tokens = torch.tensor( + [num_tokens], + dtype=torch.uint32, + device=vllm_config.device_config.device) + dp_metadata = DPMetadata(cu_tokens_across_dp_cpu, dp_rank_num_tokens) global _forward_context prev_context = _forward_context diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 936aee14a7b..d35cfaccd39 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize @@ -83,9 +84,7 @@ def dispatch( ) # This argument is optional - bound_m = torch.tensor([a1q.shape[0]], - dtype=torch.uint32, - device=device) + bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -106,9 +105,7 @@ def combine( topk_ids: torch.Tensor, ) -> None: # This argument is optional - bound_m = torch.tensor([output.shape[0]], - dtype=torch.uint32, - device=output.device) + bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens assert output.shape[0] == self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] From 374c55cf9dcb881b456d32090de25a470ec8f4d7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 22:29:28 +0000 Subject: [PATCH 024/190] better check for fp8 in _fp8_permute Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 152007d4216..93f9158f15e 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -44,7 +44,7 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """ A permutation routine that works on fp8 types. """ - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: + if torch.is_floating_point(m) and m.dtype.itemsize == 1: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] From fbf937060d540dc7813130da3fefa4d604f10727 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 28 Apr 2025 18:38:48 +0000 Subject: [PATCH 025/190] updates Signed-off-by: Bill Nell --- .../layers/fused_moe/deep_gemm_moe.py | 22 ++++---- .../layers/fused_moe/dispatch_combine.py | 4 +- .../layers/fused_moe/fused_moe.py | 48 ++++++++++-------- .../layers/fused_moe/modular_kernel.py | 50 ++++++++++++++----- vllm/model_executor/layers/fused_moe/utils.py | 2 +- 5 files changed, 79 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index e9adb335355..e43c984f7d5 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -73,15 +73,21 @@ def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) workspace1 = M_sum * max(N * 2, K) workspace2 = M_sum * N - return (workspace1, workspace2, a_dtype) + return (workspace1, workspace2, a.dtype) def apply( self, @@ -100,6 +106,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: import deep_gemm as dg @@ -126,12 +133,7 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) - if activation == "silu": - torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N)) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N)) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") + self.activation(activation, workspace2, workspace1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index 06b90c35025..398aab60c66 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -26,7 +26,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: if self.quant_dtype == torch.float8_e4m3fn: per_act_token = a1_scale.numel( ) != 1 if a1_scale is not None else ( @@ -42,7 +42,7 @@ def dispatch( a1q = a1 a1q_scale = a1_scale - return a1q, a1q_scale + return a1q, a1q_scale, None def combine( self, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3927c56b62e..92d6d24799b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1262,7 +1262,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2], \ + f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" @@ -1272,7 +1273,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, torch.float32, torch.float16, torch.bfloat16 ] - num_tokens, _ = hidden_states.shape + num_tokens = hidden_states.shape[0] E, N, _ = w1.shape K = w2.shape[1] if global_num_experts == -1: @@ -1554,20 +1555,28 @@ def __init__( use_fp8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, - block_shape: Optional[List[int]], + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, ): super().__init__() self.use_fp8_w8a8 = use_fp8_w8a8 self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a16 = use_int8_w8a16 self.block_shape = block_shape + self.block_m = block_m - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: workspace1 = M * topk * max(N * 2, K) workspace2 = M * topk * N - return (workspace1, workspace2, a_dtype) + return (workspace1, workspace2, a.dtype) def apply( self, @@ -1586,14 +1595,16 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: # Check constraints. if self.use_int4_w4a16: - assert hidden_states.shape[1] // 2 == w1.shape[ + assert hidden_states.shape[-1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[ - 2], "Hidden size mismatch" + assert hidden_states.shape[-1] == w1.shape[2], \ + (f"Hidden size mismatch {hidden_states.shape[-1]} " + f"!= {w1.shape[2]}") assert hidden_states.is_contiguous( ), "Hidden_states must be contiguous" @@ -1603,12 +1614,11 @@ def apply( torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn ] - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape - K = w2.shape[1] + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + if global_num_experts == -1: global_num_experts = E - top_k_num = topk_ids.shape[1] config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, @@ -1668,14 +1678,8 @@ def apply( use_int4_w4a16=self.use_int4_w4a16, block_shape=self.block_shape) - if activation == "silu": - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") + self.activation(activation, intermediate_cache2, + intermediate_cache1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2dcbf0dd341..b517f6ee13c 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -56,13 +56,19 @@ def _moe_problem_size( of w1 or w2. Similarly, some kernels transpose the weights, so this needs to be kept in mind. """ - - # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.shape[0] == a1.shape[0] - M, _ = a1.shape + assert w1.dim() == 3 and w2.dim() == 3 E, N, _ = w1.shape K = w2.shape[1] + + assert a1.dim() == 2 + assert topk_ids.dim() == 2 + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.shape[0] == a1.shape[ + 0], f"{topk_ids.shape[0]} != {a1.shape[0]}" + + M = a1.shape[0] topk = topk_ids.shape[1] + return E, M, N, K, topk @@ -81,7 +87,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform any quantization (and/or) dispatching needed for this kernel. @@ -127,9 +133,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ @abstractmethod - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: """ Compute the number of elements for the temporary outputs of the two gemms and activation in the fused expert function. Since the @@ -145,6 +157,15 @@ def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, """ raise NotImplementedError + def activation(self, activation: str, output: torch.Tensor, + input: torch.Tensor) -> None: + if activation == "silu": + torch.ops._C.silu_and_mul(output, input) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(output, input) + else: + raise ValueError(f"Unsupported FusedMoe activation: {activation}") + @abstractmethod def apply( self, @@ -163,6 +184,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: """ This function computes the intermediate result of a Mixture of Experts @@ -193,6 +215,8 @@ def apply( must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation function. + - expert_num_tokens: An optional tensor containing the number of tokens + assigned to each expert when using batched experts format input. Returns: - torch.Tensor: The unweighted, unreduced output tensor @@ -224,7 +248,7 @@ def __init__( def forward( self, - a1: torch.Tensor, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, @@ -245,7 +269,7 @@ def forward( of weights, w1 and w2, and top-k gating mechanism. Parameters: - - a1: (torch.Tensor): The input tensor to the MoE layer. + - hidden_states: (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_weights (torch.Tensor): The topk weights applied at the end of @@ -272,6 +296,7 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ + a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) if global_num_experts == -1: @@ -280,7 +305,7 @@ def forward( output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes(a1.dtype, M, N, K, top_k, + self.fused_experts.workspace_shapes(a1, M, N, K, top_k, global_num_experts)) # We can reuse the memory between cache1 and cache3 because by the time @@ -292,7 +317,7 @@ def forward( device=a1.device, dtype=workspace_dtype) - a1q, a1q_scale = self.dispatch_combine.dispatch( + a1q, a1q_scale, expert_num_tokens = self.dispatch_combine.dispatch( a1, a1_scale, a2_scale, @@ -317,6 +342,7 @@ def forward( a2_scale, workspace13=workspace13, workspace2=workspace2, + expert_num_tokens=expert_num_tokens, ) self.dispatch_combine.combine(output, fused_out, topk_weights, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 93f9158f15e..eff39c0f792 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -15,7 +15,7 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - assert prod(v) <= x.numel() + assert prod(v) <= x.numel(), f"{prod(v)} <= {x.numel()}" return x.flatten()[:prod(v)].view(*v) From be5c8d8af507f648a95bfa80fd6f62f3125adb55 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 02:06:11 +0000 Subject: [PATCH 026/190] fix merge issues Signed-off-by: Bill Nell --- tests/kernels/moe/test_cutlass_moe.py | 22 +- tests/kernels/moe/test_moe.py | 43 +-- tests/kernels/quantization/test_block_fp8.py | 29 +- tests/kernels/quantization/test_block_int8.py | 5 +- tests/kernels/test_cutlass_moe.py | 244 --------------- .../layers/fused_moe/cutlass_moe.py | 286 ++++++------------ .../layers/fused_moe/deep_gemm_moe.py | 8 +- .../layers/fused_moe/dispatch_combine.py | 43 +-- .../layers/fused_moe/fused_moe.py | 148 ++++----- .../layers/fused_moe/modular_kernel.py | 44 +-- .../layers/fused_moe/moe_permute_unpermute.py | 4 +- .../layers/fused_moe/pplx_dispatch_combine.py | 41 ++- vllm/model_executor/layers/fused_moe/utils.py | 48 ++- 13 files changed, 347 insertions(+), 618 deletions(-) delete mode 100644 tests/kernels/test_cutlass_moe.py diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 975cd418a17..7d24307e353 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -30,6 +30,11 @@ (224, 3072, 1536), ] +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @dataclasses.dataclass class MOETensors: @@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'topk_weights': topk_weights, - 'topk_ids_': topk_ids, + 'topk_ids': topk_ids, 'ab_strides1': moe_tensors.ab_strides1, 'c_strides1': moe_tensors.c_strides1, 'ab_strides2': moe_tensors.ab_strides2, @@ -231,10 +236,7 @@ def test_cutlass_moe_8_bit_no_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) @@ -276,10 +278,7 @@ def test_cutlass_moe_8_bit_cuda_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): dtype = torch.half mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, @@ -334,10 +333,7 @@ def test_cutlass_moe_8_bit_EP( ep_size: int, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 0e671ac9683..5250dd82fa1 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -11,7 +11,9 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import opcheck, stack_and_dev, torch_moe +from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, + torch_moe_single) +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( @@ -30,6 +32,10 @@ EP_SIZE = [1, 4] TOP_KS = [2, 6] +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @@ -68,7 +74,6 @@ def test_fused_moe( else: e_map = None - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): torch_output = torch_moe(a, w1, w2, score, topk, e_map) iterative_output = iterative_moe(a, @@ -195,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, else: e_map = None - triton_output = fused_moe(a, - w1_qweight, - w2_qweight, - score, - topk, - renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=e, - expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + with set_current_vllm_config(vllm_config): + triton_output = fused_moe(a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=e, + expert_map=e_map, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size]) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 5a02270b3bf..11fb5000713 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -211,6 +211,9 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # Set the context to avoid lots of warning spam. vllm_config = VllmConfig() + vllm_config.scheduler_config.max_num_seqs = 128 + vllm_config.scheduler_config.max_model_len = 8192 + with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -387,8 +390,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): if not _valid_deep_gemm_shape(M, N, K): pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") - vllm_config = VllmConfig() - torch.manual_seed(seed) fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min @@ -425,7 +426,26 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + if True: + dgm = modular_deep_gemm_fused_moe_fp8() + + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids): + return dgm(a, + w1, + w2, + topk_weights, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + # Set the context to avoid lots of warning spam. + vllm_config = VllmConfig() + vllm_config.scheduler_config.max_num_seqs = 128 + vllm_config.scheduler_config.max_model_len = 8192 + with set_current_vllm_config(vllm_config): if M >= 128: ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, @@ -437,7 +457,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): topk_weights, topk_ids, token_expert_indices = fused_topk( a, score.float(), topk, False) - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index 104f23fd7cd..a4e9f83f0ea 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -18,6 +18,10 @@ pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # For test def native_per_token_group_quant_int8(x, @@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, diff --git a/tests/kernels/test_cutlass_moe.py b/tests/kernels/test_cutlass_moe.py deleted file mode 100644 index 3cfed6ae853..00000000000 --- a/tests/kernels/test_cutlass_moe.py +++ /dev/null @@ -1,244 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import pytest -import torch - -from vllm import _custom_ops as ops -from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, - fused_topk) -from vllm.platforms import current_platform - -NUM_EXPERTS = [40, 64] -TOP_KS = [6, 8] - - -def run(a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif( - (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_no_graph( - m: int, - n: int, - k: int, - e: int, - topk: int, - per_act_token: bool, - per_out_ch: bool, -): - current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) - - a_d = a_q.float().mul(a_scale1).to(dtype) - - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - - cutlass_output = cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale1) - - #print(triton_output) - #print(cutlass_output) - #print("*") - - torch.testing.assert_close(triton_output, - cutlass_output, - atol=5e-2, - rtol=1e-2) - - -@pytest.mark.parametrize("m", [2, 64, 224]) -@pytest.mark.parametrize("n", [1024, 3072]) -@pytest.mark.parametrize("k", [1024, 1536]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("per_act_token", [True, False]) -@pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.skipif( - (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_moe_cuda_graph( - m: int, - n: int, - k: int, - e: int, - topk: int, - per_act_token: bool, - per_out_ch: bool, -): - current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - - dtype = torch.half - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - # Get the right scale for tests. - _, a_scale1 = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=per_act_token) - a_q, _ = ops.scaled_fp8_quant(a, - a_scale1, - use_per_token_if_dynamic=per_act_token) - - a_d = a_q.float().mul(a_scale1).to(dtype) - - n_b_scales = 2 * n if per_out_ch else 1 - k_b_scales = k if per_out_ch else 1 - - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - for expert in range(e): - w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) - w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) - - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - - w1_d = torch.empty_like(w1) - w2_d = torch.empty_like(w2) - for expert in range(e): - w1_d[expert] = (w1_q[expert].t().float() * w1_scale[expert]).half() - w2_d[expert] = (w2_q[expert].t().float() * w2_scale[expert]).half() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) - - triton_output = fused_experts(a_d, w1_d, w2_d, topk_weights, topk_ids) - - stream = torch.cuda.Stream() - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, stream=stream): - cutlass_output = run(a, a_scale1, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, - c_strides1, ab_strides2, c_strides2) - torch.cuda.synchronize() - graph.replay() - torch.cuda.synchronize() - - #print(triton_output) - #print(cutlass_output) - #print("*") - - torch.testing.assert_close(triton_output, - cutlass_output, - atol=9e-2, - rtol=1e-2) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 669505c656d..d718ac1f3f3 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -11,180 +11,6 @@ from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache -#TODO make the grouped gemm kernel consistent with scaled gemm kernel -def cutlass_moe_fp8( - a: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids_: torch.Tensor, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.half, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - """ - This function computes a a8w8-quantized Mixture of Experts (MoE) layer - using two sets of quantized weights, w1_q and w2_q, and top-k gating - mechanism. The matrix multiplications are implemented with CUTLASS - grouped gemm. - - Parameters: - - a (torch.Tensor): The input tensor to the MoE layer. - Shape: [M, K] - - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. - Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. - Shape: [num_experts, N, K] (the weights are passed transposed) - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk_weights (torch.Tensor): The weights of each token->expert mapping. - - ab_strides1 (torch.Tensor): The input and weights strides of the first - grouped gemm. - - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. - - ab_strides2 (torch.Tensor): The input and weights strides of the second - grouped gemm. - - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. - - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - Shape: scalar or [M] - - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - quantize the intermediate result between the gemms. - Shape: scalar or [M] - - out_dtype (torch.dtype): The output tensor type. - - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, - every Rank is responsible for a subset of experts. expert_map is a - mapping from global expert-id to local expert-id. When expert_map[i] - is -1, it means that this Rank is not responsible for global - expert-id i. - - apply_router_weight_on_input (bool): When true, the topk weights are - applied directly on the inputs. This is only applicable when topk is 1. - - Returns: - - torch.Tensor: The fp16 output tensor after applying the MoE layer. - """ - - assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch" - assert w1_q.dtype == torch.float8_e4m3fn - assert w2_q.dtype == torch.float8_e4m3fn - assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" - assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" - assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert a1_scale is None or a1_scale.dim( - ) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[ - 0], "Input scale shape mismatch" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1_q.shape[2], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2_q.shape[2], "W2 scale shape mismatch" - assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[ - 0], "w1 scales expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[ - 0], "w2 scales expert number mismatch" - assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - assert ab_strides1.shape[0] == w1_q.shape[ - 0], "AB Strides 1 expert number mismatch" - assert c_strides1.shape[0] == w1_q.shape[ - 0], "C Strides 1 expert number mismatch" - assert ab_strides2.shape[0] == w2_q.shape[ - 0], "AB Strides 2 expert number mismatch" - assert c_strides2.shape[0] == w2_q.shape[ - 0], "C Strides 2 expert number mismatch" - assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" - - num_experts = w1_q.size(0) - m = a.size(0) - k = w1_q.size(1) - n = w2_q.size(1) - - local_topk_ids = topk_ids_ - if expert_map is not None: - "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where(expert_map[topk_ids_] != -1, - expert_map[topk_ids_], -1) - - topk = local_topk_ids.size(1) - - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - if apply_router_weight_on_input: - assert topk == 1, \ - "apply_router_weight_on_input is only implemented for topk=1" - # TODO: this only works for topK=1, will need to update for topK>1 - a = a * topk_weights.to(out_dtype) - - a_q, a1_scale = ops.scaled_fp8_quant( - a, a1_scale, use_per_token_if_dynamic=per_act_token) - device = a_q.device - - expert_offsets = torch.empty((num_experts + 1), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - - a_map_initializer = torch.empty - c2_initializer = torch.empty - if expert_map is not None: - # With expert_map each Rank processes only a subset of experts. As - # a result not all of a_map and c2 tensors are filled. We fill it - # zeros for correctness. - a_map_initializer = torch.zeros - c2_initializer = torch.zeros - - a_map = a_map_initializer((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - c_map = torch.empty((local_topk_ids.numel()), - dtype=torch.int32, - device=device) - - ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, num_experts, n, - k) - - rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale - - c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) - c2 = c2_initializer((m * topk, k), device=device, dtype=out_dtype) - - ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale, - expert_offsets[:-1], problem_sizes1, ab_strides1, - ab_strides1, c_strides1) - - intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) - torch.ops._C.silu_and_mul(intermediate, c1) - - intemediate_q, a2_scale = ops.scaled_fp8_quant( - intermediate, a2_scale, use_per_token_if_dynamic=per_act_token) - - ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale, - expert_offsets[:-1], problem_sizes2, ab_strides2, - ab_strides2, c_strides2) - - # Gather tokens - c2 = c2[c_map].view(m, topk, k) - if not apply_router_weight_on_input: - c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) - return c2.sum(dim=1) - - class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -202,9 +28,15 @@ def __init__( self.c_strides2 = c_strides2 self.out_dtype = out_dtype - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: # Note that K, N are transposed N, K = K, N workspace1 = M * topk * max(2 * N, K) @@ -213,7 +45,7 @@ def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, def apply( self, - a1q: torch.Tensor, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, @@ -228,16 +60,56 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: + a1q = hidden_states + + assert w1.dtype == torch.float8_e4m3fn + assert w2.dtype == torch.float8_e4m3fn + assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1" + assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim( + ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ + 0], "Input scale shape mismatch" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1.shape[2], "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2.shape[2], "W2 scale shape mismatch" + assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" + assert w1.shape[0] == w1_scale.shape[ + 0], "w1 scales expert number mismatch" + assert w1.shape[0] == w2_scale.shape[ + 0], "w2 scales expert number mismatch" + assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 + assert self.ab_strides1.shape[0] == w1.shape[ + 0], "AB Strides 1 expert number mismatch" + assert self.c_strides1.shape[0] == w1.shape[ + 0], "C Strides 1 expert number mismatch" + assert self.ab_strides2.shape[0] == w2.shape[ + 0], "AB Strides 2 expert number mismatch" + assert self.c_strides2.shape[0] == w2.shape[ + 0], "C Strides 2 expert number mismatch" + assert self.out_dtype in [torch.half, + torch.bfloat16], "Invalid output dtype" + M = a1q.shape[0] _, N, K = w2.shape # because w1 + w2 are transposed - topk = topk_ids.shape[1] device = a1q.device assert w1.shape[1] == K assert global_num_experts != -1 assert a1q_scale is not None + if expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where(expert_map[topk_ids] != -1, + expert_map[topk_ids], -1) + else: + local_topk_ids = topk_ids + + topk = local_topk_ids.shape[1] + per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) @@ -251,21 +123,29 @@ def apply( dtype=torch.int32, device=device) - a_map = torch.empty((topk_ids.numel()), - dtype=torch.int32, - device=device) - c_map = torch.empty((topk_ids.numel()), + # With expert_map each Rank processes only a subset of experts. As + # a result not all of a_map and c2 tensors are filled. We fill it + # zeros for correctness. + if expert_map is not None: + a_map = torch.zeros((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + else: + a_map = torch.empty((local_topk_ids.numel()), + dtype=torch.int32, + device=device) + + c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device) - ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, - global_num_experts, N, K) + ops.get_cutlass_moe_mm_data(local_topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, a_map, + c_map, global_num_experts, N, K) a1q = _fp8_perm(a1q, a_map) a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale - # fix names c1 = _resize_cache(workspace13, (M * topk, N * 2)) c2 = _resize_cache(workspace2, (M * topk, N)) c3 = _resize_cache(workspace13, (M * topk, K)) @@ -274,16 +154,14 @@ def apply( expert_offsets[:-1], problem_sizes1, self.ab_strides1, self.ab_strides1, self.c_strides1) - if activation == "silu": - torch.ops._C.silu_and_mul(c2, c1) - elif activation == "gelu": - torch.ops._C.gelu_and_mul(c2, c1) - else: - raise ValueError(f"Unsupported FusedMoe activation: {activation}") + self.activation(activation, c2, c1) a2q, a2q_scale = ops.scaled_fp8_quant( c2, a2_scale, use_per_token_if_dynamic=per_act_token) + if expert_map is not None: + c3.fill_(0) + ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets[:-1], problem_sizes2, self.ab_strides2, self.ab_strides2, self.c_strides2) @@ -294,6 +172,7 @@ def apply( def modular_cutlass_moe_fp8( + per_act_token: bool, ab_strides1: torch.Tensor, c_strides1: torch.Tensor, ab_strides2: torch.Tensor, @@ -301,7 +180,10 @@ def modular_cutlass_moe_fp8( out_dtype: torch.dtype = torch.half, ) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( - StandardDispatchCombine(quant_dtype=torch.float8_e4m3fn), + StandardDispatchCombine( + per_channel_quant=per_act_token, + quant_dtype=torch.float8_e4m3fn, + ), CutlassExperts( ab_strides1, c_strides1, @@ -328,6 +210,8 @@ def cutlass_moe_fp8( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.half, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -361,25 +245,39 @@ def cutlass_moe_fp8( quantize the intermediate result between the gemms. Shape: scalar or [M] - out_dtype (torch.dtype): The output tensor type. + - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, + every Rank is responsible for a subset of experts. expert_map is a + mapping from global expert-id to local expert-id. When expert_map[i] + is -1, it means that this Rank is not responsible for global + expert-id i. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is 1. Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + fn = modular_cutlass_moe_fp8( + per_act_token, ab_strides1, c_strides1, ab_strides2, c_strides2, out_dtype, ) + return fn( a, w1_q, w2_q, topk_weights, topk_ids, + expert_map=expert_map, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index e43c984f7d5..266ba3bfa07 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -91,7 +91,7 @@ def workspace_shapes( def apply( self, - a1q: torch.Tensor, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, @@ -110,6 +110,7 @@ def apply( ) -> torch.Tensor: import deep_gemm as dg + a1q = hidden_states _, N, K = w1.shape assert global_num_experts != -1 @@ -137,7 +138,8 @@ def apply( a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, self.block_shape) + a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False, + self.block_shape) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) @@ -169,6 +171,7 @@ def deep_gemm_moe_fp8( expert_map: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input=False, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -222,4 +225,5 @@ def deep_gemm_moe_fp8( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/fused_moe/dispatch_combine.py b/vllm/model_executor/layers/fused_moe/dispatch_combine.py index 398aab60c66..9b647a70d5e 100644 --- a/vllm/model_executor/layers/fused_moe/dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/dispatch_combine.py @@ -6,15 +6,20 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_unpermute_and_reduce) -from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) class StandardDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, - quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): + def __init__( + self, + quant_dtype: Optional[torch.dtype] = None, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None, + ): super().__init__() + self.per_channel_quant = per_channel_quant self.block_shape = block_shape self.quant_dtype = quant_dtype @@ -23,24 +28,23 @@ def dispatch( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - if self.quant_dtype == torch.float8_e4m3fn: - per_act_token = a1_scale.numel( - ) != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) - - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - per_act_token, - ) - else: - a1q = a1 - a1q_scale = a1_scale + if apply_router_weight_on_input: + topk = topk_ids.shape[1] + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + self.quant_dtype, + self.per_channel_quant, + self.block_shape) return a1q, a1q_scale, None @@ -50,6 +54,7 @@ def combine( fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, ) -> None: _moe_unpermute_and_reduce(output, fused_expert_output, None, - topk_weights) + topk_weights, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 92d6d24799b..59a765f900c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -17,12 +17,8 @@ StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) -from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_group_quant_int8, per_token_quant_int8) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, - _resize_cache) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, moe_kernel_quantize_input) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -967,6 +963,20 @@ def get_config_dtype_str( return None +# TODO: use scalar_type? +def get_config_qtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, +) -> Optional[torch.dtype]: + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + return None + + def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1156,6 +1166,7 @@ def fused_experts(hidden_states: torch.Tensor, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, ) else: return dispatch_fused_experts_func(inplace)( @@ -1182,59 +1193,6 @@ def fused_experts(hidden_states: torch.Tensor, block_shape=block_shape) -def moe_kernel_prepare_input( - A: torch.Tensor, - B: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[List[int]] = None, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if use_fp8_w8a8: - assert B_scale is not None - if block_shape is None: - # If weights are per-channel (per_channel_quant=True), then - # activations apply per-token quantization. Otherwise, assume - # activation tensor-wise fp8 quantization, dynamic or static - A, A_scale = ops.scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_channel_quant) - else: - # activation block-wise fp8 quantization - assert len(block_shape) == 2 - _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_fp8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a8: - assert B_scale is not None - if block_shape is None: - # activation channel-wise int8 quantization - assert (per_channel_quant - ), "int8 quantization only supports block or channel-wise" - A, A_scale = per_token_quant_int8(A) - else: - # activation block-wise int8 quantization - assert len(block_shape) == 2 - _, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_int8(A, block_k) - assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] - # assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] - # assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a16 or use_int4_w4a16: - assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 - else: - assert A_scale is None - assert B_scale is None - - return A, A_scale - - def fused_experts_impl(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1288,6 +1246,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype) + qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16) + get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, @@ -1350,15 +1313,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - qcurr_hidden_states, qa1_scale = moe_kernel_prepare_input( + qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, - B=w1, A_scale=a1_scale, - B_scale=w1_scale, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, + qtype=qtype, per_channel_quant=per_channel_quant, block_shape=block_shape) @@ -1369,7 +1327,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, - qa1_scale, + a1q_scale, w1_scale, w1_zp, curr_topk_weights, @@ -1396,22 +1354,17 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: raise ValueError(f"Unsupported FusedMoe activation: {activation}") - qintermediate_cache2, qa2_scale = moe_kernel_prepare_input( + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( A=intermediate_cache2, - B=w2, A_scale=a2_scale, - B_scale=w2_scale, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, + qtype=qtype, per_channel_quant=per_channel_quant, block_shape=block_shape) invoke_fused_moe_kernel(qintermediate_cache2, w2, intermediate_cache3, - qa2_scale, + a2q_scale, w2_scale, w2_zp, curr_topk_weights, @@ -1553,17 +1506,25 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + per_channel_quant: bool, block_shape: Optional[List[int]] = None, block_m: Optional[int] = None, ): super().__init__() self.use_fp8_w8a8 = use_fp8_w8a8 self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a8 = use_int8_w8a8 self.use_int8_w8a16 = use_int8_w8a16 self.block_shape = block_shape self.block_m = block_m + self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16) + self.per_channel_quant = per_channel_quant def workspace_shapes( self, @@ -1674,8 +1635,10 @@ def apply( config, compute_type=compute_type, use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) self.activation(activation, intermediate_cache2, @@ -1683,12 +1646,9 @@ def apply( a2q_scale: Optional[torch.Tensor] = None - if self.use_fp8_w8a8: - qintermediate_cache2, a2q_scale = _fp8_quantize( - intermediate_cache2, a2_scale, self.block_shape) - else: - qintermediate_cache2 = intermediate_cache2 - a2q_scale = a2_scale + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant, + self.block_shape) invoke_fused_moe_kernel(qintermediate_cache2, w2, @@ -1705,8 +1665,10 @@ def apply( config, compute_type=compute_type, use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) return intermediate_cache3 @@ -1714,18 +1676,30 @@ def apply( def modular_triton_fused_moe( use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + per_channel_quant: bool, block_shape: Optional[List[int]] = None, ) -> mk.FusedMoEModularKernel: + qtype = get_config_qtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + ) return mk.FusedMoEModularKernel( StandardDispatchCombine( - quant_dtype=torch.float8_e4m3fn if use_fp8_w8a8 else None, - block_shape=block_shape), + quant_dtype=qtype, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + ), TritonExperts( - use_fp8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - block_shape, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, ), ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b517f6ee13c..aab7658ae64 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -84,9 +84,11 @@ def dispatch( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform any quantization (and/or) dispatching needed @@ -95,7 +97,8 @@ def dispatch( - a1_scale: Optional scales for a1 - a2_scale: Optional scales for the second MoE gemm. Required to make sure the quantization is consistent for both gemms. - - topk_ids: The topk_ids. + - topk_ids: The topk ids. + - topk_weights: The topk weights. - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. @@ -113,6 +116,7 @@ def combine( fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, ) -> None: """ Perform any combine plus apply weights and perform a reduction on the @@ -169,7 +173,7 @@ def activation(self, activation: str, output: torch.Tensor, @abstractmethod def apply( self, - a1q: torch.Tensor, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor, @@ -191,7 +195,8 @@ def apply( (MoE) layer using two sets of weights, w1 and w2. Parameters: - - a1q: (torch.Tensor): The (quantized) input tensor to the MoE layer. + - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE + layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - topk_ids (torch.Tensor): A map of row to expert id. @@ -263,6 +268,7 @@ def forward( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -292,6 +298,9 @@ def forward( w2. - a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1. - a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2. + - apply_router_weight_on_input (bool): When true, the topk weights are + applied directly on the inputs. This is only applicable when topk is + 1. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -318,34 +327,29 @@ def forward( dtype=workspace_dtype) a1q, a1q_scale, expert_num_tokens = self.dispatch_combine.dispatch( - a1, - a1_scale, - a2_scale, - topk_ids, - global_num_experts, - expert_map, - ) + a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, + expert_map, apply_router_weight_on_input) fused_out = self.fused_experts.apply( a1q, w1, w2, topk_ids, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1q_scale, - a2_scale, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, expert_num_tokens=expert_num_tokens, ) self.dispatch_combine.combine(output, fused_out, topk_weights, - topk_ids) + topk_ids, apply_router_weight_on_input) return output diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index ad9e149f5d7..cba1e0ef506 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -55,6 +55,7 @@ def _moe_unpermute_and_reduce( curr_hidden: torch.Tensor, inv_perm: Optional[torch.Tensor], topk_weight: torch.Tensor, + apply_router_weight_on_input: bool, ) -> None: """ Unpermute the final result and apply topk_weights, then perform the final @@ -65,7 +66,8 @@ def _moe_unpermute_and_reduce( if inv_perm is not None: curr_hidden = curr_hidden[inv_perm, ...] curr_hidden = curr_hidden.view(-1, topk, K) - curr_hidden.mul_(topk_weight.view(M, -1, 1)) + if not apply_router_weight_on_input: + curr_hidden.mul_(topk_weight.view(M, -1, 1)) ops.moe_sum(curr_hidden, out) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index d35cfaccd39..90a4833948f 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -6,7 +6,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) # Note use: layer.get_all_to_all() to get an AllToAll instance @@ -34,27 +35,33 @@ def dispatch( a1: torch.Tensor, a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + rank_topk_weights: torch.Tensor, rank_topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Is this always going to be a1.device? device = a1.device - if self.quant_dtype == torch.float8_e4m3fn: - per_act_token = a1_scale.numel( - ) != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + assert expert_map is None, "NYI" - a1q, a1q_scale = _fp8_quantize( - a1, - a1_scale, - self.block_shape, - per_act_token, - ) - else: - a1q = a1 - a1q_scale = a1_scale + # TBD + assert not apply_router_weight_on_input + if apply_router_weight_on_input: + topk = rank_topk_ids.shape[1] + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1 = a1 * rank_topk_weights.to(a1.dtype) + + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + + a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, + self.quant_dtype, + per_act_token, + self.block_shape) expert_num_tokens = torch.empty( num_experts, @@ -103,6 +110,7 @@ def combine( fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, ) -> None: # This argument is optional bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -110,6 +118,11 @@ def combine( assert output.shape[0] == self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] + # Set weights to 1? + assert not apply_router_weight_on_input + if apply_router_weight_on_input: + topk_weights = torch.ones_like(topk_weights) + self.a2a.combine(out_tokens=output, indices=topk_ids, weights=topk_weights, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index eff39c0f792..d53da1d7926 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -7,6 +7,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.int8_utils import ( + per_token_group_quant_int8, per_token_quant_int8) from vllm.utils import cdiv @@ -22,8 +24,8 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], + per_act_token: bool, block_shape: Optional[List[int]] = None, - per_act_token: bool = False, # make sure this is the same default as op ) -> Tuple[torch.Tensor, torch.Tensor]: """ Perform fp8 quantization on the inputs. If a block_shape @@ -37,9 +39,53 @@ def _fp8_quantize( _, block_k = block_shape[0], block_shape[1] A, A_scale = per_token_group_quant_fp8(A, block_k) assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + + return A, A_scale + + +def _int8_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token: bool, + block_shape: Optional[List[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform int8 quantization on the inputs. If a block_shape + is provided, the output will be blocked. + """ + + # If weights are per-channel (per_channel_quant=True), then + # activations apply per-token quantization. Otherwise, assume + # activation tensor-wise fp8/int8 quantization, dynamic or static + if block_shape is None: + assert per_act_token, \ + "int8 quantization only supports block or channel-wise" + A, A_scale = per_token_quant_int8(A) + else: + assert len(block_shape) == 2 + _, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_int8(A, block_k) + assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + return A, A_scale +def moe_kernel_quantize_input( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + qtype: Optional[torch.dtype], + per_channel_quant: bool, + block_shape: Optional[List[int]] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if qtype == torch.float8_e4m3fn: + return _fp8_quantize(A, A_scale, per_channel_quant, block_shape) + elif qtype == torch.int8: + return _int8_quantize(A, A_scale, per_channel_quant, block_shape) + else: + assert A_scale is None + return A, A_scale + + def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: """ A permutation routine that works on fp8 types. From bdcaae24e75fc9313fb123260e6461f933d28854 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 02:28:00 +0000 Subject: [PATCH 027/190] fix lint Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 2 + .../layers/fused_moe/pplx_dispatch_combine.py | 43 ++++++++++++------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d718ac1f3f3..e52751eddf2 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -64,6 +64,8 @@ def apply( ) -> torch.Tensor: a1q = hidden_states + assert w1_scale is not None + assert w2_scale is not None assert w1.dtype == torch.float8_e4m3fn assert w2.dtype == torch.float8_e4m3fn assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1" diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 90a4833948f..658705515b4 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -5,15 +5,13 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) # Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same -# as the ones used to create the AllToAll. Unfortunately, there's -# no way(?) to extract this info from AllToAll +# as the ones used to create the AllToAll. class PplxDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): def __init__(self, @@ -21,13 +19,16 @@ def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, + rank: int, quant_dtype: Optional[torch.dtype] = None, block_shape: Optional[List[int]] = None): super().__init__() self.a2a = a2a self.block_shape = block_shape self.max_num_tokens = max_num_tokens - self.dp_num_tokens = max_num_tokens * (world_size // dp_size) + self.world_size = world_size + self.dp_size = dp_size + self.rank = rank self.quant_dtype = quant_dtype def dispatch( @@ -39,8 +40,8 @@ def dispatch( rank_topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + apply_router_weight_on_input: bool, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: # Is this always going to be a1.device? device = a1.device @@ -63,14 +64,19 @@ def dispatch( per_act_token, self.block_shape) + rem_experts = num_experts % self.world_size + num_local_experts = ((num_experts // self.world_size) + + (1 if self.rank < rem_experts else 0)) + expert_num_tokens = torch.empty( - num_experts, + num_local_experts, dtype=torch.int32, device=device, ) + num_dp = self.world_size // self.dp_size expert_x = torch.empty( - (num_experts, self.dp_num_tokens, a1q.shape[-1]), + (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, device=device, ) @@ -90,8 +96,14 @@ def dispatch( device=device, ) - # This argument is optional - bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + # This argument is optional, defaults to indices.shape[0] + # This causes a deadlock???? + #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) + bound_m = None + + # TODO: optimize this? + indices = rank_topk_ids.to(dtype=torch.uint32) self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -99,10 +111,10 @@ def dispatch( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=rank_topk_ids, + indices=indices, bound_m=bound_m, ) - return expert_x, expert_x_scale + return expert_x, expert_x_scale, expert_num_tokens def combine( self, @@ -113,9 +125,10 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + bound_m = None - assert output.shape[0] == self.max_num_tokens + assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] # Set weights to 1? @@ -124,7 +137,7 @@ def combine( topk_weights = torch.ones_like(topk_weights) self.a2a.combine(out_tokens=output, - indices=topk_ids, + indices=topk_ids.to(torch.uint32), weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) From e3ba64fd9fad5c945a1eb4dadb76e5cee6870f7e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 03:21:26 +0000 Subject: [PATCH 028/190] add pplx tests Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 598 ++++++++++++++++++ .../layers/fused_moe/fused_batched_moe.py | 175 +++++ .../layers/fused_moe/fused_moe.py | 14 + 3 files changed, 787 insertions(+) create mode 100644 tests/kernels/moe/test_pplx_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/fused_batched_moe.py diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py new file mode 100644 index 00000000000..cab9990b16b --- /dev/null +++ b/tests/kernels/moe/test_pplx_moe.py @@ -0,0 +1,598 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE layers. + +Run `pytest tests/kernels/test_pplx_moe.py`. +""" +import dataclasses +import os +import pytest +import torch +import traceback + +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] +from typing import Callable, Concatenate, Optional, ParamSpec, Tuple + +from pplx_kernels import AllToAll +from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, +) + +import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, + torch_moe, torch_moe_single) +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) +from vllm.platforms import current_platform + +from vllm.model_executor.layers.activation import SiluAndMul + +from vllm.model_executor.layers.fused_moe.fused_moe import ( + TritonExperts, fused_experts) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedDispatchCombine, BatchedExperts) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] +TOP_KS = [2, 6] + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +P = ParamSpec("P") + +require_multi_node = pytest.mark.skipif( + "MASTER_ADDR" not in os.environ, + reason="Requires multi-node environment", +) + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exception(ex) + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + + args, + nprocs=world_local_size, + join=True, + ) + + +def torch_dispatch( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + max_num_tokens: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + num_tokens = a.shape[0] + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + if max_num_tokens is None: + max_num_tokens = tokens_per_expert.max() + + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), + dtype=a.dtype, device=a.device) + + #print(f"b_a shape {b_a.shape}") + + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = token_counts[expert_id] + b_a[expert_id, idx:idx+1, :] = a[token, :] + token_counts[expert_id] = token_counts[expert_id] + 1 + + return b_a, tokens_per_expert + + +def torch_combine(b_out, topk_weight, topk_ids): + num_tokens, topk = topk_ids.shape + num_experts = b_out.shape[0] + K = b_out.shape[-1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return out + + +def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): + num_experts = w1.shape[0] + b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and K == w2.shape[1] + out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) + + return torch_combine(out, topk_weight, topk_ids) + + +# TODO: same as torch_moe but with fused_topk factored out. +def torch_moe2(a, w1, w2, topk_weight, topk_ids): + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + if True: + triton_output = torch_batched_moe(a, + w1, + w2, + topk_weight, + topk_ids) + else: + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) + triton_output = fused_batched_experts( + b_a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e + ) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +def chunk_by_rank(t, r, w): + chunk = rank_chunk(t.shape[0], r, w) + #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") + return t[(r * chunk):(r + 1)*chunk] + + +def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): + assert torch.cuda.current_device() == pgi.local_rank + + num_tokens, hidden_dim = a.shape + num_experts = w1.shape[0] + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + rank_num_tokens = rank_chunk(num_tokens, rank, world_size) + max_num_tokens = num_tokens + #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") + + ata = AllToAll( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=pgi.world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=( + 0 + if a.dtype.itemsize != 1 + else ( + (hidden_dim + block_size - 1) + // block_size + * torch.float32.itemsize + ) + ), + ) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + pgi.world_size, + dp_size, + rank, + a.dtype, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + score_chunk = chunk_by_rank(scores, rank, world_size).to(device) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + + b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( + a_chunk, + None, + None, + chunk_topk_weight, + chunk_topk_ids, + num_experts, # store at PplxDispatchCombine creation? + None, + False, + ) + + naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) + + torch.distributed.all_reduce(tokens_per_expert) + tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) + + torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) + + b_a = b_a * 1.5 + + out = torch.full( + (rank_num_tokens * world_size, hidden_dim), + torch.nan, + dtype=a.dtype, + device=device, + ) + + dispatch_combine.combine( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + False, + ) + torch.cuda.synchronize() + + ata.destroy() + + return out[:rank_num_tokens] + + +def _pplx_dispatch_combine( + pgi: ProcessGroupInfo, + dp_size: int, + m, n, k, e, + topk: int, + dtype: torch.dtype, +): + uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + device = pgi.device + + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + a_rep = torch.repeat_interleave(a, topk, dim=0) + + torch_output = (a_rep.view(-1, topk, k) * 1.5 * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) + + pplx_output = torch_pplx_dispatch_combine(pgi, + dp_size, + a, + w1, + w2, + score, + topk) + + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +@pytest.mark.parametrize("m", [4, 32, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) +def test_pplx_dispatch_combine( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: Tuple[int, int], +): + current_platform.seed_everything(7) + world_size, dp_size = world_dp_size + + parallel_launch( + world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype + ) + + +def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): + assert torch.cuda.current_device() == pgi.local_rank + + num_tokens, hidden_dim = a.shape + num_experts = w1.shape[0] + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + rank_num_tokens = rank_chunk(num_tokens, rank, world_size) + max_num_tokens = num_tokens + + ata = AllToAll( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=pgi.world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=( + 0 + if a.dtype.itemsize != 1 + else ( + (hidden_dim + block_size - 1) + // block_size + * torch.float32.itemsize + ) + ), + ) + + w1 = w1.to(device) + w2 = w2.to(device) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + pgi.world_size, + dp_size, + rank, + a.dtype, + ) + + experts = BatchedExperts(rank, pgi.world_size, max_num_tokens) + + fused_experts = FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + score_chunk = chunk_by_rank(scores, rank, world_size).to(device) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + + out = fused_experts( + a_chunk, + # Chunking weights like this only works for batched format + chunk_by_rank(w1, rank, world_size), + chunk_by_rank(w2, rank, world_size), + #w1, + #w2, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts #? num_local_experts? + ) + + torch.cuda.synchronize() + + ata.destroy() + + return out[:rank_num_tokens] + + +def _pplx_moe( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + dtype: torch.dtype, +): + uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + + m, k = a.shape + e, _, n = w2.shape + + torch.set_printoptions(profile="full") + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + pplx_output = torch_pplx_moe(pgi, + dp_size, + a, + w1, + w2, + score, + topk) + + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +# TODO: M == 1 doesn't work +@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +def test_pplx_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: Tuple[int, int], +): + current_platform.seed_everything(7) + world_size, dp_size = world_dp_size + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + parallel_launch( + world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, dtype + ) + diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py new file mode 100644 index 00000000000..a39d08b8376 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused batched MoE kernel.""" +from typing import List, Optional, Tuple + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import _resize_cache + + +class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, + world_size: int, + rank: int): + super().__init__() + self.world_size = world_size + self.rank = rank + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a1.shape[0] + + if apply_router_weight_on_input: + topk = topk_ids.shape[1] + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + num_tokens = a1.shape[0] + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + max_num_tokens = tokens_per_expert.max() + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) + + b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), + dtype=a1.dtype, device=a1.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = expert_counts[expert_id] + b_a1[expert_id, idx:idx+1, :] = a1[token, :] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return b_a1, a1_scale, tokens_per_expert + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + num_tokens = topk_ids.shape[0] + num_experts = fused_expert_output.shape[0] + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=fused_expert_output.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(topk_ids.shape[1]): + expert_id = expert_ids[i] + if expert_id < num_experts: + idx = expert_counts[expert_id] + if apply_router_weight_on_input: + output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] + else: + output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] * topk_weights[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + rank: int = 0, + world_size: int = 1, + max_num_tokens: Optional[int] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + ): + super().__init__() + assert not use_fp8_w8a8 + assert not use_int4_w4a16 + assert not use_int8_w8a16 + assert block_shape is None + assert block_m is None + self.max_num_tokens = max_num_tokens + self.rank = rank + self.world_size = world_size + assert not use_fp8_w8a8, "NYI" + assert not use_int8_w8a8, "NYI" + assert not use_int8_w8a16, "NYI" + assert not use_int4_w4a16, "NYI" + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens + workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack + workspace2 = max_num_tokens * N + return (workspace13, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + assert hidden_states.dim() == 3 + assert expert_num_tokens is not None + num_tokens = topk_ids.shape[0] + _, tmp_max_num_tokens, K = hidden_states.shape + max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens + num_experts = global_num_experts + out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) + num_local_experts = expert_num_tokens.numel() + + # TODO: don't need world_size or rank if expert_base always == 0 + #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" + #expert_base = rank_chunk(w1.shape[0], self.rank, self.world_size) * self.rank + expert_base = 0 + + for expert in range(num_local_experts): + num = expert_num_tokens[expert] + assert num <= max_num_tokens, f"{num}, {max_num_tokens}" + if num > 0: + tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) + self.activation( + activation, + tmp, + hidden_states[expert,:num,:] @ w1[expert_base + expert].transpose(0, 1) + ) + out[expert, :num, :] = tmp @ w2[expert_base + expert].transpose(0, 1) + + return out diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 59a765f900c..2650f5a6064 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -485,6 +485,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + if use_fp8_w8a8: + assert B_scale is not None + assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) + == B_scale.shape[-2]) + assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) + == B_scale.shape[-1]) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + M = A.shape[0] num_tokens = M * top_k From 897f3a1d32739985626e5268bf4198763410daf2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 03:26:54 +0000 Subject: [PATCH 029/190] lint Signed-off-by: Bill Nell --- .../cutlass_benchmarks/w8a8_benchmarks.py | 2 +- tests/kernels/moe/test_pplx_moe.py | 197 ++++++++---------- .../layers/fused_moe/fused_batched_moe.py | 60 ++++-- 3 files changed, 123 insertions(+), 136 deletions(-) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index e7b742d8bec..09462560f40 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -11,9 +11,9 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement -from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES +from utils import make_rand_tensors from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( w8a8_block_fp8_matmul) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index cab9990b16b..97ecf141851 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -5,37 +5,29 @@ """ import dataclasses import os -import pytest -import torch import traceback +from typing import Callable, Concatenate, Optional, ParamSpec -from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing import Callable, Concatenate, Optional, ParamSpec, Tuple - +import pytest +import torch from pplx_kernels import AllToAll -from pplx_kernels.nvshmem import ( - nvshmem_alloc_empty_unique_id, - nvshmem_finalize, - nvshmem_get_unique_id, - nvshmem_init, -) +from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, - torch_moe, torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) -from vllm.platforms import current_platform - from vllm.model_executor.layers.activation import SiluAndMul - -from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, fused_experts) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedDispatchCombine, BatchedExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel -from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine + BatchedExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import ( + PplxDispatchCombine) +from vllm.platforms import current_platform NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] @@ -122,8 +114,7 @@ def parallel_launch( 0, "tcp://localhost:29500", worker, - ) - + args, + ) + args, nprocs=world_size, join=True, ) @@ -157,8 +148,7 @@ def parallel_launch_from_env( node_rank, "env://", worker, - ) - + args, + ) + args, nprocs=world_local_size, join=True, ) @@ -169,19 +159,21 @@ def torch_dispatch( topk_ids: torch.Tensor, num_experts: int, max_num_tokens: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] num_tokens = a.shape[0] topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) if max_num_tokens is None: max_num_tokens = tokens_per_expert.max() b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), - dtype=a.dtype, device=a.device) + dtype=a.dtype, + device=a.device) #print(f"b_a shape {b_a.shape}") @@ -191,7 +183,7 @@ def torch_dispatch( for j in range(topk): expert_id = topk_ids[token, j] idx = token_counts[expert_id] - b_a[expert_id, idx:idx+1, :] = a[token, :] + b_a[expert_id, idx:idx + 1, :] = a[token, :] token_counts[expert_id] = token_counts[expert_id] + 1 return b_a, tokens_per_expert @@ -202,13 +194,16 @@ def torch_combine(b_out, topk_weight, topk_ids): num_experts = b_out.shape[0] K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(expert_ids.numel()): expert_id = expert_ids[i] idx = expert_counts[expert_id] - out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx + + 1, :] * topk_weight[token, i] expert_counts[expert_id] = expert_counts[expert_id] + 1 return out @@ -220,13 +215,18 @@ def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): assert b_a.dim() == 3 num_tokens, topk = topk_ids.shape _, max_num_tokens, K = b_a.shape - assert num_experts == b_a.shape[0] and K == w2.shape[1] - out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) - tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) + assert num_experts == b_a.shape[0] and w2.shape[1] == K + out = torch.zeros((num_experts, max_num_tokens, K), + dtype=b_a.dtype, + device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), + dtype=b_a.dtype, + device=b_a.device) for expert in range(num_experts): num = tokens_per_expert[expert] if num > 0: - torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) + torch.ops._C.silu_and_mul( + tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)) out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) return torch_combine(out, topk_weight, topk_ids) @@ -249,7 +249,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -272,25 +272,8 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - - if True: - triton_output = torch_batched_moe(a, - w1, - w2, - topk_weight, - topk_ids) - else: - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - triton_output = fused_batched_experts( - b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e - ) + triton_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -303,7 +286,7 @@ def rank_chunk(num, r, w): def chunk_by_rank(t, r, w): chunk = rank_chunk(t.shape[0], r, w) #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") - return t[(r * chunk):(r + 1)*chunk] + return t[(r * chunk):(r + 1) * chunk] def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): @@ -317,7 +300,6 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): world_size = pgi.world_size rank_num_tokens = rank_chunk(num_tokens, rank, world_size) max_num_tokens = num_tokens - #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") ata = AllToAll( max_num_tokens=max_num_tokens, @@ -328,15 +310,9 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=( - 0 - if a.dtype.itemsize != 1 - else ( - (hidden_dim + block_size - 1) - // block_size - * torch.float32.itemsize - ) - ), + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), ) dispatch_combine = PplxDispatchCombine( @@ -350,7 +326,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): a_chunk = chunk_by_rank(a, rank, world_size).to(device) score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, + False) b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -358,17 +335,22 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): None, chunk_topk_weight, chunk_topk_ids, - num_experts, # store at PplxDispatchCombine creation? + num_experts, # store at PplxDispatchCombine creation? None, False, ) - naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) + naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, + num_experts) torch.distributed.all_reduce(tokens_per_expert) - tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) + tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, + world_size).to(dtype=torch.int32) - torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) + torch.testing.assert_close(tokens_per_expert, + expert_num_tokens, + atol=0, + rtol=0) b_a = b_a * 1.5 @@ -396,11 +378,15 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, - m, n, k, e, + m, + n, + k, + e, topk: int, dtype: torch.dtype, ): - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) device = pgi.device @@ -414,17 +400,14 @@ def _pplx_dispatch_combine( a_rep = torch.repeat_interleave(a, topk, dim=0) - torch_output = (a_rep.view(-1, topk, k) * 1.5 * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) + torch_output = (a_rep.view(-1, topk, k) * 1.5 * + topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) - pplx_output = torch_pplx_dispatch_combine(pgi, - dp_size, - a, - w1, - w2, - score, + pplx_output = torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, score, topk) - torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) @@ -437,7 +420,7 @@ def _pplx_dispatch_combine( @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) def test_pplx_dispatch_combine( m: int, n: int, @@ -445,14 +428,13 @@ def test_pplx_dispatch_combine( e: int, topk: int, dtype: torch.dtype, - world_dp_size: Tuple[int, int], + world_dp_size: tuple[int, int], ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch( - world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype - ) + parallel_launch(world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, + topk, dtype) def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): @@ -476,15 +458,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=( - 0 - if a.dtype.itemsize != 1 - else ( - (hidden_dim + block_size - 1) - // block_size - * torch.float32.itemsize - ) - ), + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), ) w1 = w1.to(device) @@ -508,7 +484,8 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): a_chunk = chunk_by_rank(a, rank, world_size).to(device) score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, + False) out = fused_experts( a_chunk, @@ -519,7 +496,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): #w2, chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_experts #? num_local_experts? + global_num_experts=num_experts #? num_local_experts? ) torch.cuda.synchronize() @@ -539,7 +516,8 @@ def _pplx_moe( topk: int, dtype: torch.dtype, ): - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) @@ -553,15 +531,10 @@ def _pplx_moe( torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplx_output = torch_pplx_moe(pgi, - dp_size, - a, - w1, - w2, - score, - topk) + pplx_output = torch_pplx_moe(pgi, dp_size, a, w1, w2, score, topk) - torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) @@ -575,7 +548,7 @@ def _pplx_moe( @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) def test_pplx_moe( m: int, n: int, @@ -583,7 +556,7 @@ def test_pplx_moe( e: int, topk: int, dtype: torch.dtype, - world_dp_size: Tuple[int, int], + world_dp_size: tuple[int, int], ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size @@ -592,7 +565,5 @@ def test_pplx_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - parallel_launch( - world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, dtype - ) - + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, + dtype) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index a39d08b8376..56b1b343c86 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -9,9 +9,8 @@ class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, - world_size: int, - rank: int): + + def __init__(self, world_size: int, rank: int): super().__init__() self.world_size = world_size self.rank = rank @@ -40,18 +39,22 @@ def dispatch( num_tokens = a1.shape[0] topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) max_num_tokens = tokens_per_expert.max() - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=a1.device) b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), - dtype=a1.dtype, device=a1.device) + dtype=a1.dtype, + device=a1.device) for token in range(num_tokens): for j in range(topk): expert_id = topk_ids[token, j] idx = expert_counts[expert_id] - b_a1[expert_id, idx:idx+1, :] = a1[token, :] + b_a1[expert_id, idx:idx + 1, :] = a1[token, :] expert_counts[expert_id] = expert_counts[expert_id] + 1 return b_a1, a1_scale, tokens_per_expert @@ -66,7 +69,9 @@ def combine( ) -> None: num_tokens = topk_ids.shape[0] num_experts = fused_expert_output.shape[0] - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=fused_expert_output.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=fused_expert_output.device) for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(topk_ids.shape[1]): @@ -74,9 +79,14 @@ def combine( if expert_id < num_experts: idx = expert_counts[expert_id] if apply_router_weight_on_input: - output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] + output[token, :] = output[ + token, :] + fused_expert_output[expert_id, + idx:idx + 1, :] else: - output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] * topk_weights[token, i] + output[ + token, :] = output[token, :] + fused_expert_output[ + expert_id, + idx:idx + 1, :] * topk_weights[token, i] expert_counts[expert_id] = expert_counts[expert_id] + 1 @@ -122,8 +132,10 @@ def workspace_shapes( topk: int, num_experts: int, ) -> Tuple[int, int, torch.dtype]: - max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack + max_num_tokens = a.shape[ + 1] if self.max_num_tokens is None else self.max_num_tokens + # TODO: *2 is a hack + workspace13 = num_experts * max_num_tokens * K * topk * 2 workspace2 = max_num_tokens * N return (workspace13, workspace2, a.dtype) @@ -148,16 +160,21 @@ def apply( ) -> torch.Tensor: assert hidden_states.dim() == 3 assert expert_num_tokens is not None - num_tokens = topk_ids.shape[0] - _, tmp_max_num_tokens, K = hidden_states.shape - max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens + + if self.max_num_tokens is None: + max_num_tokens = hidden_states.shape[1] + else: + max_num_tokens = self.max_num_tokens + num_experts = global_num_experts - out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) + out = _resize_cache(workspace13, + (num_experts, max_num_tokens, w2.shape[1])) num_local_experts = expert_num_tokens.numel() # TODO: don't need world_size or rank if expert_base always == 0 #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" - #expert_base = rank_chunk(w1.shape[0], self.rank, self.world_size) * self.rank + #expert_base = rank_chunk(w1.shape[0], self.rank, + # self.world_size) * self.rank expert_base = 0 for expert in range(num_local_experts): @@ -166,10 +183,9 @@ def apply( if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation( - activation, - tmp, - hidden_states[expert,:num,:] @ w1[expert_base + expert].transpose(0, 1) - ) - out[expert, :num, :] = tmp @ w2[expert_base + expert].transpose(0, 1) + activation, tmp, hidden_states[expert, :num, :] + @ w1[expert_base + expert].transpose(0, 1)) + out[expert, :num, :] = tmp @ w2[expert_base + + expert].transpose(0, 1) return out From 3feab7222d5324f02db2d80fa3c44521ffe57659 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 13:14:05 +0000 Subject: [PATCH 030/190] undo random lint changes Signed-off-by: Bill Nell --- benchmarks/cutlass_benchmarks/w8a8_benchmarks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 09462560f40..e7b742d8bec 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -11,9 +11,9 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES -from utils import make_rand_tensors from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( w8a8_block_fp8_matmul) From cf3697797d33d34c38c5a436fb32b7e29afc00d2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 13:34:40 +0000 Subject: [PATCH 031/190] more lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 97ecf141851..f0dabd66fea 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -6,7 +6,7 @@ import dataclasses import os import traceback -from typing import Callable, Concatenate, Optional, ParamSpec +from typing import Callable, Optional import pytest import torch @@ -16,6 +16,7 @@ nvshmem_init) from torch.multiprocessing import ( spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec import vllm.model_executor.layers.fused_moe # noqa from vllm.config import VllmConfig, set_current_vllm_config @@ -169,7 +170,7 @@ def torch_dispatch( tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) if max_num_tokens is None: - max_num_tokens = tokens_per_expert.max() + max_num_tokens = int(tokens_per_expert.max().item()) b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), dtype=a.dtype, From f861876ff37ca028fcfe1ad60010acda33a82253 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 13:46:13 +0000 Subject: [PATCH 032/190] more lint nonsense Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index f0dabd66fea..405ced54d2e 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -94,7 +94,7 @@ def _worker_parallel_launch( ) except Exception as ex: print(ex) - traceback.print_exception(ex) + traceback.print_exc() raise finally: torch.distributed.destroy_process_group() @@ -176,8 +176,6 @@ def torch_dispatch( dtype=a.dtype, device=a.device) - #print(f"b_a shape {b_a.shape}") - token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) for token in range(num_tokens): From f147062889054170363616966b94421632988114 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Sat, 15 Mar 2025 01:11:06 +0000 Subject: [PATCH 033/190] WIP torch while Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/cuda_graph_utils.py | 0 vllm/forward_context.py | 3 + vllm/model_executor/layers/fused_moe/layer.py | 74 +++++++++++++++++++ 3 files changed, 77 insertions(+) create mode 100644 vllm/cuda_graph_utils.py diff --git a/vllm/cuda_graph_utils.py b/vllm/cuda_graph_utils.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a9815aba1d6..1e4a77f1da9 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -31,6 +31,7 @@ @dataclass class DPMetadata: + max_tokens_across_dp: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor dp_rank_num_tokens: torch.Tensor @@ -95,6 +96,8 @@ def set_forward_context(attn_metadata: Any, dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) + #TODO device? + max_tokens_across_dp = torch.max(num_tokens_tensor).to(device="cuda") cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_rank_num_tokens = torch.tensor( [num_tokens], diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 35994c8ac6a..5caa53ffd26 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -841,6 +841,80 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) + def forward_impl_while(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + max_tokens_across_dp = get_forward_context( + ).dp_metadata.max_tokens_across_dp + + #TODO: we need to define a couple of ranges: + # 1. the range within this rank's M dimension that we are looping over + # 2. the range within the workspace buffer that our current chunk maps to. + + moe_dp_chunk_size = 256 + my_dp_chunk_size = moe_dp_chunk_size // self.dp_size + chunk_start = torch.tensor(0, device=hidden_states.device) + + def padded_allgather(self, x: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.zeros((moe_dp_chunk_size, x.shape[1]), + device=x.device, + dtype=x.dtype) + + buffer[:x.shape[0], :].copy_(x) + get_dp_group().all_gather(buffer, 0) + return buffer + + def cond_fn(chunk_range, max_tokens_across_dp, hidden_states, + router_logits): + return chunk_range[0] < max_tokens_across_dp + + def body_fn(chunk_range, max_tokens_across_dp, full_hidden_states, + full_router_logits): + hidden_states = full_hidden_states[chunk_range] + router_logits = full_router_logits[chunk_range] + + if self.dp_size > 1: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + + hidden_states = self.padded_allgather(hidden_states) + router_logits = self.padded_allgather(router_logits) + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + ) + + if self.dp_size > 1: + all_hidden_states = get_dp_group().all_reduce( + final_hidden_states) + final_hidden_states[chunk_range] = all_hidden_states[ + start:end, :] + + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # Default set to False. (May have to add shared expert outputs.) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + chunk_range[0] = min(hidden_states.shape[0], + chunk_range[0] + moe_dp_chunk_size) + chunk_range[1] = min(hidden_states.shape[0], + chunk_range[1] + moe_dp_chunk_size) + return chunk_start, hidden_states + def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None From eb58491b27f5d3337cddbe24e5590aaaed3d5689 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 25 Mar 2025 13:10:57 +0000 Subject: [PATCH 034/190] wip Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5caa53ffd26..27cc9993635 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1007,7 +1007,7 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None - return self.forward_impl(hidden_states, router_logits) + return self.forward_impl_while(hidden_states, router_logits) def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, From 7a1b6e1968bdc8deb9b182fee8ae7421fb14d4cc Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 25 Mar 2025 21:32:43 +0000 Subject: [PATCH 035/190] wip Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 27cc9993635..7d6edc50adb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -845,10 +845,19 @@ def forward_impl_while(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu - #TODO: we need to define a couple of ranges: - # 1. the range within this rank's M dimension that we are looping over - # 2. the range within the workspace buffer that our current chunk maps to. + #In this function we define two ranges: + # 1. chunk_range - The current iteration of the loops's range over the DP world tokens + # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. + + chunk_range = torch.zeros(2, device=hidden_states.device) + chunk_range[1] = min(moe_dp_chunk_size, cu_tokens_across_dp_cpu[-1]) + + my_tokens_in_chunk = torch.zeros(2, device=hidden_states.device) + my_tokens_in_chunk[1] = min(my_dp_chunk_size, + chunk_range[1] - chunk_range[0]) moe_dp_chunk_size = 256 my_dp_chunk_size = moe_dp_chunk_size // self.dp_size From b1bef6f9fa7729a75ed8fecc8a71adc43c8ef0c9 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 27 Mar 2025 13:48:42 +0000 Subject: [PATCH 036/190] wip Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/forward_context.py | 6 +- vllm/model_executor/layers/fused_moe/layer.py | 80 +++++++++---------- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 1e4a77f1da9..f6a036d228d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -32,6 +32,7 @@ @dataclass class DPMetadata: max_tokens_across_dp: torch.Tensor + num_tokens_across_dp: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor dp_rank_num_tokens: torch.Tensor @@ -103,7 +104,10 @@ def set_forward_context(attn_metadata: Any, [num_tokens], dtype=torch.uint32, device=vllm_config.device_config.device) - dp_metadata = DPMetadata(cu_tokens_across_dp_cpu, dp_rank_num_tokens) + dp_metadata = DPMetadata(max_tokens_across_dp, + num_tokens_tensor, + cu_tokens_across_dp_cpu, + dp_rank_num_tokens) global _forward_context prev_context = _forward_context diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7d6edc50adb..e31cf848a93 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -841,53 +841,43 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) - def forward_impl_while(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): + def forward_impl_while(self, full_hidden_states: torch.Tensor, + full_router_logits: torch.Tensor): max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp cu_tokens_across_dp_cpu = get_forward_context( ).dp_metadata.cu_tokens_across_dp_cpu + num_tokens_across_dp = get_forward_context( + ).dp_metadata.num_tokens_across_dp - #In this function we define two ranges: - # 1. chunk_range - The current iteration of the loops's range over the DP world tokens - # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. - - chunk_range = torch.zeros(2, device=hidden_states.device) - chunk_range[1] = min(moe_dp_chunk_size, cu_tokens_across_dp_cpu[-1]) - - my_tokens_in_chunk = torch.zeros(2, device=hidden_states.device) - my_tokens_in_chunk[1] = min(my_dp_chunk_size, - chunk_range[1] - chunk_range[0]) - - moe_dp_chunk_size = 256 - my_dp_chunk_size = moe_dp_chunk_size // self.dp_size - chunk_start = torch.tensor(0, device=hidden_states.device) - - def padded_allgather(self, x: torch.Tensor): + def padded_allgather(x: torch.Tensor): assert (len(x.shape) == 2) buffer = torch.zeros((moe_dp_chunk_size, x.shape[1]), device=x.device, dtype=x.dtype) - buffer[:x.shape[0], :].copy_(x) get_dp_group().all_gather(buffer, 0) return buffer - def cond_fn(chunk_range, max_tokens_across_dp, hidden_states, - router_logits): - return chunk_range[0] < max_tokens_across_dp + #In this function we define two ranges: + # 1. chunk_range - The current iteration of the loops's range over the DP world tokens + # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. + + moe_dp_chunk_size = 256 + moe_dp_chunk_size_per_rank = moe_dp_chunk_size // self.dp_size + + num_tokens_remaining_across_dp = num_tokens_across_dp + chunk_start = 0 + chunk_end = min(moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) + full_final_hidden_states = torch.empty_like(full_hidden_states) - def body_fn(chunk_range, max_tokens_across_dp, full_hidden_states, - full_router_logits): - hidden_states = full_hidden_states[chunk_range] - router_logits = full_router_logits[chunk_range] + for _ in range(max_tokens_across_dp, moe_dp_chunk_size_per_rank): + hidden_states = full_hidden_states[chunk_start:chunk_end,:] + router_logits = full_router_logits[chunk_start:chunk_end,:] if self.dp_size > 1: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - - hidden_states = self.padded_allgather(hidden_states) - router_logits = self.padded_allgather(router_logits) + hidden_states = padded_allgather(hidden_states) + router_logits = padded_allgather(router_logits) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -907,22 +897,32 @@ def body_fn(chunk_range, max_tokens_across_dp, full_hidden_states, activation=self.activation, ) + cu_tokens_across_dp_this_iter = torch.cumsum( + num_tokens_remaining_across_dp.clamp(max=moe_dp_chunk_size_per_rank), + dim=0) + if self.dp_size > 1: + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[self.dp_rank-1] + end = cu_tokens_across_dp_this_iter[self.dp_rank] + all_hidden_states = get_dp_group().all_reduce( final_hidden_states) - final_hidden_states[chunk_range] = all_hidden_states[ - start:end, :] + final_hidden_states = all_hidden_states[start:end, :] if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + full_final_hidden_states[chunk_start:chunk_end,:].copy_(final_hidden_states) + + num_tokens_remaining_across_dp = torch.clamp(num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) + chunk_start = min(chunk_start + moe_dp_chunk_size_per_rank, + full_hidden_states.shape[0]) + chunk_end = min(chunk_end + moe_dp_chunk_size_per_rank, + full_hidden_states.shape[0]) + + return full_final_hidden_states - chunk_range[0] = min(hidden_states.shape[0], - chunk_range[0] + moe_dp_chunk_size) - chunk_range[1] = min(hidden_states.shape[0], - chunk_range[1] + moe_dp_chunk_size) - return chunk_start, hidden_states def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): From fb73ea72f1e483105bf6a35b3a456f36e3dec896 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 27 Mar 2025 13:41:18 -0400 Subject: [PATCH 037/190] wip Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 41 ++++++++----------- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2650f5a6064..ebbed9b6eac 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1462,8 +1462,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e31cf848a93..1f262490644 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -841,7 +841,7 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) - def forward_impl_while(self, full_hidden_states: torch.Tensor, + def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp @@ -850,15 +850,6 @@ def forward_impl_while(self, full_hidden_states: torch.Tensor, num_tokens_across_dp = get_forward_context( ).dp_metadata.num_tokens_across_dp - def padded_allgather(x: torch.Tensor): - assert (len(x.shape) == 2) - buffer = torch.zeros((moe_dp_chunk_size, x.shape[1]), - device=x.device, - dtype=x.dtype) - buffer[:x.shape[0], :].copy_(x) - get_dp_group().all_gather(buffer, 0) - return buffer - #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. @@ -871,13 +862,18 @@ def padded_allgather(x: torch.Tensor): chunk_end = min(moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) - for _ in range(max_tokens_across_dp, moe_dp_chunk_size_per_rank): + for _ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): hidden_states = full_hidden_states[chunk_start:chunk_end,:] router_logits = full_router_logits[chunk_start:chunk_end,:] - if self.dp_size > 1: - hidden_states = padded_allgather(hidden_states) - router_logits = padded_allgather(router_logits) + cu_tokens_across_dp_this_iter = torch.cumsum( + num_tokens_remaining_across_dp.clamp(max=moe_dp_chunk_size_per_rank), + dim=0) + + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_this_iter) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_this_iter) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -897,10 +893,6 @@ def padded_allgather(x: torch.Tensor): activation=self.activation, ) - cu_tokens_across_dp_this_iter = torch.cumsum( - num_tokens_remaining_across_dp.clamp(max=moe_dp_chunk_size_per_rank), - dim=0) - if self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[self.dp_rank-1] end = cu_tokens_across_dp_this_iter[self.dp_rank] @@ -913,13 +905,14 @@ def padded_allgather(x: torch.Tensor): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - full_final_hidden_states[chunk_start:chunk_end,:].copy_(final_hidden_states) + full_final_hidden_states[chunk_start:chunk_end, :].copy_(final_hidden_states) + # Update bounds num_tokens_remaining_across_dp = torch.clamp(num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) - chunk_start = min(chunk_start + moe_dp_chunk_size_per_rank, - full_hidden_states.shape[0]) - chunk_end = min(chunk_end + moe_dp_chunk_size_per_rank, - full_hidden_states.shape[0]) + def update_chunk_bound(x: int): + return min(x + moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) + chunk_start = update_chunk_bound(chunk_start) + chunk_end = update_chunk_bound(chunk_end) return full_final_hidden_states @@ -1016,7 +1009,7 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None - return self.forward_impl_while(hidden_states, router_logits) + return self.forward_impl_chunked(hidden_states, router_logits) def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, From 1cf583174191404ffefc2d833fbca774668c3b91 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 28 Mar 2025 16:35:28 -0400 Subject: [PATCH 038/190] WIP integration Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1f262490644..c36e7d0ad3d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,11 +3,14 @@ from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple +from dataclasses import dataclass import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter +import pplx_kernels as pplx + import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, @@ -34,6 +37,24 @@ fused_moe_pallas = None # type: ignore logger = init_logger(__name__) +MOE_DP_CHUNK_SIZE = 256 + +# Adapted from pplx-kernels tests/all_to_all_utils.py +@dataclass +class MoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + dp_size: int + dp_rank: int + ep_size: int + ep_rank: int + + in_dtype: torch.dtype = torch.bfloat16 + out_dtype: torch.dtype = torch.bfloat16 + block_size: int = 128 class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -71,10 +92,22 @@ def apply( ) -> torch.Tensor: raise NotImplementedError - @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" + def __init__(self, moe: MoEConfig): + self.all_to_all = pplx.AllToAll( + max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, + rank=moe.ep_rank, + world_size=moe.ep_size, + dp_size=moe.dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + hidden_dim_scale_bytes=0, + ) + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -854,8 +887,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, # 1. chunk_range - The current iteration of the loops's range over the DP world tokens # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. - moe_dp_chunk_size = 256 - moe_dp_chunk_size_per_rank = moe_dp_chunk_size // self.dp_size + moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size num_tokens_remaining_across_dp = num_tokens_across_dp chunk_start = 0 From 91222dc2e4586a15e19e0f09971d35103a3c5313 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 26 Feb 2025 23:09:34 +0000 Subject: [PATCH 039/190] Add test for deep gemm matmul Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 347 ++++++++++++++++++++++++++++++++ 1 file changed, 347 insertions(+) create mode 100644 tests/kernels/test_block_fp8.py diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py new file mode 100644 index 00000000000..bebc77dcec9 --- /dev/null +++ b/tests/kernels/test_block_fp8.py @@ -0,0 +1,347 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +import deep_gemm + +import itertools +import pytest +import torch + +from typing import Tuple + +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8, w8a8_block_fp8_matmul) +from vllm.platforms import current_platform + +if current_platform.get_device_capability() < (9, 0): + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", + allow_module_level=True) + +# Test configurations +DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] +NUM_TOKENS = [7, 83, 2048] +D = [512, 4096, 5120, 13824] +GROUP_SIZE = [64, 128, 256, 512] +M = [1, 7, 83, 512, 2048] +N = [128, 512, 1024, 4096, 7748, 13824] +K = [256, 4096, 5120, 3884, 13824] +# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 +# and its hidden size is 7168. +M_moe = [1, 7, 83, 512, 2048] +N_moe = [4608] # [128, 4608, 13824] +K_moe = [7168] # [256, 7168, 13824] +BLOCK_SIZE = [[128, 128]] +E = [8, 24] # [8, 24, 128, 256] +TOP_KS = [2] # [1, 2, 6] +OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] +SEEDS = [0] + + +def native_per_token_group_quant_fp8(x, + group_size, + eps=1e-10, + dtype=torch.float8_e4m3fn): + """Function to perform per-token-group quantization on an input tensor + `x` using native torch.""" + assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " + "be divisible by `group_size`") + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, + keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + + return x_q, x_s + + +def native_w8a8_block_fp8_matmul(A, + B, + As, + Bs, + block_size, + output_dtype=torch.float16): + """Matrix multiplication with block-wise quantization using native torch.""" + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N, ) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [ + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) + ] + B_tiles = [[ + B[ + j * block_n:min((j + 1) * block_n, N), + i * block_k:min((i + 1) * block_k, K), + ] for i in range(k_tiles) + ] for j in range(n_tiles)] + C_tiles = [ + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + ] + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_fp8_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8( + act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_fp8_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +# Skip all tests if CUDA is not available +pytest.importorskip("torch.cuda") + + +@pytest.fixture(autouse=True) +def setup_cuda(): + torch.set_default_device("cuda") + + +@pytest.mark.parametrize( + "num_tokens,d,dtype,group_size,seed", + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS)) +@torch.inference_mode() +def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): + torch.manual_seed(seed) + x = torch.rand(num_tokens, d, dtype=dtype) + + ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) + out, scale = per_token_group_quant_fp8(x, group_size) + + assert torch.allclose(out.to(torch.float32), + ref_out.to(torch.float32), + rtol=0.15) + assert torch.allclose(scale, ref_scale) + + +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 + + +@pytest.mark.parametrize( + "M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = (torch.rand( + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w1_bf16 + + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w2_bf16 + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = torch.rand( + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale + w2_s = torch.rand( + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale + + score = torch.randn((M, E), dtype=dtype) + + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + print(f"{out.sum()=}") + print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 + + +######################################################################################### + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((deep_gemm.cell_div(m, 128) * 128, deep_gemm.cell_div(n, 128) * 128), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + + +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + + # only aligned sizes + if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: + return + + # weird max diff errors + if False and (M == 512 or M == 2048): + return + + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + out_dtype) + + A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) + B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) + + # Transpose earlier so that the testing will not trigger transposing kernels + As_dg = deep_gemm.get_col_major_tma_aligned_tensor(As_dg) + + out = torch.empty((M, N), device='cuda', dtype=out_dtype) + + assert As_dg.shape == (M, (K + 127) // 128), f"{As_dg.shape} != {(M, (K + 127) // 128)}" + + deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8_dg, As_dg), (B_fp8_dg, Bs_dg), out) + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.001 From 0e5081ed30b6a481175f2e2523f5cba5f7c43149 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Feb 2025 03:01:01 +0000 Subject: [PATCH 040/190] fix matmul test Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 171 +++++++++++++++++++++++++++++--- 1 file changed, 157 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index bebc77dcec9..249da81b32a 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/sgl-project/sglang/pull/2575 +# TODO: try/catch this? import deep_gemm import itertools @@ -24,12 +25,14 @@ NUM_TOKENS = [7, 83, 2048] D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] -M = [1, 7, 83, 512, 2048] +#M = [1, 7, 83, 512, 2048] +M = [1, 8, 84, 512, 2048] N = [128, 512, 1024, 4096, 7748, 13824] K = [256, 4096, 5120, 3884, 13824] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 7, 83, 512, 2048] +#M_moe = [1, 7, 83, 512, 2048] +M_moe = [1, 8, 84, 512, 2048] N_moe = [4608] # [128, 4608, 13824] K_moe = [7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] @@ -299,16 +302,11 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): - torch.manual_seed(seed) - # only aligned sizes if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: return - # weird max diff errors - if False and (M == 512 or M == 2048): - return - + torch.manual_seed(seed) factor_for_scale = 1e-2 fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min @@ -323,19 +321,22 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale - Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) + B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) - ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, + As = As_dg.to(torch.float32) + Bs = Bs_dg.to(torch.float32) + + ref_out = native_w8a8_block_fp8_matmul(A_fp8_dg, B_fp8_dg, As, Bs, block_size, out_dtype) - A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) - B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) + #A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) + #B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) # Transpose earlier so that the testing will not trigger transposing kernels As_dg = deep_gemm.get_col_major_tma_aligned_tensor(As_dg) - out = torch.empty((M, N), device='cuda', dtype=out_dtype) + out = torch.zeros((M, N), device='cuda', dtype=out_dtype) assert As_dg.shape == (M, (K + 127) // 128), f"{As_dg.shape} != {(M, (K + 127) // 128)}" @@ -345,3 +346,145 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 + + +################################################################################### + +def construct_grouped( + num_groups: int, + m: int, + k: int, + n: int, + is_masked: bool +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16) + + assert m % 4 == 0, f'TMA alignment error: {m}' + x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + # For non-masked input, we must merge the group and M dims + if not is_masked: + x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) + out, ref_out = out.view(-1, n), ref_out.view(-1, n) + + # Transpose earlier so that the testing will not trigger transposing kernels + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out + + +# ref_out = torch.einsum('gmk,gnk->gmn', x, y) + +from vllm.model_executor.layers.fused_moe import fused_topk, grouped_topk + +def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = per_token_group_quant_fp8(a, block_k) + w1, w1_s = per_block_cast_to_fp8(w1) + w2, w2_s = per_block_cast_to_fp8(w2) + + num_groups = w1.shape[0] # ??? + + m_indices = torch.arange(0, num_groups, device=a.device, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) + + inter_out = torch.zeros(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), + (w1, w1_s), + inter_out, + m_indices) + + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) + + num_groups2 = w2.shape[0] # ??? + + m_indices2 = torch.arange(0, num_groups2, device=a.device, dtype=torch.int) + m_indices2 = m_indices2.unsqueeze(-1).expand(num_groups2, n).contiguous().view(-1) + out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((act_out_q, act_out_s), + (w2, w2_s), + out, + m_indices2) + + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize( + "M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M_moe, N, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) +@torch.inference_mode() +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + + # only aligned sizes + if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: + return + + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = (torch.rand( + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w1_bf16 + + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w2_bf16 + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = torch.rand( + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale + w2_s = torch.rand( + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale + + score = torch.randn((M, E), dtype=dtype) + + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + + print(f"{out.sum()=}") + print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 From 9e58fee85d064a4c44977307a66a3bb2e4805290 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Feb 2025 19:55:59 +0000 Subject: [PATCH 041/190] running Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 59 +++++++++++-------- .../layers/fused_moe/fused_moe.py | 1 + 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 249da81b32a..1028310b5ca 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -26,9 +26,15 @@ D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] #M = [1, 7, 83, 512, 2048] -M = [1, 8, 84, 512, 2048] -N = [128, 512, 1024, 4096, 7748, 13824] -K = [256, 4096, 5120, 3884, 13824] + +M = [1, 8, 84, 512, 2048, 4096] +N = [128, 512, 1024, 4096, 7748, 13824, 7168] +K = [256, 4096, 5120, 3884, 13824, 16384] + +#M = [128] +#N = [24576] +#K = [1536] + # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. #M_moe = [1, 7, 83, 512, 2048] @@ -384,46 +390,50 @@ def construct_grouped( def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using native torch.""" - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + M, K = a.shape + print(f"before {a.shape}") + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) + topk_ids = topk_ids.to(dtype=torch.int32).view(-1) _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_k) - w1, w1_s = per_block_cast_to_fp8(w1) - w2, w2_s = per_block_cast_to_fp8(w2) - num_groups = w1.shape[0] # ??? + num_groups = w1.shape[0] + for i in range(num_groups): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1[i].to(dtype=torch.bfloat16)) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2[i].to(dtype=torch.bfloat16)) + + print(f"{M}, {num_groups}, {a.shape}") m_indices = torch.arange(0, num_groups, device=a.device, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) + m_indices = m_indices.unsqueeze(-1).expand(num_groups, a.shape[0]//num_groups).contiguous().view(-1) inter_out = torch.zeros(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) + print("FIRST GEMM") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, - m_indices) + topk_ids) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - num_groups2 = w2.shape[0] # ??? + out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) - m_indices2 = torch.arange(0, num_groups2, device=a.device, dtype=torch.int) - m_indices2 = m_indices2.unsqueeze(-1).expand(num_groups2, n).contiguous().view(-1) - out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + print("SECOND GEMM") deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((act_out_q, act_out_s), (w2, w2_s), out, - m_indices2) + topk_ids) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( @@ -446,11 +456,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, w1_bf16 = (torch.rand( (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max) del w1_bf16 w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max) del w2_bf16 block_n, block_k = block_size[0], block_size[1] @@ -466,6 +476,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, score = torch.randn((M, E), dtype=dtype) + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + + ref_out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) + out = fused_moe( a, w1, @@ -478,9 +494,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, w2_scale=w2_s, block_shape=block_size, ) - ref_out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - print(f"{out.sum()=}") print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ebbed9b6eac..585d3de78f4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -502,6 +502,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, M = A.shape[0] num_tokens = M * top_k + # EM = num_groups EM = sorted_token_ids.shape[0] if A.shape[0] < config["BLOCK_SIZE_M"]: # optimize for small batch_size. From a1ccb7849199b0adc0751e6b844641515da6f49c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Feb 2025 20:54:25 +0000 Subject: [PATCH 042/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 143 ++++++++++++++++---------------- 1 file changed, 70 insertions(+), 73 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 1028310b5ca..7a9a46291ae 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,6 +10,7 @@ from typing import Tuple +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -292,12 +293,12 @@ def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) -def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def per_block_cast_to_fp8(x: torch.Tensor, block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((deep_gemm.cell_div(m, 128) * 128, deep_gemm.cell_div(n, 128) * 128), dtype=x.dtype, device=x.device) + x_padded = torch.zeros((deep_gemm.cell_div(m, 128) * 128, deep_gemm.cell_div(n, block_size_n) * block_size_n), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) @@ -388,32 +389,40 @@ def construct_grouped( from vllm.model_executor.layers.fused_moe import fused_topk, grouped_topk -def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): +def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, score, topk, block_shape): """Fused moe with block-wise quantization using native torch.""" - M, K = a.shape - print(f"before {a.shape}") - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) topk_ids = topk_ids.to(dtype=torch.int32).view(-1) - _, block_k = block_shape[0], block_shape[1] + M, K = a.shape + N = w2.shape[-1] + num_groups = w1.shape[0] + + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + + block_n, block_k = block_shape[0], block_shape[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = torch.empty((num_groups, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + a_q, a_s = per_token_group_quant_fp8(a, block_k) - num_groups = w1.shape[0] for i in range(num_groups): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1[i].to(dtype=torch.bfloat16)) + w1[i], w1_s[i] = per_block_cast_to_fp8(w1[i].to(dtype=torch.bfloat16), block_n) w2[i], w2_s[i] = per_block_cast_to_fp8(w2[i].to(dtype=torch.bfloat16)) - print(f"{M}, {num_groups}, {a.shape}") + inter_out = torch.empty(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) - m_indices = torch.arange(0, num_groups, device=a.device, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(num_groups, a.shape[0]//num_groups).contiguous().view(-1) + #print("FIRST GEMM") - inter_out = torch.zeros(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) - - print("FIRST GEMM") + w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), @@ -425,7 +434,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) - print("SECOND GEMM") + #print("SECOND GEMM") deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((act_out_q, act_out_s), (w2, w2_s), @@ -433,7 +442,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape topk_ids) return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1), w1_s, w2_s @pytest.mark.parametrize( @@ -444,60 +453,48 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes - if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: + if M % 4 != 0 or K % 128 != 0 or N % 128 != 0: return - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = (torch.rand( - (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max) - del w1_bf16 - - w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max) - del w2_bf16 - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = torch.rand( - (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale - w2_s = torch.rand( - (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale - - score = torch.randn((M, E), dtype=dtype) - - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) - - ref_out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - print(f"{out.sum()=}") - print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + torch.manual_seed(seed) + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = (torch.rand( + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w1_bf16 + + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + del w2_bf16 + + score = torch.randn((M, E), dtype=dtype) + + # TODO: move out scale setup + ref_out, w1_s, w2_s = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, score, topk, block_size) + + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 From a1b033e24c12896856030575a4114aa86f1442d5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 04:23:16 +0000 Subject: [PATCH 043/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 290 +++++++++++++++++--------------- 1 file changed, 151 insertions(+), 139 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 7a9a46291ae..97b99445536 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -2,14 +2,13 @@ # Adapted from https://github.com/sgl-project/sglang/pull/2575 # TODO: try/catch this? -import deep_gemm - import itertools +from typing import Tuple + +import deep_gemm import pytest import torch -from typing import Tuple - from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe @@ -43,7 +42,8 @@ N_moe = [4608] # [128, 4608, 13824] K_moe = [7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] -E = [8, 24] # [8, 24, 128, 256] +#E = [8, 24] # [8, 24, 128, 256] +E = [8, 16] # [8, 24, 128, 256] TOP_KS = [2] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -285,23 +285,33 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ######################################################################################### -def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + +def per_token_cast_to_fp8( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to( + torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) -def per_block_cast_to_fp8(x: torch.Tensor, block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((deep_gemm.cell_div(m, 128) * 128, deep_gemm.cell_div(n, block_size_n) * block_size_n), dtype=x.dtype, device=x.device) + x_padded = torch.zeros( + (deep_gemm.cell_div(m, 128) * 128, + deep_gemm.cell_div(n, block_size_n) * block_size_n), + dtype=x.dtype, + device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( + x_amax / 448.0).view(x_view.size(0), x_view.size(2)) @pytest.mark.parametrize( @@ -314,40 +324,32 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): return torch.manual_seed(seed) - factor_for_scale = 1e-2 fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min + fp8_max = fp8_info.max A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - block_n, block_k = block_size[0], block_size[1] - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k + _, block_k = block_size[0], block_size[1] - A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) - B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) + A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k) + B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) - As = As_dg.to(torch.float32) - Bs = Bs_dg.to(torch.float32) + As = As_fp8.to(torch.float32) + Bs = Bs_fp8.to(torch.float32) - ref_out = native_w8a8_block_fp8_matmul(A_fp8_dg, B_fp8_dg, As, Bs, block_size, + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - #A_fp8_dg, As_dg = per_token_group_quant_fp8(A_fp32, block_k) - #B_fp8_dg, Bs_dg = per_block_cast_to_fp8(B_fp32) - # Transpose earlier so that the testing will not trigger transposing kernels - As_dg = deep_gemm.get_col_major_tma_aligned_tensor(As_dg) + As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) out = torch.zeros((M, N), device='cuda', dtype=out_dtype) - assert As_dg.shape == (M, (K + 127) // 128), f"{As_dg.shape} != {(M, (K + 127) // 128)}" + assert As_fp8.shape == (M, (K + 127) // + 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" - deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8_dg, As_dg), (B_fp8_dg, Bs_dg), out) + deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / @@ -357,144 +359,154 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): ################################################################################### -def construct_grouped( - num_groups: int, - m: int, - k: int, - n: int, - is_masked: bool -) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: - x = torch.randn((num_groups, m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - out = torch.empty((num_groups, m, n), device='cuda', dtype=torch.bfloat16) - - assert m % 4 == 0, f'TMA alignment error: {m}' - x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, m, k // 128), device='cuda', dtype=torch.float)) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, (n + 127) // 128, k // 128), device='cuda', dtype=torch.float)) - for i in range(num_groups): - x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - - # For non-masked input, we must merge the group and M dims - if not is_masked: - x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) - out, ref_out = out.view(-1, n), ref_out.view(-1, n) - - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out - - # ref_out = torch.einsum('gmk,gnk->gmn', x, y) -from vllm.model_executor.layers.fused_moe import fused_topk, grouped_topk -def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" +def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_shape): + """Fused moe with block-wise quantization using DeepGemm torch.""" + M = a.numel() // a.shape[-1] + K = w1.shape[-1] + num_groups = w1.shape[0] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + inter_out = torch.zeros(a.shape[0], + w1.shape[1], + dtype=torch.bfloat16, + device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.to(dtype=torch.int32).view(-1) - - M, K = a.shape - N = w2.shape[-1] - num_groups = w1.shape[0] - - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - - block_n, block_k = block_shape[0], block_shape[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = torch.empty((num_groups, n_tiles_w1, k_tiles_w1), dtype=torch.float32) - w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + topk_ids = topk_ids.view(-1) + _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_k) - for i in range(num_groups): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1[i].to(dtype=torch.bfloat16), block_n) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2[i].to(dtype=torch.bfloat16)) + #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) + #print(f"FIRST GEMM {a_q.shape}") - inter_out = torch.empty(a_q.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) + m_indices = torch.arange(0, num_groups, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand( + num_groups, (2 * M) // num_groups).contiguous().view(-1) + #print(f"m_indices {m_indices.shape}, ng={num_groups}") - #print("FIRST GEMM") - - w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), - (w1, w1_s), - inter_out, - topk_ids) + if True: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (a_q, a_s), (w1, w1_s), inter_out, m_indices) + else: + topk_ids = topk_ids.to(dtype=torch.int32) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), + inter_out, topk_ids, M) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + out = torch.zeros(M * topk, + w2.shape[1], + dtype=torch.bfloat16, + device=a.device) #print("SECOND GEMM") - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((act_out_q, act_out_s), - (w2, w2_s), - out, - topk_ids) + if True: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (act_out_q, act_out_s), (w2, w2_s), out, m_indices) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1), w1_s, w2_s + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) + itertools.product(M_moe, N, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, + dtype, seed): # only aligned sizes - if M % 4 != 0 or K % 128 != 0 or N % 128 != 0: + if (M % 4 != 0 or N % 128 != 0 or K % 128 != 0): return vllm_config = VllmConfig() + + torch.manual_seed(seed) + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * + fp8_max).clamp(min=fp8_min, max=fp8_max) + + w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * + fp8_max).clamp(min=fp8_min, max=fp8_max) + + score = torch.randn((M, E), dtype=dtype) + + num_groups = E + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = ((2 * N) + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w2 = (N + block_k - 1) // block_k + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + + w1_s = torch.empty((num_groups, n_tiles_w1, k_tiles_w1), + dtype=torch.float32) + w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), + dtype=torch.float32) + + assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) + assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] + for i in range(num_groups): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + + w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + with set_current_vllm_config(vllm_config): - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = (torch.rand( - (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w1_bf16 - - w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w2_bf16 - - score = torch.randn((M, E), dtype=dtype) - - # TODO: move out scale setup - ref_out, w1_s, w2_s = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, score, topk, block_size) - - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 + if False: + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + else: + out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + ref_out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") + + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 From 6b0aa02f3eb91e7c7ffed6c0eeeaed8dfd649552 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 22:03:27 +0000 Subject: [PATCH 044/190] debugging Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 99 ++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 31 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 97b99445536..0093d74efa7 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +from vllm.utils import cdiv if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", @@ -223,11 +224,13 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 +def p(s, t): + print(f"{s}: {t.shape}, {t.dtype}") @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) + #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([4], [128], [128], [8], [2], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.manual_seed(seed) @@ -235,6 +238,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min + vllm_config = VllmConfig() + a = torch.randn((M, K), dtype=dtype) / 10 w1_bf16 = (torch.rand( @@ -259,20 +264,27 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) + p("a", a) + p("w1", w1) + p("w1_s", w1_s) + p("w2", w2) + p("w2_s", w2_s) + + with set_current_vllm_config(vllm_config): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) print(f"{out.sum()=}") print(f"{ref_out.sum()=}") @@ -310,8 +322,9 @@ def per_block_cast_to_fp8( x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( - x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales @pytest.mark.parametrize( @@ -369,7 +382,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, K = w1.shape[-1] num_groups = w1.shape[0] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - inter_out = torch.zeros(a.shape[0], + inter_out = torch.empty(a.shape[0], w1.shape[1], dtype=torch.bfloat16, device=a.device) @@ -386,8 +399,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, m_indices = torch.arange(0, num_groups, dtype=torch.int) m_indices = m_indices.unsqueeze(-1).expand( - num_groups, (2 * M) // num_groups).contiguous().view(-1) - #print(f"m_indices {m_indices.shape}, ng={num_groups}") + num_groups, max(M // num_groups, 1)).contiguous().view(-1) + p("m_indices", m_indices) if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -400,13 +413,13 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - out = torch.zeros(M * topk, + #print("SECOND GEMM") + + out = torch.empty(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, device=a.device) - #print("SECOND GEMM") - if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) @@ -420,13 +433,15 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes - if (M % 4 != 0 or N % 128 != 0 or K % 128 != 0): + if (N % 128 != 0 or K % 128 != 0): + print(f"skip {N}, {K}") return vllm_config = VllmConfig() @@ -460,14 +475,35 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + print(f"NUM_GROUPS = {num_groups}") + p("before w1_s", w1_s) + p("before w2_s", w2_s) + assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] for i in range(num_groups): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + p("imm w1_s", w1_s) + + w1_sa = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_sa = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + + if w1_sa.shape != w1_s.shape or w2_sa.shape != w2_s.shape: + p("w1_sa", w1_sa) + p("w2_sa", w2_sa) + print(f"UNALIGNED") + return + + w1_s = w1_sa + w2_s = w2_sa + + p("a", a) + p("w1", w1) + p("final w1_s", w1_s) + p("w2", w2) + p("w2_s", w2_s) with set_current_vllm_config(vllm_config): if False: @@ -487,9 +523,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) else: - out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - ref_out = fused_moe( a, w1, @@ -503,6 +536,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) + out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") From 4b91cd43bd7ad3cbc0f1f9891060e41c3ad4088e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 22:04:31 +0000 Subject: [PATCH 045/190] debugging Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 585d3de78f4..47626045a85 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1360,6 +1360,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + print(intermediate_cache2) + if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) From 06acd02e777918efb301efdc20e76937cc2e86a9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 22:04:45 +0000 Subject: [PATCH 046/190] fix Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 47626045a85..585d3de78f4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1360,8 +1360,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - print(intermediate_cache2) - if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) From 24a90b9eece1826e52b1748eef5c72da7177bd9c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Feb 2025 23:41:36 +0000 Subject: [PATCH 047/190] update deep gemm Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 30 ++++++++++++------- .../layers/fused_moe/fused_moe.py | 2 ++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 0093d74efa7..ea96724f559 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -229,8 +229,7 @@ def p(s, t): @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([4], [128], [128], [8], [2], [[128, 128]], DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.manual_seed(seed) @@ -314,8 +313,8 @@ def per_block_cast_to_fp8( assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros( - (deep_gemm.cell_div(m, 128) * 128, - deep_gemm.cell_div(n, block_size_n) * block_size_n), + (deep_gemm.ceil_div(m, 128) * 128, + deep_gemm.ceil_div(n, block_size_n) * block_size_n), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x @@ -334,7 +333,7 @@ def per_block_cast_to_fp8( def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: - return + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") torch.manual_seed(seed) fp8_info = torch.finfo(torch.float8_e4m3fn) @@ -399,8 +398,15 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, m_indices = torch.arange(0, num_groups, dtype=torch.int) m_indices = m_indices.unsqueeze(-1).expand( - num_groups, max(M // num_groups, 1)).contiguous().view(-1) + num_groups, max((topk * M) // num_groups, 1)).contiguous().view(-1) + #m_indices = torch.IntTensor([0, 1]) p("m_indices", m_indices) + print(m_indices) + + print("topk", topk_ids) + print(topk_ids) + print("topk_weight", topk_weight) + print(topk_weight) if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -410,6 +416,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), inter_out, topk_ids, M) + print(f"DG {inter_out.shape} {inter_out}") + act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) @@ -441,8 +449,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, # only aligned sizes if (N % 128 != 0 or K % 128 != 0): - print(f"skip {N}, {K}") - return + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + + torch.set_printoptions(profile="full") vllm_config = VllmConfig() @@ -490,11 +499,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_sa = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() w2_sa = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + # TODO: move size alignment further up when setting up all shapes if w1_sa.shape != w1_s.shape or w2_sa.shape != w2_s.shape: p("w1_sa", w1_sa) p("w2_sa", w2_sa) - print(f"UNALIGNED") - return + print("UNALIGNED") + pytest.skip("UNALIGNED") w1_s = w1_sa w2_s = w2_sa diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 585d3de78f4..d390553d619 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1360,6 +1360,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) From 49ec1c6319798eec78113b731a4839208fadbd07 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 1 Mar 2025 00:21:16 +0000 Subject: [PATCH 048/190] update deep gemm + small test case Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ea96724f559..cc2d1d8673f 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform -from vllm.utils import cdiv if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", @@ -224,12 +223,15 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.001 + def p(s, t): print(f"{s}: {t.shape}, {t.dtype}") + @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.manual_seed(seed) @@ -399,7 +401,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, m_indices = torch.arange(0, num_groups, dtype=torch.int) m_indices = m_indices.unsqueeze(-1).expand( num_groups, max((topk * M) // num_groups, 1)).contiguous().view(-1) - #m_indices = torch.IntTensor([0, 1]) + #m_indices = torch.IntTensor([1, 0]).to(dtype=torch.int32, device=a.device) p("m_indices", m_indices) print(m_indices) @@ -442,7 +444,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) + itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, + SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -485,8 +488,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype=torch.float32) print(f"NUM_GROUPS = {num_groups}") - p("before w1_s", w1_s) - p("before w2_s", w2_s) assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] @@ -494,8 +495,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - p("imm w1_s", w1_s) - w1_sa = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() w2_sa = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() @@ -511,7 +510,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, p("a", a) p("w1", w1) - p("final w1_s", w1_s) + #print(w1) + p("w1_s", w1_s) + #print(w1_s) p("w2", w2) p("w2_s", w2_s) @@ -549,7 +550,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") From ceba47638dfc330a7338c60dd7d6f757ed99a736 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 2 Mar 2025 20:28:35 +0000 Subject: [PATCH 049/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 81 +++++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 19 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index cc2d1d8673f..cdb4b601a1c 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -295,10 +295,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 - ######################################################################################### - def per_token_cast_to_fp8( x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 @@ -375,16 +373,50 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # ref_out = torch.einsum('gmk,gnk->gmn', x, y) +def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(dtype=torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = torch.empty((a_q[mask].shape[0], w1[i].shape[0]), + device=a_q.device, dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt((a_q[mask].to(dtype=torch.float8_e4m3fn), a_s[mask]), + (w1[i], w1_s[i]), + inter_out) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) + tmp_out = torch.empty((act_out.shape[0], w2[i].shape[0]), + device=a_q.device, dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt((act_out_q, act_out_s), + (w2[i], w2_s[i]), + tmp_out) + out[mask] = tmp_out + + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" - M = a.numel() // a.shape[-1] - K = w1.shape[-1] num_groups = w1.shape[0] + M = a.numel() // a.shape[-1] # * num_groups) + M_sum = M # * num_groups + K = w1.shape[-1] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - inter_out = torch.empty(a.shape[0], - w1.shape[1], + inter_out = torch.empty((M_sum, K), dtype=torch.bfloat16, device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) @@ -392,8 +424,15 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + print(f"BLOCK_M {block_m}") + p("A", a) + _, block_k = block_shape[0], block_shape[1] - a_q, a_s = per_token_group_quant_fp8(a, block_k) + a_q, a_s = per_token_group_quant_fp8(a, block_m) + + p("A_q", a_q) + p("A_s", a_s) #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") @@ -437,8 +476,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + return (out.view(M_sum, -1, w2.shape[1]) * + topk_weight.view(M_sum, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( @@ -479,18 +518,22 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, n_tiles_w2 = (K + block_n - 1) // block_n k_tiles_w2 = (N + block_k - 1) // block_k - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + # TODO: turn these back to empty calls + w1 = torch.zeros_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.zeros_like(w2_bf16, dtype=torch.float8_e4m3fn) - w1_s = torch.empty((num_groups, n_tiles_w1, k_tiles_w1), + w1_s = torch.zeros((num_groups, n_tiles_w1, k_tiles_w1), dtype=torch.float32) - w2_s = torch.empty((num_groups, n_tiles_w2, k_tiles_w2), + w2_s = torch.zeros((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) print(f"NUM_GROUPS = {num_groups}") assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] + + # TODO: fix later + print("For now, only convert the first group, the rest will be 0") for i in range(num_groups): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) @@ -517,7 +560,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, p("w2_s", w2_s) with set_current_vllm_config(vllm_config): - if False: + if True: out = fused_moe( a, w1, @@ -531,9 +574,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) + ref_out = deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) else: + out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + ref_out = fused_moe( a, w1, @@ -547,9 +593,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") From 9ac041c346b4fb861170e1236c663909e54d3a05 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 2 Mar 2025 20:40:35 +0000 Subject: [PATCH 050/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index cdb4b601a1c..d63bbd2e1bb 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -412,9 +412,11 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" num_groups = w1.shape[0] - M = a.numel() // a.shape[-1] # * num_groups) - M_sum = M # * num_groups - K = w1.shape[-1] + M = a.shape[0] + M_sum = M * topk + N = w1.shape[1] // 2 + K = w1.shape[2] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) inter_out = torch.empty((M_sum, K), dtype=torch.bfloat16, @@ -437,10 +439,15 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - m_indices = torch.arange(0, num_groups, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand( - num_groups, max((topk * M) // num_groups, 1)).contiguous().view(-1) - #m_indices = torch.IntTensor([1, 0]).to(dtype=torch.int32, device=a.device) + # use topk_ids?? + if True: + m_indices = torch.arange(0, num_groups, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand( + num_groups, max(M_sum // num_groups, 1)).contiguous().view(-1) + #m_indices = torch.IntTensor([1, 0]).to(dtype=torch.int32, device=a.device) + else: + pass + p("m_indices", m_indices) print(m_indices) @@ -560,7 +567,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, p("w2_s", w2_s) with set_current_vllm_config(vllm_config): - if True: + if False: out = fused_moe( a, w1, From 696e6a29aed5824671110106975b550399f62f05 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 2 Mar 2025 22:52:51 +0000 Subject: [PATCH 051/190] problem with scores Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 43 ++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index d63bbd2e1bb..8f63f16f332 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -12,6 +12,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -414,9 +415,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, num_groups = w1.shape[0] M = a.shape[0] M_sum = M * topk - N = w1.shape[1] // 2 - K = w1.shape[2] - + K = w1.shape[2] # w2.shape[1] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) inter_out = torch.empty((M_sum, K), dtype=torch.bfloat16, @@ -430,28 +429,31 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, print(f"BLOCK_M {block_m}") p("A", a) + row_size = max(M_sum // num_groups, 1) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(topk_ids, row_size, num_groups, None) + ) + m_indices = expert_ids + assert m_indices.numel() == M_sum + print(f"num_tokens_post_padded = {num_tokens_post_padded}") + p("expert ids", expert_ids) + _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_m) - p("A_q", a_q) - p("A_s", a_s) - #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - # use topk_ids?? - if True: - m_indices = torch.arange(0, num_groups, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand( - num_groups, max(M_sum // num_groups, 1)).contiguous().view(-1) - #m_indices = torch.IntTensor([1, 0]).to(dtype=torch.int32, device=a.device) - else: - pass + # m_indices maps to expert_ids + #m_indices = torch.arange(0, num_groups, dtype=torch.int) + #m_indices = m_indices.unsqueeze(-1).expand( + # num_groups, row_size).contiguous().view(-1) p("m_indices", m_indices) print(m_indices) - print("topk", topk_ids) + print("topk_ids", topk_ids) print(topk_ids) print("topk_weight", topk_weight) print(topk_weight) @@ -483,8 +485,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) - return (out.view(M_sum, -1, w2.shape[1]) * - topk_weight.view(M_sum, -1, 1).to(out.dtype)).sum(dim=1) + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( @@ -516,7 +518,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - score = torch.randn((M, E), dtype=dtype) + #score = torch.randn((M, E), dtype=dtype) + score = torch.zeros((M, E), dtype=dtype) num_groups = E block_n, block_k = block_size[0], block_size[1] @@ -600,8 +603,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") + print(f"{out.sum()=}") + print(f"{ref_out.sum()=}") rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / From 383c3644de18321db07379c100110b5760776cb9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 19:39:40 +0000 Subject: [PATCH 052/190] some passing tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 81 +++++++++++-------- .../layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 8f63f16f332..2b625b838e8 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -39,13 +39,13 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. #M_moe = [1, 7, 83, 512, 2048] -M_moe = [1, 8, 84, 512, 2048] -N_moe = [4608] # [128, 4608, 13824] -K_moe = [7168] # [256, 7168, 13824] +M_moe = [1, 2, 8, 84, 512] #, 2048] +N_moe = [128, 256, 4608] # [128, 4608, 13824] +K_moe = [256, 512, 7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] #E = [8, 24] # [8, 24, 128, 256] -E = [8, 16] # [8, 24, 128, 256] -TOP_KS = [2] # [1, 2, 6] +E = [2] #, 8] #, 16] # [8, 24, 128, 256] +TOP_KS = [1] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -227,7 +227,11 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): print(f"{s}: {t.shape}, {t.dtype}") + pass +def pp(x): + print(x) + pass @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", @@ -413,11 +417,10 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" num_groups = w1.shape[0] - M = a.shape[0] - M_sum = M * topk - K = w1.shape[2] # w2.shape[1] + M, K = a.shape + N = w2.shape[-1] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - inter_out = torch.empty((M_sum, K), + inter_out = torch.empty((a.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) @@ -426,18 +429,18 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_ids = topk_ids.view(-1) block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - print(f"BLOCK_M {block_m}") + pp(f"BLOCK_M {block_m}") p("A", a) - row_size = max(M_sum // num_groups, 1) + row_size = max((topk * M) // num_groups, 1) # 2 *? sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, row_size, num_groups, None) + moe_align_block_size(topk_ids, M * topk, num_groups, None) ) m_indices = expert_ids - assert m_indices.numel() == M_sum - print(f"num_tokens_post_padded = {num_tokens_post_padded}") - p("expert ids", expert_ids) + #assert m_indices.numel() == num_groups * M * topk + #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") + #p("expert ids", expert_ids) _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_m) @@ -446,17 +449,16 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, #print(f"FIRST GEMM {a_q.shape}") # m_indices maps to expert_ids - #m_indices = torch.arange(0, num_groups, dtype=torch.int) - #m_indices = m_indices.unsqueeze(-1).expand( - # num_groups, row_size).contiguous().view(-1) - + m_indices = torch.arange(0, M, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(M, topk).contiguous().view(-1) p("m_indices", m_indices) - print(m_indices) + pp(m_indices) + p("topk_ids", topk_ids) + #pp(topk_ids) + p("topk_weight", topk_weight) + #pp(topk_weight) - print("topk_ids", topk_ids) - print(topk_ids) - print("topk_weight", topk_weight) - print(topk_weight) + pp("FIRST GEMM") if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -466,12 +468,14 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), inter_out, topk_ids, M) - print(f"DG {inter_out.shape} {inter_out}") + pp("FIRST GEMM DONE") + + #pp(f"DG {inter_out.shape} {inter_out}") act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - #print("SECOND GEMM") + pp("SECOND GEMM") out = torch.empty(act_out.shape[0], w2.shape[1], @@ -485,15 +489,16 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) + pp("SECOND GEMM DONE") + return (out.view(M, -1, w2.shape[1]) * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, - SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -502,6 +507,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if (N % 128 != 0 or K % 128 != 0): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + torch.set_printoptions(profile="full") vllm_config = VllmConfig() @@ -519,7 +526,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, fp8_max).clamp(min=fp8_min, max=fp8_max) #score = torch.randn((M, E), dtype=dtype) + if False: + score = torch.empty((M, E), dtype=dtype) + for i in range(M): + score[i] = torch.full((E,), 1.0/(i+1), dtype=dtype) + for i in range(score.numel()): + score.view(-1)[i] = 1.0/(i+1) score = torch.zeros((M, E), dtype=dtype) + p("score", score) + #pp(score) num_groups = E block_n, block_k = block_size[0], block_size[1] @@ -537,13 +552,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = torch.zeros((num_groups, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - print(f"NUM_GROUPS = {num_groups}") - assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] # TODO: fix later - print("For now, only convert the first group, the rest will be 0") + pp("For now, only convert the first group, the rest will be 0") for i in range(num_groups): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) @@ -603,8 +616,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - print(f"{out.sum()=}") - print(f"{ref_out.sum()=}") + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d390553d619..9cc8bbd68fd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1360,7 +1360,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + #print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, From 4ff5a1a2ea16cc4a1194130a4493cdfeff604d50 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 19:40:04 +0000 Subject: [PATCH 053/190] some passing tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 2b625b838e8..42709535fea 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -45,7 +45,7 @@ BLOCK_SIZE = [[128, 128]] #E = [8, 24] # [8, 24, 128, 256] E = [2] #, 8] #, 16] # [8, 24, 128, 256] -TOP_KS = [1] # [1, 2, 6] +TOP_KS = [2] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] From 6c49e2d615508ffca6f6f06b4f217210a521518a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 19:46:48 +0000 Subject: [PATCH 054/190] topk > 1 doesn't work. prune oom-ing tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 42709535fea..308956678e1 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -45,7 +45,7 @@ BLOCK_SIZE = [[128, 128]] #E = [8, 24] # [8, 24, 128, 256] E = [2] #, 8] #, 16] # [8, 24, 128, 256] -TOP_KS = [2] # [1, 2, 6] +TOP_KS = [1] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -495,6 +495,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) +# topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) From bada057c3f5153de429a71a97974fb6cacd5ba91 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 20:07:51 +0000 Subject: [PATCH 055/190] fix indices Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 308956678e1..1c5f9c2ce64 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -449,8 +449,8 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, #print(f"FIRST GEMM {a_q.shape}") # m_indices maps to expert_ids - m_indices = torch.arange(0, M, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(M, topk).contiguous().view(-1) + m_indices = torch.arange(0, topk, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) p("m_indices", m_indices) pp(m_indices) p("topk_ids", topk_ids) @@ -499,6 +499,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) #itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, From 69788af518dbe7c1e98c386cd0a0d5f248bb608e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 20:23:10 +0000 Subject: [PATCH 056/190] enable more tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 1c5f9c2ce64..6831ab139b1 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,14 +38,12 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -#M_moe = [1, 7, 83, 512, 2048] -M_moe = [1, 2, 8, 84, 512] #, 2048] +M_moe = [1, 2, 7, 83, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] -#E = [8, 24] # [8, 24, 128, 256] -E = [2] #, 8] #, 16] # [8, 24, 128, 256] -TOP_KS = [1] # [1, 2, 6] +E = [2, 8, 16] # 24 # [8, 24, 128, 256] +TOP_KS = [1, 2] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -226,11 +224,11 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): - print(f"{s}: {t.shape}, {t.dtype}") + #print(f"{s}: {t.shape}, {t.dtype}") pass def pp(x): - print(x) + #print(x) pass @pytest.mark.parametrize( @@ -505,9 +503,9 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - # only aligned sizes - if (N % 128 != 0 or K % 128 != 0): - pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + # only aligned sizes or supported topk + if (N % 128 != 0 or K % 128 != 0 or topk > 1): + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") From cf83822e5c41e1aac36313ea1b03121fd7c4cbfe Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Mar 2025 20:37:22 +0000 Subject: [PATCH 057/190] format Signed-off-by: Bill Nell --- requirements/test.txt | 6 ++++ tests/kernels/test_block_fp8.py | 52 ++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 9a15d9a0d82..60b8faa0fa2 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -126,6 +126,10 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval +exceptiongroup==1.2.2 + # via + # anyio + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -759,9 +763,11 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer tzdata==2024.2 diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 6831ab139b1..08f620789f7 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -42,7 +42,7 @@ N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16] # 24 # [8, 24, 128, 256] +E = [2, 8, 16] # 24 # [8, 24, 128, 256] TOP_KS = [1, 2] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -227,10 +227,12 @@ def p(s, t): #print(f"{s}: {t.shape}, {t.dtype}") pass + def pp(x): #print(x) pass + @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, @@ -298,8 +300,10 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 + ######################################################################################### + def per_token_cast_to_fp8( x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 @@ -376,11 +380,16 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # ref_out = torch.einsum('gmk,gnk->gmn', x, y) -def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + +def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_shape): """Fused moe with block-wise quantization using native torch.""" B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) + out = torch.zeros(B * topk, + w2.shape[1], + dtype=torch.bfloat16, + device=a.device) score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) topk_weight = topk_weight.view(-1) @@ -393,24 +402,24 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, bloc mask = topk_ids == i if mask.sum(): inter_out = torch.empty((a_q[mask].shape[0], w1[i].shape[0]), - device=a_q.device, dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt((a_q[mask].to(dtype=torch.float8_e4m3fn), a_s[mask]), - (w1[i], w1_s[i]), - inter_out) + device=a_q.device, + dtype=torch.bfloat16) + deep_gemm.gemm_fp8_fp8_bf16_nt( + (a_q[mask].to(dtype=torch.float8_e4m3fn), a_s[mask]), + (w1[i], w1_s[i]), inter_out) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) tmp_out = torch.empty((act_out.shape[0], w2[i].shape[0]), - device=a_q.device, dtype=torch.bfloat16) + device=a_q.device, + dtype=torch.bfloat16) deep_gemm.gemm_fp8_fp8_bf16_nt((act_out_q, act_out_s), - (w2[i], w2_s[i]), - tmp_out) + (w2[i], w2_s[i]), tmp_out) out[mask] = tmp_out return (out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" @@ -433,8 +442,7 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, row_size = max((topk * M) // num_groups, 1) # 2 *? sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, M * topk, num_groups, None) - ) + moe_align_block_size(topk_ids, M * topk, num_groups, None)) m_indices = expert_ids #assert m_indices.numel() == num_groups * M * topk #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") @@ -496,9 +504,10 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, # topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) - #itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) +#itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) +#itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -507,7 +516,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if (N % 128 != 0 or K % 128 != 0 or topk > 1): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") - pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}" + ) torch.set_printoptions(profile="full") @@ -529,9 +539,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if False: score = torch.empty((M, E), dtype=dtype) for i in range(M): - score[i] = torch.full((E,), 1.0/(i+1), dtype=dtype) + score[i] = torch.full((E, ), 1.0 / (i + 1), dtype=dtype) for i in range(score.numel()): - score.view(-1)[i] = 1.0/(i+1) + score.view(-1)[i] = 1.0 / (i + 1) score = torch.zeros((M, E), dtype=dtype) p("score", score) #pp(score) @@ -597,8 +607,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - ref_out = deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) + ref_out = deep_gemm_matmul_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size) else: out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) From 3b57bf9a6011e0cb8ca6a82b6327e129b74c9ea8 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 4 Mar 2025 21:59:00 +0000 Subject: [PATCH 058/190] use fused_topk for unit test Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 159 +++++++++++------- .../layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 103 insertions(+), 58 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 08f620789f7..05e4de3e3f7 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -12,7 +12,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size +from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size, fused_topk from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -38,12 +38,16 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 2, 7, 83, 512, 2048] +#M_moe = [1, 2, 7, 83] #, 512, 2048] +M_moe = [128, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] +M_moe_small = [128, 512] +N_moe_small = [128, 256] +K_moe_small = [256, 512] BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16] # 24 # [8, 24, 128, 256] -TOP_KS = [1, 2] # [1, 2, 6] +E = [2, 8] #, 16] # 24 # [8, 24, 128, 256] +TOP_KS = [1, 2, 6] # [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -224,7 +228,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): - #print(f"{s}: {t.shape}, {t.dtype}") + print(f"{s}: {t.shape}, {t.dtype}") pass @@ -385,13 +389,18 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using native torch.""" B, D = a.shape + pre_a = a a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) + if False: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + else: + topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) + del pre_a topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) @@ -420,18 +429,25 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) -def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, +def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" num_groups = w1.shape[0] M, K = a.shape N = w2.shape[-1] + pre_a = a a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + inter_out = torch.empty((a.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) + + if True: + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + else: + topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) + del pre_a topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) @@ -439,26 +455,39 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, pp(f"BLOCK_M {block_m}") p("A", a) - row_size = max((topk * M) // num_groups, 1) # 2 *? - - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, M * topk, num_groups, None)) - m_indices = expert_ids - #assert m_indices.numel() == num_groups * M * topk - #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") - #p("expert ids", expert_ids) - _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_m) #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - # m_indices maps to expert_ids - m_indices = torch.arange(0, topk, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) + if False: + m_indices = torch.arange(0, M * topk, dtype=torch.int) + #m_indices = m_indices.unsqueeze(-1).expand(M, topk).contiguous().view(-1) + m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) + else: + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(topk_ids, 1, M, None)) + #sorted_token_ids, _ = torch.sort(sorted_token_ids, 0, descending=False) + #sorted_token_ids = sorted_token_ids.unsqueeze(-1).expand(num_groups, M).contiguous().view(-1) + # ??? + #sorted_token_ids = sorted_token_ids.unsqueeze(-1).expand(M, topk).contiguous().view(-1) + p("SORTED", sorted_token_ids) + pp(sorted_token_ids) + print(sorted_token_ids) + pp(f"mask = {sorted_token_ids == M}") + #sorted_token_ids[sorted_token_ids == 2*M] = -1 + pp(sorted_token_ids) + print(f"max = {torch.max(sorted_token_ids)}, M={M}, topk={topk}") + assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + m_indices = sorted_token_ids + #assert m_indices.numel() == num_groups * M * topk + #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") + #p("expert ids", expert_ids) + p("m_indices", m_indices) - pp(m_indices) + #print(f"m_indices {m_indices.shape} {sorted_token_ids.shape}") + #pp(m_indices) p("topk_ids", topk_ids) #pp(topk_ids) p("topk_weight", topk_weight) @@ -476,11 +505,13 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, pp("FIRST GEMM DONE") - #pp(f"DG {inter_out.shape} {inter_out}") + pp(f"DG {inter_out.shape} {inter_out}") act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) + p("act_out", act_out) + pp("SECOND GEMM") out = torch.empty(act_out.shape[0], @@ -501,23 +532,36 @@ def deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) +def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: + dimensions = [] + + for index, _ in enumerate(shape): + if index != dim: + dimension = 1 + else: + dimension = shape[index] + + dimensions = [*dimensions, dimension] + + return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) + + # topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) -#itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) -#itertools.product([2], [256], [512], [2], [1], [[128, 128]], DTYPES, SEEDS)) + #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) + #itertools.product([1], [128], [256], [3], [3], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes or supported topk - if (N % 128 != 0 or K % 128 != 0 or topk > 1): + if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk == 1 or topk > E): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") - pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}" - ) + print(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -535,39 +579,39 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - #score = torch.randn((M, E), dtype=dtype) - if False: - score = torch.empty((M, E), dtype=dtype) - for i in range(M): - score[i] = torch.full((E, ), 1.0 / (i + 1), dtype=dtype) - for i in range(score.numel()): - score.view(-1)[i] = 1.0 / (i + 1) - score = torch.zeros((M, E), dtype=dtype) + #score = torch.randn((M, E), dtype=dtype) # does not work + #score = torch.ones((M, E), dtype=dtype) # works + #score = torch.zeros((M, E), dtype=dtype) # works + #score = torch.full((M, E), 0.5, dtype=dtype) # works + #score = torch.empty((M, E), dtype=dtype) + #for i in range(M): # works + # score[i] = torch.full((E, ), 1.0 / (i + 1), dtype=dtype) + #score = torch.empty((M, E), dtype=dtype) + #for i in range(score.numel()): # works + # score.view(-1)[i] = 1.0 / (i + 1) + score = iota((M, E), dtype=dtype) p("score", score) #pp(score) - num_groups = E block_n, block_k = block_size[0], block_size[1] n_tiles_w1 = ((2 * N) + block_n - 1) // block_n k_tiles_w1 = (K + block_k - 1) // block_k n_tiles_w2 = (K + block_n - 1) // block_n k_tiles_w2 = (N + block_k - 1) // block_k - # TODO: turn these back to empty calls - w1 = torch.zeros_like(w1_bf16, dtype=torch.float8_e4m3fn) - w2 = torch.zeros_like(w2_bf16, dtype=torch.float8_e4m3fn) + # TODO: change these to zeros to test out groups + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - w1_s = torch.zeros((num_groups, n_tiles_w1, k_tiles_w1), - dtype=torch.float32) - w2_s = torch.zeros((num_groups, n_tiles_w2, k_tiles_w2), - dtype=torch.float32) + w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) + assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] # TODO: fix later - pp("For now, only convert the first group, the rest will be 0") - for i in range(num_groups): + #pp("For now, only convert the first group, the rest will be 0") + for i in range(E): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) @@ -595,10 +639,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, with set_current_vllm_config(vllm_config): if False: out = fused_moe( - a, + a, #hidden w1, w2, - score, + score, #gating topk, renormalize=False, use_fp8_w8a8=True, @@ -610,14 +654,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = deep_gemm_matmul_w8a8_block_fp8_moe( a, w1, w2, w1_s, w2_s, score, topk, block_size) else: - out = deep_gemm_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - ref_out = fused_moe( - a, + a, #hidden w1, w2, - score, + score, #gating topk, renormalize=False, use_fp8_w8a8=True, @@ -626,6 +667,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) + out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9cc8bbd68fd..e69a93d9fa5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1451,7 +1451,7 @@ def fused_moe( MoE layer. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_top note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. From 4b312179109455a11dd78f5defbfb2391ba28de1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 04:18:32 +0000 Subject: [PATCH 059/190] every other block correct Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 80 +++++++++++++------ .../layers/fused_moe/fused_moe.py | 10 ++- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 05e4de3e3f7..98eef3475db 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -228,12 +228,12 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): - print(f"{s}: {t.shape}, {t.dtype}") + print(f"{s}: {t.shape}, {t.dtype}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -436,35 +436,49 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, M, K = a.shape N = w2.shape[-1] pre_a = a - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + # to try: turn into 3d view here, do not flatten until after quantization + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + p("A'", a) + print(a) - inter_out = torch.empty((a.shape[0], w1[0].shape[0]), - dtype=torch.bfloat16, - device=a.device) - - if True: - score = torch.softmax(score, dim=-1, dtype=torch.float32) + if False: + scpore = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) + topk_ids, w_sort = topk_ids.sort() + topk_weight = torch.gather(topk_weight, dim=1, index=w_sort) else: topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) - del pre_a - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) + #del pre_a + + # pre_a.shape[0] * topk_ids.shape[1] + inter_out = torch.empty((pre_a.shape[0] * topk, w1[0].shape[0]), + dtype=torch.bfloat16, + device=a.device) block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - pp(f"BLOCK_M {block_m}") - p("A", a) + pp(f"M {M}, BLOCK_M {block_m}") + #p("A", a) _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_m) + #a_q, a_s = per_token_cast_to_fp8(a) + + #a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(topk, 1, 1).reshape(-1, a_q.shape[1]) + #a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(topk, 1, 1).reshape(-1, a_s.shape[1]) #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - if False: - m_indices = torch.arange(0, M * topk, dtype=torch.int) - #m_indices = m_indices.unsqueeze(-1).expand(M, topk).contiguous().view(-1) - m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) + if True: + m_indices = torch.arange(0, topk, dtype=torch.int) + m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) + #m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) + elif True: + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, 1, num_groups, None) + #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + m_indices = sorted_token_ids + p("SORTED", m_indices) + print(m_indices) else: sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(topk_ids, 1, M, None)) @@ -485,6 +499,9 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") #p("expert ids", expert_ids) + # must happen after align block size + #topk_weight = topk_weight.view(-1) + p("m_indices", m_indices) #print(f"m_indices {m_indices.shape} {sorted_token_ids.shape}") #pp(m_indices) @@ -494,6 +511,12 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #pp(topk_weight) pp("FIRST GEMM") + pp(f"E = {num_groups}") + p("A", a_q) + p("A_s", a_s) + p("B", w1) + p("B_s", w1_s) + p("m_indices", m_indices) if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -503,22 +526,28 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), inter_out, topk_ids, M) + p("out", inter_out) pp("FIRST GEMM DONE") - pp(f"DG {inter_out.shape} {inter_out}") + #pp(f"DG {inter_out.shape} {inter_out}") act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - p("act_out", act_out) - - pp("SECOND GEMM") - out = torch.empty(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, device=a.device) + pp("SECOND GEMM") + pp(f"E = {num_groups}") + p("A", act_out) + p("A_s", act_out_s) + p("B", w2) + p("B_s", w2_s) + p("topk_weights", topk_weight) + p("m_indices", m_indices) + if True: deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) @@ -526,6 +555,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) + p("out", out) pp("SECOND GEMM DONE") return (out.view(M, -1, w2.shape[1]) * @@ -550,9 +580,9 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) - #itertools.product([1], [128], [256], [3], [3], [[128, 128]], DTYPES, SEEDS)) + itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e69a93d9fa5..84eb8720231 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -633,6 +633,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor, **config, ) + p("fused_out", C) + print(f"END {'SECOND' if mul_routed_weight else 'FIRST'} FUSED_GEMM") + # Adapted from: https://github.com/sgl-project/sglang/pull/2628 def get_config_file_name(E: int, @@ -1304,6 +1307,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) + print(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") + print(f"FUSED A {hidden_states.shape}, {hidden_states}") + for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1360,7 +1366,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - #print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1483,6 +1489,8 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ + print(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") + if use_grouped_topk: assert num_expert_group is not None and topk_group is not None topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, From 396cfa0a276123c0c32371e6b66c315d81777554 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 21:14:46 +0000 Subject: [PATCH 060/190] working Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 50 ++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 98eef3475db..3af652dae6a 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -229,6 +229,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): print(f"{s}: {t.shape}, {t.dtype}\n{t}") + #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") pass @@ -429,6 +430,12 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +# repeat_interleaved. +# shuffle input by token ids +# unshuffle output by argsorted token ids +# argsort token ids + + def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" @@ -437,9 +444,10 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, N = w2.shape[-1] pre_a = a # to try: turn into 3d view here, do not flatten until after quantization - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + #a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + #a = a.view(M, -1, K).repeat_interleave(topk, dim=0).reshape(-1, K) # orig p("A'", a) - print(a) + #print(a) if False: scpore = torch.softmax(score, dim=-1, dtype=torch.float32) @@ -460,25 +468,26 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #p("A", a) _, block_k = block_shape[0], block_shape[1] - a_q, a_s = per_token_group_quant_fp8(a, block_m) - #a_q, a_s = per_token_cast_to_fp8(a) + #a_q, a_s = per_token_group_quant_fp8(a, block_m) + #a_q, a_s = per_token_cast_to_fp8(a) #a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(topk, 1, 1).reshape(-1, a_q.shape[1]) #a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(topk, 1, 1).reshape(-1, a_s.shape[1]) + #p("A_q", a_q) + #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) #print(f"FIRST GEMM {a_q.shape}") - if True: + if False: m_indices = torch.arange(0, topk, dtype=torch.int) m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) #m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) elif True: - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, 1, num_groups, None) + sorted_token_ids, expert_ids, _ = moe_align_block_size(topk_ids, 1, num_groups, None) #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 m_indices = sorted_token_ids - p("SORTED", m_indices) - print(m_indices) + p("SORTED", sorted_token_ids) else: sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(topk_ids, 1, M, None)) @@ -499,6 +508,25 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") #p("expert ids", expert_ids) + #a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + + a_q, a_s = per_token_group_quant_fp8(a, block_m) + p("a_s_0", a_s) + + a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) # orig + a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) # orig + + print(f"max = {topk*M}") + # gather? + a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) + a_s = a_s[sorted_token_ids] + #a_s = torch.gather(a_s, dim=0, index=sorted_token_ids.clamp((topk*M)-1).view(-1, 1).to(dtype=torch.int64)) + + m_indices = expert_ids # torch.repeat_interleave(expert_ids, topk, dim=0) + + p("a_q_s", a_q) + p("a_s_s", a_s) + # must happen after align block size #topk_weight = topk_weight.view(-1) @@ -526,7 +554,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), inter_out, topk_ids, M) - p("out", inter_out) + p("inter_out", inter_out) pp("FIRST GEMM DONE") #pp(f"DG {inter_out.shape} {inter_out}") @@ -558,7 +586,9 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("out", out) pp("SECOND GEMM DONE") - return (out.view(M, -1, w2.shape[1]) * + inv_perm = torch.argsort(sorted_token_ids) + + return (out[inv_perm].view(M, -1, w2.shape[1]) * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) From 25945cc6c687f1beae90f8dfd7b83bd788f9ec5b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 21:46:41 +0000 Subject: [PATCH 061/190] enable more tests Signed-off-by: Bill Nell --- requirements/test.txt | 6 ----- tests/kernels/test_block_fp8.py | 18 ++++++------- .../layers/fused_moe/fused_moe.py | 26 +++++++++++++------ 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 60b8faa0fa2..9a15d9a0d82 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -126,10 +126,6 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -763,11 +759,9 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common - # multidict # pqdm # pydantic # pydantic-core - # rich # torch # typer tzdata==2024.2 diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 3af652dae6a..df63fd52073 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -42,9 +42,9 @@ M_moe = [128, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] -M_moe_small = [128, 512] -N_moe_small = [128, 256] -K_moe_small = [256, 512] +M_moe_small = [128, 512, 2048] +N_moe_small = [128, 256, 4608] +K_moe_small = [256, 512, 7168] BLOCK_SIZE = [[128, 128]] E = [2, 8] #, 16] # 24 # [8, 24, 128, 256] TOP_KS = [1, 2, 6] # [1, 2, 6] @@ -228,13 +228,13 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): def p(s, t): - print(f"{s}: {t.shape}, {t.dtype}\n{t}") + #print(f"{s}: {t.shape}, {t.dtype}\n{t}") #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") pass def pp(x): - print(x) + #print(x) pass @@ -516,7 +516,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) # orig a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) # orig - print(f"max = {topk*M}") + pp(f"max = {topk*M}") # gather? a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) a_s = a_s[sorted_token_ids] @@ -610,9 +610,9 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) - itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -621,7 +621,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk == 1 or topk > E): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") - print(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 84eb8720231..1b97fcfadf1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -28,6 +28,17 @@ logger = init_logger(__name__) +def p(s, t): + #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") + #print(f"{s}: {t.shape}, {t.dtype}\n{t}") + pass + + +def pp(x): + #print(x) + pass + + @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, @@ -502,7 +513,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, M = A.shape[0] num_tokens = M * top_k - # EM = num_groups EM = sorted_token_ids.shape[0] if A.shape[0] < config["BLOCK_SIZE_M"]: # optimize for small batch_size. @@ -634,7 +644,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ) p("fused_out", C) - print(f"END {'SECOND' if mul_routed_weight else 'FIRST'} FUSED_GEMM") + pp(f"END {'SECOND' if mul_routed_weight else 'FIRST'} FUSED_GEMM") # Adapted from: https://github.com/sgl-project/sglang/pull/2628 @@ -1232,7 +1242,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None): + block_shape: Optional[List[int]] = None) -> torch.Tensor: # Check constraints. if use_int4_w4a16: assert hidden_states.shape[1] // 2 == w1.shape[ @@ -1307,8 +1317,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - print(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") - print(f"FUSED A {hidden_states.shape}, {hidden_states}") + pp(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") + pp(f"FUSED A {hidden_states.shape}, {hidden_states}") for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, @@ -1366,7 +1376,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - print(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + pp(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1457,7 +1467,7 @@ def fused_moe( MoE layer. - num_expert_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_top + - use_grouped_topk: If True, use grouped_topk instead of fused_topk note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. @@ -1489,7 +1499,7 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ - print(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") + pp(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") if use_grouped_topk: assert num_expert_group is not None and topk_group is not None From 4c3f4915bcf5d55a5ea7d6aaf3821ca5d7b0f549 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 22:04:44 +0000 Subject: [PATCH 062/190] working tests w/permute Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 8 ++++---- vllm/model_executor/layers/fused_moe/fused_moe.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index df63fd52073..26b455ad146 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -609,8 +609,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: # topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) @torch.inference_mode() @@ -618,8 +618,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes or supported topk - if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk == 1 or topk > E): - pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}, {topk}") + if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0): + pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1b97fcfadf1..19452626e49 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1318,7 +1318,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, out_hidden_states = torch.empty_like(hidden_states) pp(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") - pp(f"FUSED A {hidden_states.shape}, {hidden_states}") + #pp(f"FUSED A {hidden_states.shape}, {hidden_states}") for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, @@ -1376,7 +1376,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - pp(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") + #pp(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, @@ -1499,7 +1499,7 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ - pp(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") + #pp(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") if use_grouped_topk: assert num_expert_group is not None and topk_group is not None From 16e7d17ff4329adde742467585bf3e3b770c26fc Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 5 Mar 2025 22:18:10 +0000 Subject: [PATCH 063/190] cleanups Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 257 +++--------------- .../layers/fused_moe/fused_moe.py | 21 -- 2 files changed, 43 insertions(+), 235 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 26b455ad146..a10c7cc905c 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -12,7 +12,8 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size, fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -26,28 +27,17 @@ NUM_TOKENS = [7, 83, 2048] D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] -#M = [1, 7, 83, 512, 2048] - -M = [1, 8, 84, 512, 2048, 4096] +M = [1, 7, 8, 83, 84, 512, 2048, 4096] N = [128, 512, 1024, 4096, 7748, 13824, 7168] K = [256, 4096, 5120, 3884, 13824, 16384] - -#M = [128] -#N = [24576] -#K = [1536] - # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -#M_moe = [1, 2, 7, 83] #, 512, 2048] -M_moe = [128, 512, 2048] +M_moe = [1, 2, 7, 83, 128, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] -M_moe_small = [128, 512, 2048] -N_moe_small = [128, 256, 4608] -K_moe_small = [256, 512, 7168] BLOCK_SIZE = [[128, 128]] -E = [2, 8] #, 16] # 24 # [8, 24, 128, 256] -TOP_KS = [1, 2, 6] # [1, 2, 6] +E = [2, 8, 16, 24] +TOP_KS = [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] @@ -227,17 +217,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -def p(s, t): - #print(f"{s}: {t.shape}, {t.dtype}\n{t}") - #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") - pass - - -def pp(x): - #print(x) - pass - - @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, @@ -275,12 +254,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) - p("a", a) - p("w1", w1) - p("w1_s", w1_s) - p("w2", w2) - p("w2_s", w2_s) - with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -306,19 +279,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): assert rel_diff < 0.03 -######################################################################################### - - -def per_token_cast_to_fp8( - x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(1) % 128 == 0 - m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to( - torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) - - def per_block_cast_to_fp8( x: torch.Tensor, block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: @@ -381,29 +341,19 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -################################################################################### - -# ref_out = torch.einsum('gmk,gnk->gmn', x, y) - - def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" + """Fused moe with block-wise quantization using DeepGemm.""" + topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + B, D = a.shape - pre_a = a a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) out = torch.zeros(B * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device) - if False: - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - else: - topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) - del pre_a - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) _, block_k = block_shape[0], block_shape[1] a_q, a_s = per_token_group_quant_fp8(a, block_k) @@ -430,134 +380,45 @@ def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) -# repeat_interleaved. -# shuffle input by token ids -# unshuffle output by argsorted token ids -# argsort token ids - - def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" num_groups = w1.shape[0] M, K = a.shape N = w2.shape[-1] - pre_a = a - # to try: turn into 3d view here, do not flatten until after quantization - #a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig - #a = a.view(M, -1, K).repeat_interleave(topk, dim=0).reshape(-1, K) # orig - p("A'", a) - #print(a) - - if False: - scpore = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_ids, w_sort = topk_ids.sort() - topk_weight = torch.gather(topk_weight, dim=1, index=w_sort) - else: - topk_weight, topk_ids = fused_topk(pre_a, score.float(), topk, False) - #del pre_a - - # pre_a.shape[0] * topk_ids.shape[1] - inter_out = torch.empty((pre_a.shape[0] * topk, w1[0].shape[0]), + + topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + + inter_out = torch.empty((M * topk, w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - pp(f"M {M}, BLOCK_M {block_m}") - #p("A", a) _, block_k = block_shape[0], block_shape[1] - #a_q, a_s = per_token_group_quant_fp8(a, block_m) - #a_q, a_s = per_token_cast_to_fp8(a) - #a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(topk, 1, 1).reshape(-1, a_q.shape[1]) - #a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(topk, 1, 1).reshape(-1, a_s.shape[1]) - - #p("A_q", a_q) - - #assert w1_s.shape == (num_groups, (2 * N + 127) // 128, (K + 127) // 128) - #print(f"FIRST GEMM {a_q.shape}") - - if False: - m_indices = torch.arange(0, topk, dtype=torch.int) - m_indices = m_indices.unsqueeze(-1).expand(topk, M).contiguous().view(-1) - #m_indices = m_indices.unsqueeze(-1).contiguous().view(-1) - elif True: - sorted_token_ids, expert_ids, _ = moe_align_block_size(topk_ids, 1, num_groups, None) - #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 - m_indices = sorted_token_ids - p("SORTED", sorted_token_ids) - else: - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, 1, M, None)) - #sorted_token_ids, _ = torch.sort(sorted_token_ids, 0, descending=False) - #sorted_token_ids = sorted_token_ids.unsqueeze(-1).expand(num_groups, M).contiguous().view(-1) - # ??? - #sorted_token_ids = sorted_token_ids.unsqueeze(-1).expand(M, topk).contiguous().view(-1) - p("SORTED", sorted_token_ids) - pp(sorted_token_ids) - print(sorted_token_ids) - pp(f"mask = {sorted_token_ids == M}") - #sorted_token_ids[sorted_token_ids == 2*M] = -1 - pp(sorted_token_ids) - print(f"max = {torch.max(sorted_token_ids)}, M={M}, topk={topk}") - assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 - m_indices = sorted_token_ids - #assert m_indices.numel() == num_groups * M * topk - #pp(f"num_tokens_post_padded = {num_tokens_post_padded}") - #p("expert ids", expert_ids) - - #a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) # orig + sorted_token_ids, expert_ids, _ = moe_align_block_size( + topk_ids, 1, num_groups, None) + #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + m_indices = sorted_token_ids a_q, a_s = per_token_group_quant_fp8(a, block_m) - p("a_s_0", a_s) - a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) # orig - a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) # orig + a_q = a_q.view(a_q.shape[0], -1, + a_q.shape[1]).repeat(1, topk, + 1).reshape(-1, a_q.shape[1]) # orig + a_s = a_s.view(a_s.shape[0], -1, + a_s.shape[1]).repeat(1, topk, + 1).reshape(-1, a_s.shape[1]) # orig - pp(f"max = {topk*M}") - # gather? - a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) + a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, + ...].view(dtype=torch.float8_e4m3fn) a_s = a_s[sorted_token_ids] - #a_s = torch.gather(a_s, dim=0, index=sorted_token_ids.clamp((topk*M)-1).view(-1, 1).to(dtype=torch.int64)) - - m_indices = expert_ids # torch.repeat_interleave(expert_ids, topk, dim=0) - p("a_q_s", a_q) - p("a_s_s", a_s) + m_indices = expert_ids # torch.repeat_interleave(expert_ids, topk, dim=0) - # must happen after align block size - #topk_weight = topk_weight.view(-1) - - p("m_indices", m_indices) - #print(f"m_indices {m_indices.shape} {sorted_token_ids.shape}") - #pp(m_indices) - p("topk_ids", topk_ids) - #pp(topk_ids) - p("topk_weight", topk_weight) - #pp(topk_weight) - - pp("FIRST GEMM") - pp(f"E = {num_groups}") - p("A", a_q) - p("A_s", a_s) - p("B", w1) - p("B_s", w1_s) - p("m_indices", m_indices) - - if True: - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a_q, a_s), (w1, w1_s), inter_out, m_indices) - else: - topk_ids = topk_ids.to(dtype=torch.int32) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a_q, a_s), (w1, w1_s), - inter_out, topk_ids, M) - - p("inter_out", inter_out) - pp("FIRST GEMM DONE") - - #pp(f"DG {inter_out.shape} {inter_out}") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), + inter_out, m_indices) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) @@ -567,24 +428,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) - pp("SECOND GEMM") - pp(f"E = {num_groups}") - p("A", act_out) - p("A_s", act_out_s) - p("B", w2) - p("B_s", w2_s) - p("topk_weights", topk_weight) - p("m_indices", m_indices) - - if True: - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - else: - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( - (act_out_q, act_out_s), (w2, w2_s), out, topk_ids, M) - - p("out", out) - pp("SECOND GEMM DONE") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (act_out_q, act_out_s), (w2, w2_s), out, m_indices) inv_perm = torch.argsort(sorted_token_ids) @@ -606,13 +451,10 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) -# topk > 1 does not work @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product(M_moe_small, N_moe_small, K_moe_small, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([512], [128], [256], [2], [1], [[128, 128]], DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [2], [[128, 128]], DTYPES, SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -621,7 +463,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0): pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") - pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + #pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -639,6 +481,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) + # TODO!!!!!!!!!!!! #score = torch.randn((M, E), dtype=dtype) # does not work #score = torch.ones((M, E), dtype=dtype) # works #score = torch.zeros((M, E), dtype=dtype) # works @@ -650,7 +493,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, #for i in range(score.numel()): # works # score.view(-1)[i] = 1.0 / (i + 1) score = iota((M, E), dtype=dtype) - p("score", score) + #p("score", score) #pp(score) block_n, block_k = block_size[0], block_size[1] @@ -659,7 +502,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, n_tiles_w2 = (K + block_n - 1) // block_n k_tiles_w2 = (N + block_k - 1) // block_k - # TODO: change these to zeros to test out groups w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) @@ -669,8 +511,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - # TODO: fix later - #pp("For now, only convert the first group, the rest will be 0") for i in range(E): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) @@ -680,29 +520,19 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, # TODO: move size alignment further up when setting up all shapes if w1_sa.shape != w1_s.shape or w2_sa.shape != w2_s.shape: - p("w1_sa", w1_sa) - p("w2_sa", w2_sa) print("UNALIGNED") pytest.skip("UNALIGNED") w1_s = w1_sa w2_s = w2_sa - p("a", a) - p("w1", w1) - #print(w1) - p("w1_s", w1_s) - #print(w1_s) - p("w2", w2) - p("w2_s", w2_s) - with set_current_vllm_config(vllm_config): if False: out = fused_moe( - a, #hidden + a, w1, w2, - score, #gating + score, topk, renormalize=False, use_fp8_w8a8=True, @@ -715,10 +545,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, a, w1, w2, w1_s, w2_s, score, topk, block_size) else: ref_out = fused_moe( - a, #hidden + a, w1, w2, - score, #gating + score, topk, renormalize=False, use_fp8_w8a8=True, @@ -727,9 +557,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, - topk, block_size) - + out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 19452626e49..f21e1085841 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -28,17 +28,6 @@ logger = init_logger(__name__) -def p(s, t): - #print(f"{s}: {t.shape}, {t.dtype}\n{t.flatten()}") - #print(f"{s}: {t.shape}, {t.dtype}\n{t}") - pass - - -def pp(x): - #print(x) - pass - - @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, @@ -643,9 +632,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, **config, ) - p("fused_out", C) - pp(f"END {'SECOND' if mul_routed_weight else 'FIRST'} FUSED_GEMM") - # Adapted from: https://github.com/sgl-project/sglang/pull/2628 def get_config_file_name(E: int, @@ -1317,9 +1303,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - pp(f"NUM CHUNKS = {(num_tokens // CHUNK_SIZE) + 1}") - #pp(f"FUSED A {hidden_states.shape}, {hidden_states}") - for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1376,8 +1359,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - #pp(f"FUSED_MOE {intermediate_cache1.shape} {intermediate_cache1}") - if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1499,8 +1480,6 @@ def fused_moe( - torch.Tensor: The output tensor after applying the MoE layer. """ - #pp(f"FUSED SCORES {hidden_states.shape} {gating_output.shape}") - if use_grouped_topk: assert num_expert_group is not None and topk_group is not None topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, From b05c810cda2ac567ee1228f80724a49bc3ba1d95 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 6 Mar 2025 20:56:30 +0000 Subject: [PATCH 064/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 130 ++++++++---------- .../layers/fused_moe/fused_moe.py | 48 ++++++- 2 files changed, 102 insertions(+), 76 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index a10c7cc905c..2d3bc98d490 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -42,6 +42,15 @@ SEEDS = [0] +def p(s, t): + #print(f"{s}: {t.shape}\n{t}") + pass + +def pp(x): + #print(x) + pass + + def native_per_token_group_quant_fp8(x, group_size, eps=1e-10, @@ -341,45 +350,6 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -def deep_gemm_matmul_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_shape): - """Fused moe with block-wise quantization using DeepGemm.""" - topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, - w2.shape[1], - dtype=torch.bfloat16, - device=a.device) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = per_token_group_quant_fp8(a, block_k) - a_q = a_q.to(dtype=torch.float32) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - inter_out = torch.empty((a_q[mask].shape[0], w1[i].shape[0]), - device=a_q.device, - dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt( - (a_q[mask].to(dtype=torch.float8_e4m3fn), a_s[mask]), - (w1[i], w1_s[i]), inter_out) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - tmp_out = torch.empty((act_out.shape[0], w2[i].shape[0]), - device=a_q.device, - dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt((act_out_q, act_out_s), - (w2[i], w2_s[i]), tmp_out) - out[mask] = tmp_out - - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" @@ -397,32 +367,53 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, _, block_k = block_shape[0], block_shape[1] - sorted_token_ids, expert_ids, _ = moe_align_block_size( - topk_ids, 1, num_groups, None) + sorted_token_ids, m_indices, num_pad = moe_align_block_size( + topk_ids, 1, num_groups, None) # topk? #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 - m_indices = sorted_token_ids + + pp(f"num_pad = {num_pad}") + p("orig sorted", sorted_token_ids) + + oob_idx = (sorted_token_ids == M*topk).nonzero() + p("oob_idx", oob_idx) + + sorted_token_ids = sorted_token_ids.clamp(max=(M*topk)-1)[:M*num_groups] + inv_perm = torch.argsort(sorted_token_ids) + + p("m_indices", m_indices) + assert m_indices.numel() == M * topk a_q, a_s = per_token_group_quant_fp8(a, block_m) + # Replicate activations and scales a_q = a_q.view(a_q.shape[0], -1, a_q.shape[1]).repeat(1, topk, - 1).reshape(-1, a_q.shape[1]) # orig + 1).reshape(-1, a_q.shape[1]) a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, - 1).reshape(-1, a_s.shape[1]) # orig + 1).reshape(-1, a_s.shape[1]) + # Permute activations according to sorted token ids a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) a_s = a_s[sorted_token_ids] - m_indices = expert_ids # torch.repeat_interleave(expert_ids, topk, dim=0) + p("topk_ids", topk_ids) + p("sorted", sorted_token_ids) + p("m_indices", m_indices) + p("topk_weight", topk_weight) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) + #inter_out = inter_out[inv_perm, ...] + act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) +# act_out_q = act_out_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) +# act_out_s = act_out_s[sorted_token_ids] + out = torch.empty(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, @@ -431,11 +422,22 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - inv_perm = torch.argsort(sorted_token_ids) + out = out[inv_perm,...] + #topk_weight = topk_weight[inv_perm] + #out[:,num_pad:] = 0 + + #p("inter_out", inter_out) + p("out", out) - return (out[inv_perm].view(M, -1, w2.shape[1]) * + final_out = (out.view(M, -1, w2.shape[1]) * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + p("final_out", final_out) + + # TODO use moe_sum + + return final_out + def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: dimensions = [] @@ -453,17 +455,17 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - # only aligned sizes or supported topk - if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0): - pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") + # only aligned sizes + if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk > E): + pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") - #pp(f"\nTEST M={M}, N={N}, K={K}, E/num_groups={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -481,20 +483,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - # TODO!!!!!!!!!!!! - #score = torch.randn((M, E), dtype=dtype) # does not work - #score = torch.ones((M, E), dtype=dtype) # works - #score = torch.zeros((M, E), dtype=dtype) # works - #score = torch.full((M, E), 0.5, dtype=dtype) # works - #score = torch.empty((M, E), dtype=dtype) - #for i in range(M): # works - # score[i] = torch.full((E, ), 1.0 / (i + 1), dtype=dtype) - #score = torch.empty((M, E), dtype=dtype) - #for i in range(score.numel()): # works - # score.view(-1)[i] = 1.0 / (i + 1) + score = torch.randn((M, E), dtype=dtype) # does not work score = iota((M, E), dtype=dtype) - #p("score", score) - #pp(score) block_n, block_k = block_size[0], block_size[1] n_tiles_w1 = ((2 * N) + block_n - 1) // block_n @@ -541,9 +531,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - ref_out = deep_gemm_matmul_w8a8_block_fp8_moe( + ref_out = torch_w8a8_block_fp8_moe( a, w1, w2, w1_s, w2_s, score, topk, block_size) else: + out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) + ref_out = fused_moe( a, w1, @@ -557,9 +550,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, ) - out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f21e1085841..995ccc02102 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -27,6 +27,23 @@ logger = init_logger(__name__) +use_deep_gemm = False +if True or envs.VLLM_USE_DEEP_GEMM: + try: + import deep_gemm as dg + use_deep_gemm = True + except ImportError: + logger.warning("Failed to import DeepGemm kernels.") + + +def p(s, t): + #print(f"{s}: {t.shape}\n{t}") + pass + +def pp(x): + #print(x) + pass + @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, @@ -510,6 +527,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, # and we can skip some invalid blocks. EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config['BLOCK_SIZE_M']) + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) @@ -765,7 +783,7 @@ def get_default_config( # num_stages=3 can cause triton.runtime.errors.OutOfResources # on ROCm, set it to 2 instead. config = { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 64 if not use_deep_gemm else dg.get_m_alignment_for_contiguous_layout(), "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, @@ -800,10 +818,11 @@ def get_default_config( "GROUP_SIZE_M": 1, } else: + dg_config = use_deep_gemm and dtype == "fp8_w8a8" config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64 if not dg_config else dg.get_m_alignment_for_contiguous_layout(), + "BLOCK_SIZE_N": 64 if not dg_config else 128, + "BLOCK_SIZE_K": 32 if not dg_config else 128, "GROUP_SIZE_M": 8, } return config @@ -1303,7 +1322,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - for chunk in range((num_tokens // CHUNK_SIZE) + 1): + use_dg = False and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + + if use_dg: + print("USE_DG!!!!!!!!!!!!!") + num_chunks = 1 + assert w1_scale is not None + assert w2_scale is not None + # TODO: do this offline + w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() + w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() + else: + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens)) @@ -1335,7 +1367,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'] if not use_dg else 1, global_num_experts, expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, @@ -1396,6 +1428,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + p("fused topk", topk_ids) + p("fused sorted", sorted_token_ids) + p("fused topk_weight", topk_weights) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) From 0d15f37df6b1fb8085090d30d4c35fb5290b2db5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 6 Mar 2025 21:13:59 +0000 Subject: [PATCH 065/190] not crashing Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 4 ++-- vllm/model_executor/layers/fused_moe/fused_moe.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 2d3bc98d490..8a9eb674c15 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -455,8 +455,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 995ccc02102..adcd2817280 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -37,11 +37,11 @@ def p(s, t): - #print(f"{s}: {t.shape}\n{t}") + print(f"{s}: {t.shape}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -527,7 +527,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, # and we can skip some invalid blocks. EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config['BLOCK_SIZE_M']) - grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) @@ -1322,7 +1321,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = False and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + use_dg = valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) if use_dg: print("USE_DG!!!!!!!!!!!!!") From 9349503bf8947ffa30d26b8a120f7ff4a68e5d81 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 6 Mar 2025 22:37:58 +0000 Subject: [PATCH 066/190] baseline working integration Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 7 ++++--- vllm/model_executor/layers/fused_moe/fused_moe.py | 9 ++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 8a9eb674c15..f6de12d6564 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -33,6 +33,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] +M_moe_dg = [128, 512, 2048] N_moe = [128, 256, 4608] # [128, 4608, 13824] K_moe = [256, 512, 7168] # [256, 7168, 13824] BLOCK_SIZE = [[128, 128]] @@ -369,7 +370,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, sorted_token_ids, m_indices, num_pad = moe_align_block_size( topk_ids, 1, num_groups, None) # topk? - #assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 pp(f"num_pad = {num_pad}") p("orig sorted", sorted_token_ids) @@ -455,8 +456,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index adcd2817280..93004c8d84f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -37,11 +37,11 @@ def p(s, t): - print(f"{s}: {t.shape}\n{t}") + #print(f"{s}: {t.shape}\n{t}") pass def pp(x): - print(x) + #print(x) pass @@ -1324,13 +1324,16 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_dg = valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) if use_dg: - print("USE_DG!!!!!!!!!!!!!") + #print("USE_DG!!!!!!!!!!!!!") num_chunks = 1 + CHUNK_SIZE = num_tokens assert w1_scale is not None assert w2_scale is not None # TODO: do this offline + #print("GOT HERE A") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() + #print("GOT HERE B") else: num_chunks = (num_tokens // CHUNK_SIZE) + 1 From 3a744b363064e14bc889ca8bebc70f6df87c5289 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 6 Mar 2025 22:52:35 +0000 Subject: [PATCH 067/190] add allow_deep_gemm flag Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 1 + .../layers/fused_moe/fused_moe.py | 26 +++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index f6de12d6564..5e968d784c5 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -549,6 +549,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, + allow_deep_gemm=True ) #print(f"{out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 93004c8d84f..74925aa92f2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1029,13 +1029,14 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> None: + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + block_shape, allow_deep_gemm) def inplace_fused_experts_fake( @@ -1059,7 +1060,8 @@ def inplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> None: + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> None: pass @@ -1093,7 +1095,8 @@ def outplace_fused_experts( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1123,7 +1126,8 @@ def outplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1321,12 +1325,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) if use_dg: #print("USE_DG!!!!!!!!!!!!!") - num_chunks = 1 - CHUNK_SIZE = num_tokens + # TODO: how to test chunks? + #num_chunks = 1 + #CHUNK_SIZE = num_tokens + num_chunks = (num_tokens // CHUNK_SIZE) + 1 assert w1_scale is not None assert w2_scale is not None # TODO: do this offline @@ -1337,6 +1343,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: num_chunks = (num_tokens // CHUNK_SIZE) + 1 + if num_chunks > 1: + print("CHUNKS!!!!!!!!!!!!!!!!!!") + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1467,6 +1476,7 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, + allow_deep_gemm: bool = False, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of From 449c6a19bca403af48765249690928ab87ab55c2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 7 Mar 2025 21:48:28 +0000 Subject: [PATCH 068/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 118 +++++++++++++----- .../layers/fused_moe/fused_moe.py | 14 ++- 2 files changed, 97 insertions(+), 35 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 5e968d784c5..708cf61352d 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -44,11 +44,11 @@ def p(s, t): - #print(f"{s}: {t.shape}\n{t}") + print(f"{s}: {t.shape}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -168,6 +168,48 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch2_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """Fused moe with block-wise quantization using native torch.""" + B, D = a.shape + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + a_q = a_q.to(torch.float32) + + a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) + + assert topk_ids.numel() == a_q.shape[0] == B * topk + + for i in range(w1.shape[0]): + mask = topk_ids == i + print(f"sum = {mask.numel()}, {mask.nonzero()}") + if mask.sum(): + inter_out = native_w8a8_block_fp8_matmul(a_q[mask], + w1[i], + a_s[mask], + w1_s[i], + block_shape, + output_dtype=a.dtype) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8( + act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_fp8_matmul(act_out_q, + w2[i], + act_out_s, + w2_s[i], + block_shape, + output_dtype=a.dtype) + return (out.view(B, -1, w2.shape[1]) * + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + # Skip all tests if CUDA is not available pytest.importorskip("torch.cuda") @@ -360,39 +402,49 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) - inter_out = torch.empty((M * topk, w1[0].shape[0]), - dtype=torch.bfloat16, - device=a.device) - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() _, block_k = block_shape[0], block_shape[1] + #sorted_token_ids, m_indices, num_pad = moe_align_block_size( + # topk_ids, 1, num_groups, None) + sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, 1, num_groups, None) # topk? - assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + topk_ids, M, num_groups, None) + + pp(f"num_pad = {num_pad}, {topk_ids.numel()}, {M*topk}, {M*num_groups}") + + #sorted_token_ids = sorted_token_ids[:num_pad] + pad_size = (m_indices.numel() * M) - sorted_token_ids.numel() + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", topk*M) + p("sorted_token_ids2", sorted_token_ids) + p("orig m_indices", m_indices) + m_indices = torch.repeat_interleave(m_indices, M, dim=0) #[:num_pad] + + # M * topk + #assert topk_ids.numel() == sorted_token_ids.numel() == num_pad - pp(f"num_pad = {num_pad}") - p("orig sorted", sorted_token_ids) + mask = sorted_token_ids == topk*M # zero out a_q[mask]? + + sorted_token_ids = sorted_token_ids.clamp(max=(M*topk)-1)#[:num_pad] + + assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 - oob_idx = (sorted_token_ids == M*topk).nonzero() - p("oob_idx", oob_idx) - sorted_token_ids = sorted_token_ids.clamp(max=(M*topk)-1)[:M*num_groups] inv_perm = torch.argsort(sorted_token_ids) p("m_indices", m_indices) - assert m_indices.numel() == M * topk + #assert m_indices.numel() == M * topk a_q, a_s = per_token_group_quant_fp8(a, block_m) # Replicate activations and scales - a_q = a_q.view(a_q.shape[0], -1, - a_q.shape[1]).repeat(1, topk, - 1).reshape(-1, a_q.shape[1]) - a_s = a_s.view(a_s.shape[0], -1, - a_s.shape[1]).repeat(1, topk, - 1).reshape(-1, a_s.shape[1]) +# a_q = a_q.view(a_q.shape[0], -1, +# a_q.shape[1]).repeat(1, topk, +# 1).reshape(-1, a_q.shape[1]) +# a_s = a_s.view(a_s.shape[0], -1, +# a_s.shape[1]).repeat(1, topk, +# 1).reshape(-1, a_s.shape[1]) # Permute activations according to sorted token ids a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, @@ -401,9 +453,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("topk_ids", topk_ids) p("sorted", sorted_token_ids) - p("m_indices", m_indices) p("topk_weight", topk_weight) + p("a_q", a_q) + p("a_s", a_s) + p("m_indices", m_indices) + + inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), + dtype=torch.bfloat16, + device=a.device) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) @@ -415,7 +474,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, # act_out_q = act_out_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) # act_out_s = act_out_s[sorted_token_ids] - out = torch.empty(act_out.shape[0], + out = torch.zeros(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, device=a.device) @@ -427,11 +486,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #topk_weight = topk_weight[inv_perm] #out[:,num_pad:] = 0 - #p("inter_out", inter_out) - p("out", out) + p("inter_out", inter_out) + #p("out", out) final_out = (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) p("final_out", final_out) @@ -456,8 +515,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -485,7 +544,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, fp8_max).clamp(min=fp8_min, max=fp8_max) score = torch.randn((M, E), dtype=dtype) # does not work - score = iota((M, E), dtype=dtype) + #score = iota((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] n_tiles_w1 = ((2 * N) + block_n - 1) // block_n @@ -530,6 +589,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, + allow_deep_gemm=False ) ref_out = torch_w8a8_block_fp8_moe( @@ -549,7 +609,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, - allow_deep_gemm=True + allow_deep_gemm=False ) #print(f"{out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 74925aa92f2..9f317c13ddc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -37,11 +37,11 @@ def p(s, t): - #print(f"{s}: {t.shape}\n{t}") + print(f"{s}: {t.shape}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -530,6 +530,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) + p("fused a_q", A) + p("fused a_s", A_scale) + p("fused expert ids", expert_ids) + if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 @@ -1402,6 +1406,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + p("fused inter_out", intermediate_cache1) + if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1439,10 +1445,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - p("fused topk", topk_ids) - p("fused sorted", sorted_token_ids) - p("fused topk_weight", topk_weights) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) From 5f21c962858fecd87476194721ae12e7acce5bc2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 7 Mar 2025 22:02:44 +0000 Subject: [PATCH 069/190] better Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 23 +++++++++++++------ .../layers/fused_moe/fused_moe.py | 4 ++-- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 708cf61352d..3b7b34aa91b 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -44,11 +44,11 @@ def p(s, t): - print(f"{s}: {t.shape}\n{t}") + #print(f"{s}: {t.shape}\n{t}") pass def pp(x): - print(x) + #print(x) pass @@ -393,6 +393,11 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 +# dtype=torch.float8_e4m3fn +def fp8_perm(m, idx): + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + + def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): """Fused moe with block-wise quantization using DeepGemm torch.""" @@ -447,15 +452,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, # 1).reshape(-1, a_s.shape[1]) # Permute activations according to sorted token ids - a_q = a_q.view(dtype=torch.uint8)[sorted_token_ids, - ...].view(dtype=torch.float8_e4m3fn) + a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] + #a_q.view(dtype=torch.uint8)[mask] = 0 + p("topk_ids", topk_ids) p("sorted", sorted_token_ids) p("topk_weight", topk_weight) - p("a_q", a_q) + p("a_q", fp8_perm(a_q, inv_perm)) p("a_s", a_s) p("m_indices", m_indices) @@ -489,8 +495,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("inter_out", inter_out) #p("out", out) - final_out = (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) + #final_out = (out.view(M, -1, w2.shape[1]) * + # topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) + + final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) p("final_out", final_out) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9f317c13ddc..3162713af07 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -37,11 +37,11 @@ def p(s, t): - print(f"{s}: {t.shape}\n{t}") + #print(f"{s}: {t.shape}\n{t}") pass def pp(x): - print(x) + #print(x) pass From 5bcbd931d65c2118b23e796bc937d6d137c311c6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 8 Mar 2025 01:15:42 +0000 Subject: [PATCH 070/190] fix some stuff Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 95 +++++++++++++------ .../layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 69 insertions(+), 28 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 3b7b34aa91b..6bb1b24b120 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -44,11 +44,11 @@ def p(s, t): - #print(f"{s}: {t.shape}\n{t}") + print(f"{s}: {t.shape}\n{t}") pass def pp(x): - #print(x) + print(x) pass @@ -411,57 +411,69 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, _, block_k = block_shape[0], block_shape[1] - #sorted_token_ids, m_indices, num_pad = moe_align_block_size( - # topk_ids, 1, num_groups, None) sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, M, num_groups, None) + topk_ids, block_m, num_groups, None) pp(f"num_pad = {num_pad}, {topk_ids.numel()}, {M*topk}, {M*num_groups}") #sorted_token_ids = sorted_token_ids[:num_pad] - pad_size = (m_indices.numel() * M) - sorted_token_ids.numel() - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", topk*M) - p("sorted_token_ids2", sorted_token_ids) - p("orig m_indices", m_indices) - m_indices = torch.repeat_interleave(m_indices, M, dim=0) #[:num_pad] - # M * topk - #assert topk_ids.numel() == sorted_token_ids.numel() == num_pad + print("GOT HERE1") + + num_tokens = topk * M + + pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() + if pad_size > 0: + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) - mask = sorted_token_ids == topk*M # zero out a_q[mask]? + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) + + #m_indices = m_indices[(sorted_token_ids.numel() // 128):] + + p("sorted_token_ids", sorted_token_ids) + p("sorted_token_ids[:num_pad]", sorted_token_ids[:num_pad]) + #sorted_token_ids = sorted_token_ids[:num_pad] + p("orig m_indices", m_indices) + m_indices = torch.repeat_interleave(m_indices, M, dim=0) - sorted_token_ids = sorted_token_ids.clamp(max=(M*topk)-1)#[:num_pad] + print("GOT HERE2") - assert sorted_token_ids[sorted_token_ids >= topk*M].sum() == 0 + assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 + print("GOT HERE2A") inv_perm = torch.argsort(sorted_token_ids) p("m_indices", m_indices) - #assert m_indices.numel() == M * topk + + print("GOT HERE2B") a_q, a_s = per_token_group_quant_fp8(a, block_m) # Replicate activations and scales -# a_q = a_q.view(a_q.shape[0], -1, -# a_q.shape[1]).repeat(1, topk, -# 1).reshape(-1, a_q.shape[1]) -# a_s = a_s.view(a_s.shape[0], -1, -# a_s.shape[1]).repeat(1, topk, -# 1).reshape(-1, a_s.shape[1]) + a_q = a_q.view(a_q.shape[0], -1, + a_q.shape[1]).repeat(1, topk, + 1).reshape(-1, a_q.shape[1]) + a_s = a_s.view(a_s.shape[0], -1, + a_s.shape[1]).repeat(1, topk, + 1).reshape(-1, a_s.shape[1]) + + print("GOT HERE2C") # Permute activations according to sorted token ids a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] + print("GOT HERE3") + #a_q.view(dtype=torch.uint8)[mask] = 0 p("topk_ids", topk_ids) p("sorted", sorted_token_ids) p("topk_weight", topk_weight) - p("a_q", fp8_perm(a_q, inv_perm)) + p("a_q", a_q) p("a_s", a_s) p("m_indices", m_indices) @@ -469,9 +481,15 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) + + print("GOT HERE4") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) + + print("GOT HERE5") + #inter_out = inter_out[inv_perm, ...] act_out = SiluAndMul().forward_native(inter_out) @@ -485,10 +503,17 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) + print("GOT HERE6") + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) + print("GOT HERE7") + out = out[inv_perm,...] + + print("GOT HERE8") + #topk_weight = topk_weight[inv_perm] #out[:,num_pad:] = 0 @@ -498,8 +523,20 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #final_out = (out.view(M, -1, w2.shape[1]) * # topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) - final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + print(f"GOT HERE9 {out.shape}, {M}, {num_tokens}") + + TT = topk_weight.shape[0] + tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] + #tmp_out = out[:M, ...].view(M, -1, w2.shape[1]) + + print(f"GOT HERE10 {tmp_out.shape}, {topk_weight.shape}") + + final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + #final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * + # topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + print("GOT HERE11") p("final_out", final_out) @@ -521,11 +558,14 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) - +# topk 6 broken/slow @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -535,6 +575,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + print(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3162713af07..d14eac4a613 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1382,7 +1382,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'] if not use_dg else 1, + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, From 2eeeedf957ae5ec1ea56b8837fa070e9690806a0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 8 Mar 2025 01:20:07 +0000 Subject: [PATCH 071/190] fix more stuff Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 6bb1b24b120..ea1b127f3fb 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -44,11 +44,11 @@ def p(s, t): - print(f"{s}: {t.shape}\n{t}") + #print(f"{s}: {t.shape}\n{t}") pass def pp(x): - print(x) + #print(x) pass @@ -435,7 +435,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("sorted_token_ids[:num_pad]", sorted_token_ids[:num_pad]) #sorted_token_ids = sorted_token_ids[:num_pad] p("orig m_indices", m_indices) - m_indices = torch.repeat_interleave(m_indices, M, dim=0) + m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) print("GOT HERE2") @@ -525,7 +525,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, print(f"GOT HERE9 {out.shape}, {M}, {num_tokens}") - TT = topk_weight.shape[0] tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] #tmp_out = out[:M, ...].view(M, -1, w2.shape[1]) @@ -562,8 +561,8 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() From 8fce359136a9d756ca575acbca77bd46ba22f815 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 8 Mar 2025 01:29:57 +0000 Subject: [PATCH 072/190] cleanups Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ea1b127f3fb..ab2d3652bde 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -419,8 +419,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #sorted_token_ids = sorted_token_ids[:num_pad] - print("GOT HERE1") - num_tokens = topk * M pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() @@ -437,18 +435,12 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, p("orig m_indices", m_indices) m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) - print("GOT HERE2") - assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - print("GOT HERE2A") - inv_perm = torch.argsort(sorted_token_ids) p("m_indices", m_indices) - print("GOT HERE2B") - a_q, a_s = per_token_group_quant_fp8(a, block_m) # Replicate activations and scales @@ -459,14 +451,10 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) - print("GOT HERE2C") - # Permute activations according to sorted token ids a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] - print("GOT HERE3") - #a_q.view(dtype=torch.uint8)[mask] = 0 p("topk_ids", topk_ids) @@ -482,14 +470,10 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, device=a.device) - print("GOT HERE4") - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) - print("GOT HERE5") - #inter_out = inter_out[inv_perm, ...] act_out = SiluAndMul().forward_native(inter_out) @@ -503,17 +487,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) - print("GOT HERE6") - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - print("GOT HERE7") - out = out[inv_perm,...] - print("GOT HERE8") - #topk_weight = topk_weight[inv_perm] #out[:,num_pad:] = 0 @@ -523,20 +501,14 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, #final_out = (out.view(M, -1, w2.shape[1]) * # topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) - print(f"GOT HERE9 {out.shape}, {M}, {num_tokens}") - tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] #tmp_out = out[:M, ...].view(M, -1, w2.shape[1]) - print(f"GOT HERE10 {tmp_out.shape}, {topk_weight.shape}") - final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) #final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * # topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - print("GOT HERE11") - p("final_out", final_out) # TODO use moe_sum @@ -574,7 +546,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") - print(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -592,8 +563,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - score = torch.randn((M, E), dtype=dtype) # does not work - #score = iota((M, E), dtype=dtype) + score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] n_tiles_w1 = ((2 * N) + block_n - 1) // block_n From 25501739fd6287ab6c744fd39c0ce3170ed1c094 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 8 Mar 2025 04:21:38 +0000 Subject: [PATCH 073/190] some integration tests working Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 44 +++++-------------- .../layers/fused_moe/fused_moe.py | 3 +- 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ab2d3652bde..75afbb9b029 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -188,7 +188,7 @@ def torch2_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): for i in range(w1.shape[0]): mask = topk_ids == i - print(f"sum = {mask.numel()}, {mask.nonzero()}") + #print(f"sum = {mask.numel()}, {mask.nonzero()}") if mask.sum(): inter_out = native_w8a8_block_fp8_matmul(a_q[mask], w1[i], @@ -411,14 +411,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, _, block_k = block_shape[0], block_shape[1] - sorted_token_ids, m_indices, num_pad = moe_align_block_size( topk_ids, block_m, num_groups, None) pp(f"num_pad = {num_pad}, {topk_ids.numel()}, {M*topk}, {M*num_groups}") - #sorted_token_ids = sorted_token_ids[:num_pad] - num_tokens = topk * M pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() @@ -427,11 +424,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) - #m_indices = m_indices[(sorted_token_ids.numel() // 128):] - p("sorted_token_ids", sorted_token_ids) - p("sorted_token_ids[:num_pad]", sorted_token_ids[:num_pad]) - #sorted_token_ids = sorted_token_ids[:num_pad] p("orig m_indices", m_indices) m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) @@ -439,8 +432,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, inv_perm = torch.argsort(sorted_token_ids) - p("m_indices", m_indices) - a_q, a_s = per_token_group_quant_fp8(a, block_m) # Replicate activations and scales @@ -455,8 +446,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] - #a_q.view(dtype=torch.uint8)[mask] = 0 - p("topk_ids", topk_ids) p("sorted", sorted_token_ids) p("topk_weight", topk_weight) @@ -469,19 +458,15 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) + #print(f"inter_out {inter_out.shape}") deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) - #inter_out = inter_out[inv_perm, ...] - act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) -# act_out_q = act_out_q.view(dtype=torch.uint8)[sorted_token_ids, ...].view(dtype=torch.float8_e4m3fn) -# act_out_s = act_out_s[sorted_token_ids] - out = torch.zeros(act_out.shape[0], w2.shape[1], dtype=torch.bfloat16, @@ -492,22 +477,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, out = out[inv_perm,...] - #topk_weight = topk_weight[inv_perm] - #out[:,num_pad:] = 0 - p("inter_out", inter_out) - #p("out", out) - - #final_out = (out.view(M, -1, w2.shape[1]) * - # topk_weight.view(M, -1, 1).to(out.dtype))[:topk*M].sum(dim=1) tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] - #tmp_out = out[:M, ...].view(M, -1, w2.shape[1]) + + #print(f"tk {topk_weight.shape}, M={M} topk={topk}, N={w2.shape[1]}, out_C={out.shape}") + #print(f"tmp_out {tmp_out.shape}") final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - #final_out = (out.view(-1, topk, w2.shape[1])[:topk*M] * - # topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + #print(f"final_out {final_out.shape}") p("final_out", final_out) @@ -546,6 +525,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") + #print(f"\n\n\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") @@ -597,6 +577,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, with set_current_vllm_config(vllm_config): if False: + ref_out = torch_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size) + out = fused_moe( a, w1, @@ -608,11 +591,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, - allow_deep_gemm=False + allow_deep_gemm=True ) - - ref_out = torch_w8a8_block_fp8_moe( - a, w1, w2, w1_s, w2_s, score, topk, block_size) else: out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_size) @@ -628,7 +608,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_scale=w1_s, w2_scale=w2_s, block_shape=block_size, - allow_deep_gemm=False + allow_deep_gemm=True ) #print(f"{out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d14eac4a613..aab2f7a4cd8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1361,6 +1361,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, break if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + assert False # for now # Adjust the intermediate cache size and config for the last # chunk. Note that in most cases we only have one chunk # so the cache size and config are already set correctly and @@ -1382,7 +1383,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + moe_align_block_size(curr_topk_ids, block_m, global_num_experts, expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, From 1b19a9f1bf7cd36aafe18a8d2bfe5f30841dad08 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 16:59:11 +0000 Subject: [PATCH 074/190] almost all tests passing Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 3 +- .../layers/fused_moe/fused_moe.py | 41 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 75afbb9b029..142fd368083 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -516,6 +516,7 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: #itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([512], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -576,7 +577,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - if False: + if True: ref_out = torch_w8a8_block_fp8_moe( a, w1, w2, w1_s, w2_s, score, topk, block_size) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index aab2f7a4cd8..2c074fe7e51 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1302,6 +1302,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) + # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX + # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 cache13 = torch.empty(M * top_k_num * max(N, K), @@ -1315,6 +1317,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, device=hidden_states.device, dtype=hidden_states.dtype) + # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX + if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: @@ -1331,6 +1335,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + block_m = config['BLOCK_SIZE_M'] + assert not use_dg or block_m == 128 + if use_dg: #print("USE_DG!!!!!!!!!!!!!") # TODO: how to test chunks? @@ -1344,7 +1351,41 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() #print("GOT HERE B") + + # BIG HACK + sorted_token_ids, _, _ = ( + moe_align_block_size(topk_ids, block_m, + global_num_experts, expert_map)) + + num_tokens = top_k_num * M + pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() + if pad_size > 0: + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) + + new_M = sorted_token_ids.numel()//top_k_num + #print(f"fused2 m={M}, new_M={new_M}, sort={sorted_token_ids.shape}, hs={hidden_states.shape}, hs[sort]={hidden_states.view(num_tokens, -1)[sorted_token_ids, ...].shape}") + + intermediate_cache1 = torch.empty((new_M, top_k_num, N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((new_M * top_k_num, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((new_M, top_k_num, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) else: + intermediate_cache1 = torch.empty((M, top_k_num, N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * top_k_num, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + num_chunks = (num_tokens // CHUNK_SIZE) + 1 if num_chunks > 1: From e24d2c115af51a67df6a0fc384320601fcb8a758 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 18:19:10 +0000 Subject: [PATCH 075/190] cleanup temp construction a bit Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 4 ++-- vllm/model_executor/layers/fused_moe/fused_moe.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 142fd368083..828f258a877 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -512,9 +512,9 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) # all work #itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([128], [128], [256], [8], [6], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) #itertools.product([512], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2c074fe7e51..93075fc78cb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1353,7 +1353,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, #print("GOT HERE B") # BIG HACK - sorted_token_ids, _, _ = ( + sorted_token_ids, _, pad = ( moe_align_block_size(topk_ids, block_m, global_num_experts, expert_map)) @@ -1363,8 +1363,15 @@ def fused_experts_impl(hidden_states: torch.Tensor, sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) - new_M = sorted_token_ids.numel()//top_k_num - #print(f"fused2 m={M}, new_M={new_M}, sort={sorted_token_ids.shape}, hs={hidden_states.shape}, hs[sort]={hidden_states.view(num_tokens, -1)[sorted_token_ids, ...].shape}") + #new_M = sorted_token_ids.numel()//top_k_num + #print(f"fused2 m={M}, sort={sorted_token_ids.shape}, pad={pad}, hs={hidden_states.shape}, num_tok={num_tokens}") + #print(f"hs[sort]={torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape}") + new_S = torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape + #new_top_k = new_S[0] // M + new_M = new_S[0] // top_k_num + #new_M = ((new_M + block_m - 1) // block_m) * block_m + #print(f"fused2 new_M_b={new_M} top_k = {top_k_num}, new_top_k={new_top_k}") + #top_k_num = new_top_k intermediate_cache1 = torch.empty((new_M, top_k_num, N), device=hidden_states.device, From c4a1a2cea661e02d9a1f6597dbb3c41edf831b02 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 20:09:59 +0000 Subject: [PATCH 076/190] fix rest of tests Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 8 +--- .../layers/fused_moe/fused_moe.py | 38 +++++++------------ 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 828f258a877..1bd4d20f112 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -511,12 +511,7 @@ def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: # topk 6 broken/slow @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - itertools.product(M_moe_dg, N_moe, K_moe, E, [1, 2, 4], BLOCK_SIZE, DTYPES, SEEDS)) # all work - #itertools.product([512], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [8], [6], BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([128], [128], [256], [2], [2], BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([512], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): @@ -526,7 +521,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") - #print(f"\n\n\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") torch.set_printoptions(profile="full") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 93075fc78cb..d0b303ea042 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -530,10 +530,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) - p("fused a_q", A) - p("fused a_s", A_scale) - p("fused expert ids", expert_ids) - if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 @@ -1339,20 +1335,22 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == 128 if use_dg: - #print("USE_DG!!!!!!!!!!!!!") # TODO: how to test chunks? - #num_chunks = 1 - #CHUNK_SIZE = num_tokens - num_chunks = (num_tokens // CHUNK_SIZE) + 1 + if False: + num_chunks = 1 + CHUNK_SIZE = num_tokens + else: + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + assert w1_scale is not None assert w2_scale is not None + # TODO: do this offline - #print("GOT HERE A") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - #print("GOT HERE B") - # BIG HACK + + # TODO: this could be smarter sorted_token_ids, _, pad = ( moe_align_block_size(topk_ids, block_m, global_num_experts, expert_map)) @@ -1362,24 +1360,16 @@ def fused_experts_impl(hidden_states: torch.Tensor, if pad_size > 0: sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) - - #new_M = sorted_token_ids.numel()//top_k_num - #print(f"fused2 m={M}, sort={sorted_token_ids.shape}, pad={pad}, hs={hidden_states.shape}, num_tok={num_tokens}") - #print(f"hs[sort]={torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape}") new_S = torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape - #new_top_k = new_S[0] // M - new_M = new_S[0] // top_k_num - #new_M = ((new_M + block_m - 1) // block_m) * block_m - #print(f"fused2 new_M_b={new_M} top_k = {top_k_num}, new_top_k={new_top_k}") - #top_k_num = new_top_k + new_M = new_S[0] - intermediate_cache1 = torch.empty((new_M, top_k_num, N), + intermediate_cache1 = torch.empty((new_M, N), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache2 = torch.empty((new_M * top_k_num, N // 2), + intermediate_cache2 = torch.empty((new_M, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((new_M, top_k_num, w2.shape[1]), + intermediate_cache3 = torch.empty((new_M, w2.shape[1]), device=hidden_states.device, dtype=hidden_states.dtype) else: @@ -1455,8 +1445,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - p("fused inter_out", intermediate_cache1) - if activation == "silu": torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) From ce573fd6c9c1bd8551c9ef21c34bc94a39aa19ba Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 20:52:10 +0000 Subject: [PATCH 077/190] cleanups + format Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 212 ++++++------------ .../layers/fused_moe/fused_moe.py | 45 ++-- 2 files changed, 90 insertions(+), 167 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 1bd4d20f112..3fe432e61b1 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -1,11 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/sgl-project/sglang/pull/2575 -# TODO: try/catch this? + import itertools from typing import Tuple -import deep_gemm +dg_available = False +try: + import deep_gemm + dg_available = True +except: + pass + import pytest import torch @@ -28,30 +34,21 @@ D = [512, 4096, 5120, 13824] GROUP_SIZE = [64, 128, 256, 512] M = [1, 7, 8, 83, 84, 512, 2048, 4096] -N = [128, 512, 1024, 4096, 7748, 13824, 7168] +N = [128, 512, 1024, 4096, 7168, 7748, 13824] K = [256, 4096, 5120, 3884, 13824, 16384] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] M_moe_dg = [128, 512, 2048] -N_moe = [128, 256, 4608] # [128, 4608, 13824] -K_moe = [256, 512, 7168] # [256, 7168, 13824] +N_moe = [128, 256, 4608] # [13824] +K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16, 24] +E = [2, 8, 16, 24] # [128, 256] TOP_KS = [1, 2, 6] OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] SEEDS = [0] -def p(s, t): - #print(f"{s}: {t.shape}\n{t}") - pass - -def pp(x): - #print(x) - pass - - def native_per_token_group_quant_fp8(x, group_size, eps=1e-10, @@ -168,48 +165,6 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) -def torch2_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" - B, D = a.shape - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = native_per_token_group_quant_fp8(a, block_k) - a_q = a_q.to(torch.float32) - - a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - a_s = a_s.view(a_s.shape[0], -1, a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) - - assert topk_ids.numel() == a_q.shape[0] == B * topk - - for i in range(w1.shape[0]): - mask = topk_ids == i - #print(f"sum = {mask.numel()}, {mask.nonzero()}") - if mask.sum(): - inter_out = native_w8a8_block_fp8_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k) - act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_fp8_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - # Skip all tests if CUDA is not available pytest.importorskip("torch.cuda") @@ -306,6 +261,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -393,14 +349,13 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): assert rel_diff < 0.001 -# dtype=torch.float8_e4m3fn def fp8_perm(m, idx): - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=torch.float8_e4m3fn) def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using DeepGemm torch.""" + """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" num_groups = w1.shape[0] M, K = a.shape N = w2.shape[-1] @@ -414,18 +369,17 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, sorted_token_ids, m_indices, num_pad = moe_align_block_size( topk_ids, block_m, num_groups, None) - pp(f"num_pad = {num_pad}, {topk_ids.numel()}, {M*topk}, {M*num_groups}") - num_tokens = topk * M - pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() + pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * + block_m) - sorted_token_ids.numel() if pad_size > 0: - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, + (0, pad_size), "constant", + num_tokens) - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - p("sorted_token_ids", sorted_token_ids) - p("orig m_indices", m_indices) m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 @@ -436,34 +390,21 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, # Replicate activations and scales a_q = a_q.view(a_q.shape[0], -1, - a_q.shape[1]).repeat(1, topk, - 1).reshape(-1, a_q.shape[1]) + a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) a_s = a_s.view(a_s.shape[0], -1, - a_s.shape[1]).repeat(1, topk, - 1).reshape(-1, a_s.shape[1]) + a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) # Permute activations according to sorted token ids a_q = fp8_perm(a_q, sorted_token_ids) a_s = a_s[sorted_token_ids] - p("topk_ids", topk_ids) - p("sorted", sorted_token_ids) - p("topk_weight", topk_weight) - - p("a_q", a_q) - p("a_s", a_s) - p("m_indices", m_indices) - inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) - #print(f"inter_out {inter_out.shape}") - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), inter_out, m_indices) - act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) @@ -475,54 +416,31 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - out = out[inv_perm,...] - - p("inter_out", inter_out) - - tmp_out = out.view(-1, topk, w2.shape[1])[:M, ...] + out = out[inv_perm, ...] - #print(f"tk {topk_weight.shape}, M={M} topk={topk}, N={w2.shape[1]}, out_C={out.shape}") - #print(f"tmp_out {tmp_out.shape}") + tmp_out = out[:(M * topk), ...].view(-1, topk, w2.shape[1]) final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - #print(f"final_out {final_out.shape}") - - p("final_out", final_out) - - # TODO use moe_sum + # TODO use moe_sum? return final_out -def iota(shape: Tuple[int, ...], dim: int = 0, **kwargs) -> torch.Tensor: - dimensions = [] - - for index, _ in enumerate(shape): - if index != dim: - dimension = 1 - else: - dimension = shape[index] - - dimensions = [*dimensions, dimension] - - return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) - -# topk 6 broken/slow @pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + "M,N,K,E,topk,block_size,dtype,seed,test_baseline", + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS, [True, False])) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, - dtype, seed): + dtype, seed, test_baseline): # only aligned sizes if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk > E): - pytest.skip(f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}") - - pp(f"\nTEST M={M}, N={N}, K={K}, E={E}, topk={topk}, block_size={block_size}, dtype={dtype}") - - torch.set_printoptions(profile="full") + pytest.skip( + f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}" + ) vllm_config = VllmConfig() @@ -571,40 +489,36 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - if True: - ref_out = torch_w8a8_block_fp8_moe( - a, w1, w2, w1_s, w2_s, score, topk, block_size) - - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True - ) + if not test_baseline: + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + out = fused_moe(a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + allow_deep_gemm=True) else: - out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - - ref_out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True - ) + ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) + + out = fused_moe(a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + allow_deep_gemm=True) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d0b303ea042..99010bc0c23 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -40,6 +40,7 @@ def p(s, t): #print(f"{s}: {t.shape}\n{t}") pass + def pp(x): #print(x) pass @@ -819,10 +820,15 @@ def get_default_config( else: dg_config = use_deep_gemm and dtype == "fp8_w8a8" config = { - "BLOCK_SIZE_M": 64 if not dg_config else dg.get_m_alignment_for_contiguous_layout(), - "BLOCK_SIZE_N": 64 if not dg_config else 128, - "BLOCK_SIZE_K": 32 if not dg_config else 128, - "GROUP_SIZE_M": 8, + "BLOCK_SIZE_M": + 64 + if not dg_config else dg.get_m_alignment_for_contiguous_layout(), + "BLOCK_SIZE_N": + 64 if not dg_config else 128, + "BLOCK_SIZE_K": + 32 if not dg_config else 128, + "GROUP_SIZE_M": + 8, } return config @@ -1329,14 +1335,15 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, use_fp8_w8a8) + use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, + use_fp8_w8a8) block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == 128 if use_dg: # TODO: how to test chunks? - if False: + if True: num_chunks = 1 CHUNK_SIZE = num_tokens else: @@ -1349,18 +1356,21 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - # TODO: this could be smarter - sorted_token_ids, _, pad = ( - moe_align_block_size(topk_ids, block_m, - global_num_experts, expert_map)) + sorted_token_ids, _, _ = (moe_align_block_size(topk_ids, block_m, + global_num_experts, + expert_map)) num_tokens = top_k_num * M - pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() + pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * + block_m) - sorted_token_ids.numel() if pad_size > 0: - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens-1) - new_S = torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape + sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, + (0, pad_size), + "constant", num_tokens) + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + new_S = torch.repeat_interleave(hidden_states, top_k_num, + dim=0)[sorted_token_ids, ...].shape new_M = new_S[0] intermediate_cache1 = torch.empty((new_M, N), @@ -1399,7 +1409,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, break if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - assert False # for now + assert not use_dg # for now # Adjust the intermediate cache size and config for the last # chunk. Note that in most cases we only have one chunk # so the cache size and config are already set correctly and @@ -1421,8 +1431,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, block_m, - global_num_experts, expert_map)) + moe_align_block_size(curr_topk_ids, block_m, global_num_experts, + expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, w1, @@ -1484,7 +1494,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx]) - return out_hidden_states From b402f573150b9564da618fe203ec947a9308933e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 22:52:59 +0000 Subject: [PATCH 078/190] do more of output computation in place Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 99010bc0c23..6fbf2542ebd 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1524,7 +1524,7 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False, + allow_deep_gemm: bool = True, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of From 4494c1746b274c01c56322e1839f0a61641cb406 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 10 Mar 2025 22:57:19 +0000 Subject: [PATCH 079/190] add env var Signed-off-by: Bill Nell --- .../model_executor/layers/fused_moe/fused_moe.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6fbf2542ebd..3beeae16b18 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -28,24 +28,15 @@ logger = init_logger(__name__) use_deep_gemm = False -if True or envs.VLLM_USE_DEEP_GEMM: +if envs.VLLM_USE_DEEP_GEMM: try: import deep_gemm as dg + logger.info("Using DeepGemm for fused MoE.") use_deep_gemm = True except ImportError: logger.warning("Failed to import DeepGemm kernels.") -def p(s, t): - #print(f"{s}: {t.shape}\n{t}") - pass - - -def pp(x): - #print(x) - pass - - @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, @@ -1395,9 +1386,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 - if num_chunks > 1: - print("CHUNKS!!!!!!!!!!!!!!!!!!") - for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, From afee1df6903259f706b719184ce44dc53e36c392 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 04:23:27 +0000 Subject: [PATCH 080/190] formatting, remove some blocking restrictions Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 24 +++++++++---------- .../layers/fused_moe/fused_moe.py | 15 ++++++------ .../compressed_tensors_moe.py | 2 ++ .../model_executor/layers/quantization/fp8.py | 1 + 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 3fe432e61b1..9e6bfc4018e 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -1,17 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from https://github.com/sgl-project/sglang/pull/2575 - import itertools from typing import Tuple -dg_available = False -try: - import deep_gemm - dg_available = True -except: - pass - import pytest import torch @@ -24,6 +16,13 @@ per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +dg_available = False +try: + import deep_gemm + dg_available = True +except ImportError: + pass + if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) @@ -39,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 512, 2048] +M_moe_dg = [128, 192, 512, 2048] N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] @@ -358,7 +357,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" num_groups = w1.shape[0] M, K = a.shape - N = w2.shape[-1] topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) @@ -437,10 +435,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed, test_baseline): # only aligned sizes - if (M % 128 != 0 or N % 128 != 0 or K % 128 != 0 or topk > E): + if ((M % 128 != 0 and test_baseline) or N % 128 != 0 or K % 128 != 0 + or topk > E): pytest.skip( - f"Skipping test; invalid size m={M}, n={N}, k={K}, topk={topk}, E={E}" - ) + f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") vllm_config = VllmConfig() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3beeae16b18..19f95638ecc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1326,24 +1326,23 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, config, + use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, use_fp8_w8a8) block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == 128 if use_dg: - # TODO: how to test chunks? - if True: - num_chunks = 1 - CHUNK_SIZE = num_tokens - else: - num_chunks = (num_tokens // CHUNK_SIZE) + 1 + if M % 128 != 0: + CHUNK_SIZE = (M // 128) * 128 + num_chunks = (num_tokens // CHUNK_SIZE) + 1 assert w1_scale is not None assert w2_scale is not None - # TODO: do this offline + # We attempt to do this offline in Fp8MoEMethod, in which case these + # calls will be nops. Otherwise, they'll be performed every time the + # layer is executed. w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ae16a20cfaa..ec7e93b754c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -275,6 +275,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts + # TODO: do we need to do deep gemm alignment here? + def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f7056016fe8..4549bee01ef 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -428,6 +428,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + self.allow_deep_gemm = use_deep_gemm # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization From 6872005f17c5794362f36bc909c5fddffb19eeda Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 12:59:57 +0000 Subject: [PATCH 081/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 2 +- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 9e6bfc4018e..036baace5d7 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 192, 512, 2048] +M_moe_dg = [128, 512, 2048] #192 N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 19f95638ecc..16b429a010e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1333,6 +1333,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == 128 if use_dg: + #print("USE_DG") if M % 128 != 0: CHUNK_SIZE = (M // 128) * 128 num_chunks = (num_tokens // CHUNK_SIZE) + 1 @@ -1347,11 +1348,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() # TODO: this could be smarter + num_tokens = top_k_num * M sorted_token_ids, _, _ = (moe_align_block_size(topk_ids, block_m, global_num_experts, expert_map)) - num_tokens = top_k_num * M pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m) - sorted_token_ids.numel() if pad_size > 0: @@ -1361,6 +1362,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) new_S = torch.repeat_interleave(hidden_states, top_k_num, dim=0)[sorted_token_ids, ...].shape + #new_M = hidden_states.shape[0] * top_k_num * global_num_experts new_M = new_S[0] intermediate_cache1 = torch.empty((new_M, N), From 98b3256ce589f4fb50a4c94e7da7e7004ab33474 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 13:19:08 +0000 Subject: [PATCH 082/190] fix resizing of output Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 16b429a010e..fe6d85d0522 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -27,14 +27,12 @@ logger = init_logger(__name__) -use_deep_gemm = False -if envs.VLLM_USE_DEEP_GEMM: - try: - import deep_gemm as dg - logger.info("Using DeepGemm for fused MoE.") - use_deep_gemm = True - except ImportError: - logger.warning("Failed to import DeepGemm kernels.") +has_deep_gemm = False +try: + import deep_gemm as dg + has_deep_gemm = True +except ImportError: + pass @triton.jit @@ -767,6 +765,7 @@ def get_default_config( dtype: Optional[str], is_marlin: bool, block_shape: Optional[List[int]] = None, + use_deep_gemm: bool = False, ) -> Dict[str, int]: if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] @@ -832,6 +831,7 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[List[int]] = None, + use_deep_gemm: bool = False, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -853,7 +853,7 @@ def try_get_optimal_moe_config( else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - is_marlin, block_shape) + is_marlin, block_shape, use_deep_gemm) return config @@ -1291,6 +1291,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, top_k_num, config_dtype, block_shape=block_shape, + use_deep_gemm=has_deep_gemm and allow_deep_gemm, # hacky ) config = get_config_func(M) @@ -1333,8 +1334,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == 128 if use_dg: - #print("USE_DG") - if M % 128 != 0: + if False and M % 128 != 0: CHUNK_SIZE = (M // 128) * 128 num_chunks = (num_tokens // CHUNK_SIZE) + 1 From 9060108ea8178dcb74f6e80f81b41f8d66427870 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 13:19:41 +0000 Subject: [PATCH 083/190] fix resizing of output Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 2 +- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 036baace5d7..9e6bfc4018e 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 512, 2048] #192 +M_moe_dg = [128, 192, 512, 2048] N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fe6d85d0522..7b7a3f1f48d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1334,7 +1334,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == 128 if use_dg: - if False and M % 128 != 0: + if M % 128 != 0: CHUNK_SIZE = (M // 128) * 128 num_chunks = (num_tokens // CHUNK_SIZE) + 1 From 411fc7ada88c83e4e0730dfcb4ff6589b28b2d93 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 13:32:40 +0000 Subject: [PATCH 084/190] fixes Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 2 +- .../layers/fused_moe/fused_moe.py | 18 +++++------------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 9e6bfc4018e..4855fdb6995 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 192, 512, 2048] +M_moe_dg = [128, 512, 2048] # 192 N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7b7a3f1f48d..ce0ab20a316 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1347,23 +1347,12 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - # TODO: this could be smarter - num_tokens = top_k_num * M + # TODO: computing new_M could be smarter sorted_token_ids, _, _ = (moe_align_block_size(topk_ids, block_m, global_num_experts, expert_map)) - pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * - block_m) - sorted_token_ids.numel() - if pad_size > 0: - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, - (0, pad_size), - "constant", num_tokens) - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - new_S = torch.repeat_interleave(hidden_states, top_k_num, - dim=0)[sorted_token_ids, ...].shape - #new_M = hidden_states.shape[0] * top_k_num * global_num_experts - new_M = new_S[0] + new_M = ((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m intermediate_cache1 = torch.empty((new_M, N), device=hidden_states.device, @@ -1387,6 +1376,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 + # TODO: modify CHUNK_SIZE to be % 128 == 0 and check if each chunk is + # valid dg. fall back to old kernel if not + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, From b773fdc6a70f8d74856928ccd5f51f8082d3bec4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 12 Mar 2025 22:29:32 +0000 Subject: [PATCH 085/190] aligned chunking working for deep gemm Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 9 ++--- .../layers/fused_moe/fused_moe.py | 34 ++++++++++++------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 4855fdb6995..599909a7056 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -427,15 +427,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed,test_baseline", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS, [True, False])) + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) + #itertools.product([254],[128],[256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True])) + #itertools.product([512],[128],[256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True, False])) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed, test_baseline): # only aligned sizes - if ((M % 128 != 0 and test_baseline) or N % 128 != 0 or K % 128 != 0 + if ((M % 128 != 0 and not test_baseline) or N % 128 != 0 or K % 128 != 0 or topk > E): pytest.skip( f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") @@ -487,7 +488,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - if not test_baseline: + if test_baseline: ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ce0ab20a316..c934fa55df9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1333,10 +1333,16 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == 128 + chunked_dg = False if use_dg: + #print("USE_DG") + #CHUNK_SIZE = 128 if M % 128 != 0: CHUNK_SIZE = (M // 128) * 128 + #print(f"DG_CHUNK {CHUNK_SIZE}") + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + chunked_dg = num_chunks > 1 assert w1_scale is not None assert w2_scale is not None @@ -1386,20 +1392,12 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape + skip_dg = tokens_in_chunk % 128 != 0 + if tokens_in_chunk == 0: break - if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - assert not use_dg # for now - # Adjust the intermediate cache size and config for the last - # chunk. Note that in most cases we only have one chunk - # so the cache size and config are already set correctly and - # do not need to be adjusted. - intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.shape[1]] - intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] - config = get_config_func(tokens_in_chunk) + #print(f"LOOP skip={skip_dg} tic={tokens_in_chunk}, chunk={chunk}") curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] @@ -1473,8 +1471,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + if use_dg and not skip_dg: + assert inv_perm is not None + M = curr_topk_weights.shape[0] + out_C = intermediate_cache3[inv_perm, ...] + out_C = out_C[:(M * top_k_num), ...] + out_C = out_C.view(-1, top_k_num, w2.shape[1]) + out_C.mul_(curr_topk_weights.view(M, -1, 1)) + tmp_cache3 = out_C + else: + tmp_cache3 = intermediate_cache3.view(*intermediate_cache3.shape) + + ops.moe_sum(tmp_cache3, out_hidden_states[begin_chunk_idx:end_chunk_idx]) + return out_hidden_states From ae2c7916d5a016999c006aded02496f738053a48 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 13 Mar 2025 00:03:55 +0000 Subject: [PATCH 086/190] unaligned chunking for deep gemm Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 4 +--- vllm/model_executor/layers/fused_moe/fused_moe.py | 7 +------ 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 599909a7056..b4fe9135a88 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 512, 2048] # 192 +M_moe_dg = [128, 192, 512, 1335, 2048] N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] @@ -428,8 +428,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed,test_baseline", itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) - #itertools.product([254],[128],[256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True])) - #itertools.product([512],[128],[256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True, False])) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c934fa55df9..e3140af1b6f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1335,11 +1335,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, chunked_dg = False if use_dg: - #print("USE_DG") - #CHUNK_SIZE = 128 if M % 128 != 0: - CHUNK_SIZE = (M // 128) * 128 - #print(f"DG_CHUNK {CHUNK_SIZE}") + CHUNK_SIZE = (M // 128) * 128 # min with env? num_chunks = (num_tokens // CHUNK_SIZE) + 1 chunked_dg = num_chunks > 1 @@ -1397,8 +1394,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break - #print(f"LOOP skip={skip_dg} tic={tokens_in_chunk}, chunk={chunk}") - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] From 269f18b9a2b8b97320b5e60e8549d689b3b6ccd2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 13 Mar 2025 17:17:27 +0000 Subject: [PATCH 087/190] cleanup wip Signed-off-by: Bill Nell --- requirements/test.txt | 6 +++ tests/kernels/test_block_fp8.py | 3 +- .../layers/fused_moe/fused_moe.py | 41 +++++++++---------- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 9a15d9a0d82..60b8faa0fa2 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -126,6 +126,10 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval +exceptiongroup==1.2.2 + # via + # anyio + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -759,9 +763,11 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer tzdata==2024.2 diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index b4fe9135a88..9ba3a105cc5 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -427,7 +427,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed,test_baseline", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) + #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) + itertools.product([192], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True, False])) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e3140af1b6f..07ba51f0c89 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -520,7 +520,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) - if (use_int8_w8a16 or use_int4_w4a16) and \ + if use_dg: + # Note: we do not apply weights here since it requires + # resizing the output. + dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (A, A_scale), (B, B_scale), C, expert_ids) + + elif (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 @@ -808,17 +814,11 @@ def get_default_config( "GROUP_SIZE_M": 1, } else: - dg_config = use_deep_gemm and dtype == "fp8_w8a8" config = { - "BLOCK_SIZE_M": - 64 - if not dg_config else dg.get_m_alignment_for_contiguous_layout(), - "BLOCK_SIZE_N": - 64 if not dg_config else 128, - "BLOCK_SIZE_K": - 32 if not dg_config else 128, - "GROUP_SIZE_M": - 8, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, } return config @@ -831,7 +831,6 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[List[int]] = None, - use_deep_gemm: bool = False, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -853,7 +852,7 @@ def try_get_optimal_moe_config( else: # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - is_marlin, block_shape, use_deep_gemm) + is_marlin, block_shape) return config @@ -1291,7 +1290,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, top_k_num, config_dtype, block_shape=block_shape, - use_deep_gemm=has_deep_gemm and allow_deep_gemm, # hacky ) config = get_config_func(M) @@ -1330,13 +1328,15 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, use_fp8_w8a8) - block_m = config['BLOCK_SIZE_M'] - assert not use_dg or block_m == 128 + config_block_m = config['BLOCK_SIZE_M'] + block_m = config_block_m if not use_dg else dg.get_m_alignment_for_contiguous_layout() + + assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() chunked_dg = False if use_dg: - if M % 128 != 0: - CHUNK_SIZE = (M // 128) * 128 # min with env? + if M % block_m != 0: + CHUNK_SIZE = min((M //block_m) * block_m, CHUNK_SIZE) num_chunks = (num_tokens // CHUNK_SIZE) + 1 chunked_dg = num_chunks > 1 @@ -1379,9 +1379,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 - # TODO: modify CHUNK_SIZE to be % 128 == 0 and check if each chunk is - # valid dg. fall back to old kernel if not - for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1389,7 +1386,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape - skip_dg = tokens_in_chunk % 128 != 0 + skip_dg = use_dg and tokens_in_chunk % 128 != 0 #block_m != 0 if tokens_in_chunk == 0: break From 0ea9a5d7239d1fce1d2f5ec38ee27535a114964e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 13 Mar 2025 20:51:59 +0000 Subject: [PATCH 088/190] clean up some blocking stuff Signed-off-by: Bill Nell --- requirements/test.txt | 6 -- tests/kernels/test_block_fp8.py | 67 +++++++++---------- .../layers/fused_moe/fused_moe.py | 19 +++--- 3 files changed, 42 insertions(+), 50 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 60b8faa0fa2..9a15d9a0d82 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -126,10 +126,6 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -763,11 +759,9 @@ typing-extensions==4.12.2 # huggingface-hub # librosa # mistral-common - # multidict # pqdm # pydantic # pydantic-core - # rich # torch # typer tzdata==2024.2 diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 9ba3a105cc5..ec16ac30a77 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -38,7 +38,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [128, 192, 512, 1335, 2048] +M_moe_dg = [1, 128, 192, 512, 1335, 2048] N_moe = [128, 256, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] @@ -426,17 +426,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed,test_baseline", - #itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS, [True, False])) - itertools.product([192], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS, [True, False])) + "M,N,K,E,topk,block_size,dtype,seed", + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) + #itertools.product([192], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, - dtype, seed, test_baseline): + dtype, seed): # only aligned sizes - if ((M % 128 != 0 and not test_baseline) or N % 128 != 0 or K % 128 != 0 - or topk > E): + if (N % 128 != 0 or K % 128 != 0 or topk > E): pytest.skip( f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") @@ -487,36 +486,26 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - if test_baseline: - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - - out = fused_moe(a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) + + if M % 128 == 0: + ref_out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) else: - ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - - out = fused_moe(a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True) + ref_out2 = None + + out = fused_moe(a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + allow_deep_gemm=True) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") @@ -525,3 +514,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 + + if ref_out2 is not None: + rel_diff = (torch.mean( + torch.abs(out.to(torch.float32) - ref_out2.to(torch.float32))) / + torch.mean(torch.abs(ref_out2.to(torch.float32)))) + assert rel_diff < 0.03 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 07ba51f0c89..a546ada6476 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -521,6 +521,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B.shape[1], META['BLOCK_SIZE_N']), ) if use_dg: + assert use_fp8_w8a8 # Note: we do not apply weights here since it requires # resizing the output. dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( @@ -771,7 +772,6 @@ def get_default_config( dtype: Optional[str], is_marlin: bool, block_shape: Optional[List[int]] = None, - use_deep_gemm: bool = False, ) -> Dict[str, int]: if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] @@ -831,6 +831,7 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[List[int]] = None, + use_deep_gemm: bool = False, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -853,6 +854,12 @@ def try_get_optimal_moe_config( # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) + + + # Remove this + if use_deep_gemm: + config['BLOCK_SIZE_M'] = 128 + return config @@ -1325,12 +1332,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - use_dg = allow_deep_gemm and valid_deep_gemm(hidden_states, w1, w2, - use_fp8_w8a8) - - config_block_m = config['BLOCK_SIZE_M'] - block_m = config_block_m if not use_dg else dg.get_m_alignment_for_contiguous_layout() - + block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() chunked_dg = False @@ -1379,6 +1381,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1386,7 +1389,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape - skip_dg = use_dg and tokens_in_chunk % 128 != 0 #block_m != 0 + skip_dg = use_dg and tokens_in_chunk % block_m != 0 if tokens_in_chunk == 0: break From 52f53b3943dda22e820353b1216418285c546696 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 13 Mar 2025 20:58:38 +0000 Subject: [PATCH 089/190] clean up some blocking stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a546ada6476..9d757f63c47 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -855,10 +855,9 @@ def try_get_optimal_moe_config( config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - - # Remove this + # Try to remove this if use_deep_gemm: - config['BLOCK_SIZE_M'] = 128 + config['BLOCK_SIZE_M'] = dg.get_m_alignment_for_contiguous_layout() return config @@ -1389,11 +1388,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape - skip_dg = use_dg and tokens_in_chunk % block_m != 0 - if tokens_in_chunk == 0: break + skip_dg = use_dg and tokens_in_chunk % block_m != 0 + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] From de0135c45460103939ece26e49632fb1dbeabb70 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 14 Mar 2025 23:40:03 +0000 Subject: [PATCH 090/190] tweaks Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 18 ++++++++++-------- .../layers/fused_moe/fused_moe.py | 3 +-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ec16ac30a77..11d35ec345a 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -427,17 +427,18 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, @pytest.mark.parametrize( "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) - #itertools.product([192], [128], [256], [2], [1], BLOCK_SIZE, DTYPES, SEEDS)) + itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, + SEEDS)) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # only aligned sizes - if (N % 128 != 0 or K % 128 != 0 or topk > E): + if (N % 128 != 0 or K % 128 != 0 or topk > E or block_size != [128, 128]): pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}, " + f"block_size={block_size}") vllm_config = VllmConfig() @@ -486,12 +487,13 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_s = w2_sa with set_current_vllm_config(vllm_config): - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) if M % 128 == 0: - ref_out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) + ref_out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, + w2_s, score, topk, + block_size) else: ref_out2 = None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9d757f63c47..c7447a74c48 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1337,7 +1337,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, chunked_dg = False if use_dg: if M % block_m != 0: - CHUNK_SIZE = min((M //block_m) * block_m, CHUNK_SIZE) + CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) num_chunks = (num_tokens // CHUNK_SIZE) + 1 chunked_dg = num_chunks > 1 @@ -1380,7 +1380,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 - for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, From a4a0719f055ce02847c1e15e56f52b7d5888b5cf Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 15 Mar 2025 00:02:28 +0000 Subject: [PATCH 091/190] fix rebase Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c7447a74c48..f205eb515f3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1311,9 +1311,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 - intermediate_cache2 = torch.empty((M * top_k_num, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) + #intermediate_cache2 = torch.empty((M * top_k_num, N // 2), + # device=hidden_states.device, + # dtype=hidden_states.dtype) # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX @@ -1358,25 +1358,31 @@ def fused_experts_impl(hidden_states: torch.Tensor, new_M = ((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m - intermediate_cache1 = torch.empty((new_M, N), - device=hidden_states.device, - dtype=hidden_states.dtype) + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(new_M * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + intermediate_cache1 = cache13[:(new_M * N)].view(new_M, N) intermediate_cache2 = torch.empty((new_M, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((new_M, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view(new_M, w2.shape[1]) else: - intermediate_cache1 = torch.empty((M, top_k_num, N), - device=hidden_states.device, - dtype=hidden_states.dtype) + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + intermediate_cache1 = cache13[:M * top_k_num * N].view( + (M, topk_ids.shape[1], N)) intermediate_cache2 = torch.empty((M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view( + (M, topk_ids.shape[1], w2.shape[1])) num_chunks = (num_tokens // CHUNK_SIZE) + 1 From 58c733b03342870d985667a2b81a7b1e31557058 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 17 Mar 2025 16:15:15 +0000 Subject: [PATCH 092/190] rebase Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 3 +-- .../layers/fused_moe/fused_moe.py | 20 ++----------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 11d35ec345a..d787eb0044a 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -2,7 +2,6 @@ # Adapted from https://github.com/sgl-project/sglang/pull/2575 import itertools -from typing import Tuple import pytest import torch @@ -288,7 +287,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): def per_block_cast_to_fp8( x: torch.Tensor, - block_size_n: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f205eb515f3..cca91a8d20c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1300,23 +1300,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) - # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX - - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 - cache13 = torch.empty(M * top_k_num * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) - intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) - - # This needs separate memory since it's used concurrently with cache1 - #intermediate_cache2 = torch.empty((M * top_k_num, N // 2), - # device=hidden_states.device, - # dtype=hidden_states.dtype) - - # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX - if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: @@ -1368,7 +1351,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache2 = torch.empty((new_M, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view(new_M, w2.shape[1]) + intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view( + new_M, w2.shape[1]) else: # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 From 3f397b032b7d8a47e9bbf0c78a674df7f1adf324 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 21 Mar 2025 21:52:56 +0000 Subject: [PATCH 093/190] refactoring + minor perf improvements Signed-off-by: Bill Nell --- benchmarks/kernels/benchmark_moe.py | 26 ++-- tests/kernels/test_block_fp8.py | 118 ++++++++++++------ .../layers/fused_moe/fused_moe.py | 24 ++-- 3 files changed, 107 insertions(+), 61 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 1884a80a407..6e6b29b4c75 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -30,18 +30,20 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_config(config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - block_quant_shape: List[int] = None, - use_deep_gemm: bool = False) -> float: +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: List[int] = None, + use_deep_gemm: bool = False +) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_int8_w8a16: diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index d787eb0044a..f75f2f2f5f5 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -228,6 +228,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): SEEDS)) @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): + if topk > E: + pytest.skip(f"Skipping test; topk={K} > E={E}") + torch.manual_seed(seed) factor_for_scale = 1e-2 fp8_info = torch.finfo(torch.float8_e4m3fn) @@ -276,8 +279,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - print(f"{out.sum()=}") - print(f"{ref_out.sum()=}") + #print(f"{out.sum()=}") + #print(f"{ref_out.sum()=}") rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / @@ -348,21 +351,16 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): def fp8_perm(m, idx): - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=torch.float8_e4m3fn) + if m.dtype == torch.float8_e4m3fn: + return m.view(dtype=torch.uint8)[idx, + ...].view(dtype=torch.float8_e4m3fn) + else: + return m[idx, ...] -def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, - block_shape): - """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" - num_groups = w1.shape[0] +def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): M, K = a.shape - topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - - _, block_k = block_shape[0], block_shape[1] - sorted_token_ids, m_indices, num_pad = moe_align_block_size( topk_ids, block_m, num_groups, None) @@ -381,19 +379,54 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - inv_perm = torch.argsort(sorted_token_ids) + #print(f"sti {sorted_token_ids}") + + inv_perm = torch.argsort(sorted_token_ids)[:M*topk] + + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + a = fp8_perm(a, sorted_token_ids) + + if a_s is not None: + a_s = a_s.view(M, -1, K // 128).repeat(1, topk, + 1).reshape(-1, K // 128) + a_s = a_s[sorted_token_ids] + + return a, a_s, m_indices, inv_perm + + +def test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, M, K, + topk_weight, topk_ids): + # TODO use moe_sum? + out = out[inv_perm, ...] + tmp_out = out.view(-1, topk, K) + return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + - a_q, a_s = per_token_group_quant_fp8(a, block_m) +def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, + block_shape): + """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" + num_groups = w1.shape[0] + M, K = a.shape - # Replicate activations and scales - a_q = a_q.view(a_q.shape[0], -1, - a_q.shape[1]).repeat(1, topk, 1).reshape(-1, a_q.shape[1]) - a_s = a_s.view(a_s.shape[0], -1, - a_s.shape[1]).repeat(1, topk, 1).reshape(-1, a_s.shape[1]) + topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + + _, block_k = block_shape[0], block_shape[1] - # Permute activations according to sorted token ids - a_q = fp8_perm(a_q, sorted_token_ids) - a_s = a_s[sorted_token_ids] + if False: + # quantize before permute + a_q, a_s = per_token_group_quant_fp8(a, block_m) + a_q, a_s, m_indices, inv_perm = test_moe_permute( + a_q, a_s, topk_ids, num_groups, topk, block_m) + else: + # quantize after permute + a_q, a_s, m_indices, inv_perm = test_moe_permute( + a, None, topk_ids, num_groups, topk, block_m) + a_q, a_s = per_token_group_quant_fp8(a_q, block_m) + + # Fix this assert + #assert a_s.shape[1] == K // 128 and a_q.shape[0] == a_s.shape[0] == M * topk inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, @@ -413,13 +446,22 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - out = out[inv_perm, ...] + if True: + final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, + num_groups, M, K, topk_weight, topk_ids) + else: + m_indices = torch.arange(0, + M * (topk + 1), + block_m, + dtype=torch.int, + device=out.device) - tmp_out = out[:(M * topk), ...].view(-1, topk, w2.shape[1]) + print(f"inv_perm {inv_perm}") + print(f"inv_perm[:m*topk] {inv_perm[:M*topk]}") - final_out = (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - # TODO use moe_sum? + final_out = test_moe_unpermute_op(out, inv_perm, m_indices, topk, + num_groups, M, K, topk_weight, + topk_ids) return final_out @@ -489,13 +531,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - if M % 128 == 0: - ref_out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, - w2_s, score, topk, - block_size) - else: - ref_out2 = None - out = fused_moe(a, w1, w2, @@ -508,6 +543,13 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, allow_deep_gemm=True) + if M % 128 == 0: + out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, + w2_s, score, topk, + block_size) + else: + out2 = None + #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") @@ -516,8 +558,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 - if ref_out2 is not None: + if out2 is not None: rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out2.to(torch.float32))) / - torch.mean(torch.abs(ref_out2.to(torch.float32)))) + torch.abs(ref_out.to(torch.float32) - out2.to(torch.float32))) / + torch.mean(torch.abs(out2.to(torch.float32)))) assert rel_diff < 0.03 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index cca91a8d20c..db16d12476c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1331,6 +1331,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, # We attempt to do this offline in Fp8MoEMethod, in which case these # calls will be nops. Otherwise, they'll be performed every time the # layer is executed. + print(f"SHAPES {w1_scale.shape}, {w2_scale.shape}") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() @@ -1454,19 +1455,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) + # is this correct in the loop? TODO: fold in moe_sum? if use_dg and not skip_dg: - assert inv_perm is not None - M = curr_topk_weights.shape[0] - out_C = intermediate_cache3[inv_perm, ...] - out_C = out_C[:(M * top_k_num), ...] - out_C = out_C.view(-1, top_k_num, w2.shape[1]) - out_C.mul_(curr_topk_weights.view(M, -1, 1)) - tmp_cache3 = out_C + _moe_unpermute(out_hidden_states[begin_chunk_idx:end_chunk_idx], + intermediate_cache3, + inv_perm, + expert_ids, + top_k_num, + global_num_experts, + w2.shape[1], + curr_topk_weights, + curr_topk_ids) else: - tmp_cache3 = intermediate_cache3.view(*intermediate_cache3.shape) - - ops.moe_sum(tmp_cache3, - out_hidden_states[begin_chunk_idx:end_chunk_idx]) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states From 48b55c4a944587b45f3ff238d59838deea3c58ff Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 22 Mar 2025 03:57:02 +0000 Subject: [PATCH 094/190] refactoring + perf tweaks Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index db16d12476c..a006f2a447c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1319,6 +1319,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, chunked_dg = False if use_dg: + #print("USE_DG!") if M % block_m != 0: CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) @@ -1331,7 +1332,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, # We attempt to do this offline in Fp8MoEMethod, in which case these # calls will be nops. Otherwise, they'll be performed every time the # layer is executed. - print(f"SHAPES {w1_scale.shape}, {w2_scale.shape}") + #print(f"SHAPES {w1_scale.shape}, {w2_scale.shape}") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() @@ -1355,6 +1356,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view( new_M, w2.shape[1]) else: + #print(f"TRITON {allow_deep_gemm}") # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), From 3c8704c968c717e1a5a749139f5cb7dd541c73f3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 24 Mar 2025 15:26:41 +0000 Subject: [PATCH 095/190] remove debugging cruft Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a006f2a447c..976885ad2b7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1319,7 +1319,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, chunked_dg = False if use_dg: - #print("USE_DG!") if M % block_m != 0: CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) @@ -1332,7 +1331,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, # We attempt to do this offline in Fp8MoEMethod, in which case these # calls will be nops. Otherwise, they'll be performed every time the # layer is executed. - #print(f"SHAPES {w1_scale.shape}, {w2_scale.shape}") w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() @@ -1356,7 +1354,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view( new_M, w2.shape[1]) else: - #print(f"TRITON {allow_deep_gemm}") # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), From e885027cb3c84302e872659eb569b00077fe878f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 24 Mar 2025 22:28:40 +0000 Subject: [PATCH 096/190] cache resize refactoring Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 15 ++--- .../layers/fused_moe/fused_moe.py | 58 +++++++------------ 2 files changed, 26 insertions(+), 47 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index f75f2f2f5f5..188ea972302 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -396,7 +396,6 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): def test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, M, K, topk_weight, topk_ids): - # TODO use moe_sum? out = out[inv_perm, ...] tmp_out = out.view(-1, topk, K) return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @@ -414,16 +413,10 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, _, block_k = block_shape[0], block_shape[1] - if False: - # quantize before permute - a_q, a_s = per_token_group_quant_fp8(a, block_m) - a_q, a_s, m_indices, inv_perm = test_moe_permute( - a_q, a_s, topk_ids, num_groups, topk, block_m) - else: - # quantize after permute - a_q, a_s, m_indices, inv_perm = test_moe_permute( - a, None, topk_ids, num_groups, topk, block_m) - a_q, a_s = per_token_group_quant_fp8(a_q, block_m) + a_q, a_s = per_token_group_quant_fp8(a, block_m) + + a_q, a_s, m_indices, inv_perm = test_moe_permute( + a_q, a_s, topk_ids, num_groups, topk, block_m) # Fix this assert #assert a_s.shape[1] == K // 128 and a_q.shape[0] == a_s.shape[0] == M * topk diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 976885ad2b7..997c13c6bbb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -3,6 +3,7 @@ import functools import json import os +from math import prod from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -1317,14 +1318,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() - chunked_dg = False if use_dg: if M % block_m != 0: CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) - num_chunks = (num_tokens // CHUNK_SIZE) + 1 - chunked_dg = num_chunks > 1 - assert w1_scale is not None assert w2_scale is not None @@ -1334,41 +1331,30 @@ def fused_experts_impl(hidden_states: torch.Tensor, w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - # TODO: computing new_M could be smarter - sorted_token_ids, _, _ = (moe_align_block_size(topk_ids, block_m, - global_num_experts, - expert_map)) - - new_M = ((sorted_token_ids.numel() + block_m - 1) // block_m) * block_m + M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) + M_sum = ((M_sum + block_m - 1) // block_m) * block_m - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 - cache13 = torch.empty(new_M * max(N, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = cache13[:(new_M * N)].view(new_M, N) - intermediate_cache2 = torch.empty((new_M, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = cache13[:(new_M * w2.shape[1])].view( - new_M, w2.shape[1]) + cache1_view = (M_sum, N) + cache3_view = (M_sum, K) else: - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 - cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) + M_sum = M * top_k_num + cache1_view = (M, top_k_num, N) + cache3_view = (M, top_k_num, K) + + num_chunks = (num_tokens // CHUNK_SIZE) + 1 + + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M_sum * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype) - intermediate_cache1 = cache13[:M * top_k_num * N].view( - (M, topk_ids.shape[1], N)) - intermediate_cache2 = torch.empty((M * top_k_num, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view( - (M, topk_ids.shape[1], w2.shape[1])) + intermediate_cache1 = cache13[:M_sum * N].view(*cache1_view) + intermediate_cache2 = torch.empty((M_sum, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = cache13[:M_sum * K].view(*cache3_view) - num_chunks = (num_tokens // CHUNK_SIZE) + 1 for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, @@ -1462,7 +1448,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, expert_ids, top_k_num, global_num_experts, - w2.shape[1], + K, curr_topk_weights, curr_topk_ids) else: From e8e6b6dc8d96271b8be30669970a274d594128f1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 25 Mar 2025 14:23:17 +0000 Subject: [PATCH 097/190] cleanups Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 42 +++------------ .../layers/fused_moe/fused_moe.py | 51 ++++++++++--------- 2 files changed, 35 insertions(+), 58 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 188ea972302..30ab50ddf79 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -229,7 +229,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): @torch.inference_mode() def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): if topk > E: - pytest.skip(f"Skipping test; topk={K} > E={E}") + pytest.skip(f"Skipping test; topk={topk} > E={E}") torch.manual_seed(seed) factor_for_scale = 1e-2 @@ -351,7 +351,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): def fp8_perm(m, idx): - if m.dtype == torch.float8_e4m3fn: + if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: return m.view(dtype=torch.uint8)[idx, ...].view(dtype=torch.float8_e4m3fn) else: @@ -379,8 +379,6 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - #print(f"sti {sorted_token_ids}") - inv_perm = torch.argsort(sorted_token_ids)[:M*topk] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) @@ -418,9 +416,6 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q, a_s, m_indices, inv_perm = test_moe_permute( a_q, a_s, topk_ids, num_groups, topk, block_m) - # Fix this assert - #assert a_s.shape[1] == K // 128 and a_q.shape[0] == a_s.shape[0] == M * topk - inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, device=a.device) @@ -439,22 +434,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - if True: - final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, - num_groups, M, K, topk_weight, topk_ids) - else: - m_indices = torch.arange(0, - M * (topk + 1), - block_m, - dtype=torch.int, - device=out.device) - - print(f"inv_perm {inv_perm}") - print(f"inv_perm[:m*topk] {inv_perm[:M*topk]}") - - final_out = test_moe_unpermute_op(out, inv_perm, m_indices, topk, - num_groups, M, K, topk_weight, - topk_ids) + final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, + num_groups, M, K, topk_weight, topk_ids) return final_out @@ -502,6 +483,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] @@ -509,17 +493,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - w1_sa = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_sa = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() - - # TODO: move size alignment further up when setting up all shapes - if w1_sa.shape != w1_s.shape or w2_sa.shape != w2_s.shape: - print("UNALIGNED") - pytest.skip("UNALIGNED") - - w1_s = w1_sa - w2_s = w2_sa - + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 997c13c6bbb..e1588cbdab4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -523,8 +523,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, if use_dg: assert use_fp8_w8a8 - # Note: we do not apply weights here since it requires - # resizing the output. + # Note: we never apply the topk_weights here since it requires + # unpermuting and resizing the output. This goes against the + # existing interface as the `mul_routed_weight` argument is + # ignored. The weights are applied in _moe_unpermute. dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (A, A_scale), (B, B_scale), C, expert_ids) @@ -856,7 +858,7 @@ def try_get_optimal_moe_config( config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - # Try to remove this + # Enforce DeepGemm M blocking no matter what the config says. if use_deep_gemm: config['BLOCK_SIZE_M'] = dg.get_m_alignment_for_contiguous_layout() @@ -905,10 +907,10 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indices = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. @@ -1319,15 +1321,18 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() if use_dg: + # If M is not divisible by the block size we run the largest + # chunk we can using DeepGemm, the remainder is handed off to + # the Triton kernels. if M % block_m != 0: CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) assert w1_scale is not None assert w2_scale is not None - # We attempt to do this offline in Fp8MoEMethod, in which case these - # calls will be nops. Otherwise, they'll be performed every time the - # layer is executed. + # We attempt to transpose and align offline in Fp8MoEMethod, in which + # case these calls will be nops. Otherwise, they'll be performed every + # time the layer is executed. w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() @@ -1366,6 +1371,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break + # Even if we are using DeepGemm, we must defer any chunks + # that are not blocked to Triton. skip_dg = use_dg and tokens_in_chunk % block_m != 0 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] @@ -1440,20 +1447,16 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - # is this correct in the loop? TODO: fold in moe_sum? - if use_dg and not skip_dg: - _moe_unpermute(out_hidden_states[begin_chunk_idx:end_chunk_idx], - intermediate_cache3, - inv_perm, - expert_ids, - top_k_num, - global_num_experts, - K, - curr_topk_weights, - curr_topk_ids) - else: - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), - out_hidden_states[begin_chunk_idx:end_chunk_idx]) + _moe_unpermute_and_reduce(out_hidden_states[begin_chunk_idx:end_chunk_idx], + intermediate_cache3.view(*intermediate_cache3.shape), + inv_perm, + expert_ids, + top_k_num, + global_num_experts, + K, + curr_topk_weights, + curr_topk_ids, + use_dg and not skip_dg) return out_hidden_states From 4c64246c9a807e5fd2a83d3a896f8bf9f5cadd07 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 25 Mar 2025 14:34:43 +0000 Subject: [PATCH 098/190] format Signed-off-by: Bill Nell --- benchmarks/kernels/benchmark_moe.py | 26 ++++++++-------- tests/kernels/test_block_fp8.py | 15 +++++----- .../layers/fused_moe/fused_moe.py | 30 ++++++++++++------- 3 files changed, 38 insertions(+), 33 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 6e6b29b4c75..1884a80a407 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -30,20 +30,18 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_config( - config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - block_quant_shape: List[int] = None, - use_deep_gemm: bool = False -) -> float: +def benchmark_config(config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: List[int] = None, + use_deep_gemm: bool = False) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_int8_w8a16: diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 30ab50ddf79..88e7e2bcdba 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -379,7 +379,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - inv_perm = torch.argsort(sorted_token_ids)[:M*topk] + inv_perm = torch.argsort(sorted_token_ids)[:M * topk] a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) a = fp8_perm(a, sorted_token_ids) @@ -413,8 +413,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q, a_s = per_token_group_quant_fp8(a, block_m) - a_q, a_s, m_indices, inv_perm = test_moe_permute( - a_q, a_s, topk_ids, num_groups, topk, block_m) + a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids, + num_groups, topk, block_m) inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), dtype=torch.bfloat16, @@ -434,8 +434,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, - num_groups, M, K, topk_weight, topk_ids) + final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, + M, K, topk_weight, topk_ids) return final_out @@ -511,9 +511,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, allow_deep_gemm=True) if M % 128 == 0: - out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, - w2_s, score, topk, - block_size) + out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) else: out2 = None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e1588cbdab4..37cce8de28b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -510,6 +510,20 @@ def invoke_fused_moe_kernel(A: torch.Tensor, M = A.shape[0] num_tokens = M * top_k + if use_fp8_w8a8: + assert B_scale is not None + assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) + == B_scale.shape[-2]) + assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) + == B_scale.shape[-1]) + + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + EM = sorted_token_ids.shape[0] if A.shape[0] < config["BLOCK_SIZE_M"]: # optimize for small batch_size. @@ -1360,7 +1374,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, dtype=hidden_states.dtype) intermediate_cache3 = cache13[:M_sum * K].view(*cache3_view) - for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1447,16 +1460,11 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - _moe_unpermute_and_reduce(out_hidden_states[begin_chunk_idx:end_chunk_idx], - intermediate_cache3.view(*intermediate_cache3.shape), - inv_perm, - expert_ids, - top_k_num, - global_num_experts, - K, - curr_topk_weights, - curr_topk_ids, - use_dg and not skip_dg) + _moe_unpermute_and_reduce( + out_hidden_states[begin_chunk_idx:end_chunk_idx], + intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, + expert_ids, top_k_num, global_num_experts, K, curr_topk_weights, + curr_topk_ids, use_dg and not skip_dg) return out_hidden_states From f3ff692545ceb9d0477dbf0c47e6df4ed3a426f1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 25 Mar 2025 16:36:14 +0000 Subject: [PATCH 099/190] revert test.txt, fix mypy errors Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 37cce8de28b..b47cd8e70c8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1334,6 +1334,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_m = config['BLOCK_SIZE_M'] assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() + cache1_view: Tuple[int, ...] = () + cache2_view: Tuple[int, ...] = () + cache3_view: Tuple[int, ...] = () + if use_dg: # If M is not divisible by the block size we run the largest # chunk we can using DeepGemm, the remainder is handed off to From 9d048ec6ebd3ab32bfa52b49baf57384e6ae1b6a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 26 Mar 2025 22:15:33 +0000 Subject: [PATCH 100/190] review comments Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b47cd8e70c8..7c3ddf83f74 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -22,7 +22,7 @@ _resize_cache, moe_kernel_quantize_input) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, round_up from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -1355,7 +1355,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) - M_sum = ((M_sum + block_m - 1) // block_m) * block_m + M_sum = round_up(M_sum, block_m) cache1_view = (M_sum, N) cache3_view = (M_sum, K) From 9504dead15abb2dd1b9fb0c30d17aa4f309666cd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Mar 2025 02:21:17 +0000 Subject: [PATCH 101/190] review comments Signed-off-by: Bill Nell --- .../quantization/compressed_tensors/compressed_tensors_moe.py | 2 -- vllm/model_executor/layers/quantization/fp8.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ec7e93b754c..ae16a20cfaa 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -275,8 +275,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts - # TODO: do we need to do deep gemm alignment here? - def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4549bee01ef..f4eef830457 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -428,7 +428,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None - self.allow_deep_gemm = use_deep_gemm + self.allow_deep_gemm = allow_deep_gemm # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization From ef0eee999fcd29745af21abef016a2ca3fc672cd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Mar 2025 03:14:23 +0000 Subject: [PATCH 102/190] clean up use_dg flags Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 20 ++++++------------- .../layers/fused_moe/fused_moe.py | 14 +++++++------ 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 88e7e2bcdba..fdb5f4c3a5b 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -495,8 +495,12 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) + if M % 128 == 0: + ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, + score, topk, block_size) + else: + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, + block_size) out = fused_moe(a, w1, @@ -510,12 +514,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, block_shape=block_size, allow_deep_gemm=True) - if M % 128 == 0: - out2 = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - else: - out2 = None - #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") @@ -523,9 +521,3 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 - - if out2 is not None: - rel_diff = (torch.mean( - torch.abs(ref_out.to(torch.float32) - out2.to(torch.float32))) / - torch.mean(torch.abs(out2.to(torch.float32)))) - assert rel_diff < 0.03 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7c3ddf83f74..a9e9939028c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1366,8 +1366,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, num_chunks = (num_tokens // CHUNK_SIZE) + 1 - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 cache13 = torch.empty(M_sum * max(N, K), device=hidden_states.device, dtype=hidden_states.dtype) @@ -1378,6 +1378,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, dtype=hidden_states.dtype) intermediate_cache3 = cache13[:M_sum * K].view(*cache3_view) + needs_fp8_quantization = use_fp8_w8a8 or use_dg + for chunk in range(num_chunks): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1388,9 +1390,9 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break - # Even if we are using DeepGemm, we must defer any chunks - # that are not blocked to Triton. - skip_dg = use_dg and tokens_in_chunk % block_m != 0 + # If we are using DeepGemm, only operate on chunks that are + # blocked, otherwise defer to Triton. + use_dg_for_chunk = use_dg and tokens_in_chunk % block_m == 0 curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] @@ -1468,7 +1470,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, out_hidden_states[begin_chunk_idx:end_chunk_idx], intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, expert_ids, top_k_num, global_num_experts, K, curr_topk_weights, - curr_topk_ids, use_dg and not skip_dg) + curr_topk_ids, use_dg_for_chunk) return out_hidden_states From e9c5c27bd8e793f61b52f65c0068a86eb5a710fe Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 27 Mar 2025 15:24:26 +0000 Subject: [PATCH 103/190] remove check for aligned M Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a9e9939028c..71da3953f3c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1339,12 +1339,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, cache3_view: Tuple[int, ...] = () if use_dg: - # If M is not divisible by the block size we run the largest - # chunk we can using DeepGemm, the remainder is handed off to - # the Triton kernels. - if M % block_m != 0: - CHUNK_SIZE = min((M // block_m) * block_m, CHUNK_SIZE) - assert w1_scale is not None assert w2_scale is not None @@ -1390,10 +1384,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break - # If we are using DeepGemm, only operate on chunks that are - # blocked, otherwise defer to Triton. - use_dg_for_chunk = use_dg and tokens_in_chunk % block_m == 0 - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] @@ -1470,7 +1460,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, out_hidden_states[begin_chunk_idx:end_chunk_idx], intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, expert_ids, top_k_num, global_num_experts, K, curr_topk_weights, - curr_topk_ids, use_dg_for_chunk) + curr_topk_ids, use_dg) return out_hidden_states From b3287ac06b724188c08a2d7b251cd1aa92d490f0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Mar 2025 18:31:23 +0000 Subject: [PATCH 104/190] rebase + clean up test Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 35 ++++++++++++--------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index fdb5f4c3a5b..ce80abb4499 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +from vllm.utils import round_up dg_available = False try: @@ -352,8 +353,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): def fp8_perm(m, idx): if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: - return m.view(dtype=torch.uint8)[idx, - ...].view(dtype=torch.float8_e4m3fn) + return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) else: return m[idx, ...] @@ -366,34 +366,26 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): num_tokens = topk * M - pad_size = (((sorted_token_ids.numel() + block_m - 1) // block_m) * - block_m) - sorted_token_ids.numel() + pad_size = (round_up(sorted_token_ids.numel(), block_m) - + sorted_token_ids.numel()) if pad_size > 0: sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, (0, pad_size), "constant", num_tokens) sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) - - assert sorted_token_ids[sorted_token_ids >= num_tokens].sum() == 0 - inv_perm = torch.argsort(sorted_token_ids)[:M * topk] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - a = fp8_perm(a, sorted_token_ids) - + a = fp8_perm(a, sorted_token_ids // topk) if a_s is not None: - a_s = a_s.view(M, -1, K // 128).repeat(1, topk, - 1).reshape(-1, K // 128) - a_s = a_s[sorted_token_ids] + a_s = a_s[sorted_token_ids // topk] return a, a_s, m_indices, inv_perm -def test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, M, K, - topk_weight, topk_ids): +def test_moe_unpermute(out, inv_perm, topk, K, topk_weight): + M = topk_weight.shape[0] out = out[inv_perm, ...] tmp_out = out.view(-1, topk, K) return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) @@ -404,6 +396,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" num_groups = w1.shape[0] M, K = a.shape + N = w2.shape[-1] topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) @@ -416,7 +409,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids, num_groups, topk, block_m) - inter_out = torch.zeros((a_q.shape[0], w1[0].shape[0]), + inter_out = torch.zeros((a_q.shape[0], N * 2), dtype=torch.bfloat16, device=a.device) @@ -426,16 +419,14 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - out = torch.zeros(act_out.shape[0], - w2.shape[1], + out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - final_out = test_moe_unpermute(out, inv_perm, m_indices, topk, num_groups, - M, K, topk_weight, topk_ids) + final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight) return final_out @@ -495,7 +486,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - if M % 128 == 0: + if M >= 128: ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_size) else: From c820cfe142cc33a6700c8f5b53e01e69a10493b0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 28 Mar 2025 20:32:18 +0000 Subject: [PATCH 105/190] fix format Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ce80abb4499..89e9a073acf 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform -from vllm.utils import round_up dg_available = False try: @@ -362,17 +361,10 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): M, K = a.shape sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, block_m, num_groups, None) + topk_ids, block_m, num_groups, None, pad_sorted_ids=True) num_tokens = topk * M - pad_size = (round_up(sorted_token_ids.numel(), block_m) - - sorted_token_ids.numel()) - if pad_size > 0: - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, - (0, pad_size), "constant", - num_tokens) - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) inv_perm = torch.argsort(sorted_token_ids)[:M * topk] @@ -419,9 +411,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - out = torch.zeros(a_q.shape[0], K, - dtype=torch.bfloat16, - device=a.device) + out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) @@ -490,8 +480,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, block_size) else: - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, + topk, block_size) out = fused_moe(a, w1, From 30cfab4731b59b170f00dce84ac7c221f665d84c Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 31 Mar 2025 18:34:07 +0000 Subject: [PATCH 106/190] Clean up diff Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/cuda_graph_utils.py | 0 .../layers/fused_moe/fused_moe.py | 27 ++-------- vllm/model_executor/layers/fused_moe/layer.py | 50 ++++++++++++------- 3 files changed, 36 insertions(+), 41 deletions(-) delete mode 100644 vllm/cuda_graph_utils.py diff --git a/vllm/cuda_graph_utils.py b/vllm/cuda_graph_utils.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 71da3953f3c..ded05a0baa1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -535,16 +535,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( B.shape[1], META['BLOCK_SIZE_N']), ) - if use_dg: - assert use_fp8_w8a8 - # Note: we never apply the topk_weights here since it requires - # unpermuting and resizing the output. This goes against the - # existing interface as the `mul_routed_weight` argument is - # ignored. The weights are applied in _moe_unpermute. - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (A, A_scale), (B, B_scale), C, expert_ids) - - elif (use_int8_w8a16 or use_int4_w4a16) and \ + if (use_int8_w8a16 or use_int4_w4a16) and \ block_shape is not None and block_shape[1] > 0: assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 @@ -848,7 +839,6 @@ def try_get_optimal_moe_config( M: int, is_marlin: bool = False, block_shape: Optional[List[int]] = None, - use_deep_gemm: bool = False, ): from vllm.model_executor.layers.fused_moe import get_config override_config = get_config() @@ -871,11 +861,6 @@ def try_get_optimal_moe_config( # Else use the default config config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape) - - # Enforce DeepGemm M blocking no matter what the config says. - if use_deep_gemm: - config['BLOCK_SIZE_M'] = dg.get_m_alignment_for_contiguous_layout() - return config @@ -1048,14 +1033,13 @@ def inplace_fused_experts(hidden_states: torch.Tensor, w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> None: + block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape, allow_deep_gemm) + block_shape) def inplace_fused_experts_fake( @@ -1492,7 +1476,6 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = True, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1526,8 +1509,8 @@ def fused_moe( Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c36e7d0ad3d..ba0dd0f98d8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,16 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 from abc import abstractmethod +from dataclasses import dataclass from enum import Enum from typing import Callable, List, Optional, Tuple -from dataclasses import dataclass +import pplx_kernels as pplx import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter -import pplx_kernels as pplx - import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, @@ -39,6 +38,7 @@ MOE_DP_CHUNK_SIZE = 256 + # Adapted from pplx-kernels tests/all_to_all_utils.py @dataclass class MoEConfig: @@ -56,6 +56,7 @@ class MoEConfig: out_dtype: torch.dtype = torch.bfloat16 block_size: int = 128 + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" @@ -92,9 +93,12 @@ def apply( ) -> torch.Tensor: raise NotImplementedError + +#TODO: Every change in this class is a broken hack!! @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" + def __init__(self, moe: MoEConfig): self.all_to_all = pplx.AllToAll( max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, @@ -108,7 +112,6 @@ def __init__(self, moe: MoEConfig): hidden_dim_scale_bytes=0, ) - def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -875,7 +878,7 @@ def forward(self, hidden_states: torch.Tensor, self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, - full_router_logits: torch.Tensor): + full_router_logits: torch.Tensor): max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp cu_tokens_across_dp_cpu = get_forward_context( @@ -891,21 +894,23 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, num_tokens_remaining_across_dp = num_tokens_across_dp chunk_start = 0 - chunk_end = min(moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) + chunk_end = min(moe_dp_chunk_size_per_rank, + full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) for _ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): - hidden_states = full_hidden_states[chunk_start:chunk_end,:] - router_logits = full_router_logits[chunk_start:chunk_end,:] + hidden_states = full_hidden_states[chunk_start:chunk_end, :] + router_logits = full_router_logits[chunk_start:chunk_end, :] cu_tokens_across_dp_this_iter = torch.cumsum( - num_tokens_remaining_across_dp.clamp(max=moe_dp_chunk_size_per_rank), + num_tokens_remaining_across_dp.clamp( + max=moe_dp_chunk_size_per_rank), dim=0) - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_this_iter) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_this_iter) + hidden_states = self.naive_multicast( + hidden_states, cu_tokens_across_dp_this_iter) + router_logits = self.naive_multicast( + router_logits, cu_tokens_across_dp_this_iter) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -926,7 +931,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, ) if self.dp_size > 1: - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[self.dp_rank-1] + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ + self.dp_rank - 1] end = cu_tokens_across_dp_this_iter[self.dp_rank] all_hidden_states = get_dp_group().all_reduce( @@ -935,20 +941,26 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) - full_final_hidden_states[chunk_start:chunk_end, :].copy_(final_hidden_states) + full_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states) # Update bounds - num_tokens_remaining_across_dp = torch.clamp(num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) + num_tokens_remaining_across_dp = torch.clamp( + num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, + min=0) + def update_chunk_bound(x: int): - return min(x + moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) + return min(x + moe_dp_chunk_size_per_rank, + full_hidden_states.shape[0]) + chunk_start = update_chunk_bound(chunk_start) chunk_end = update_chunk_bound(chunk_end) return full_final_hidden_states - def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None From 24399c3c21ef322cf60a4a42ee731e2155dcc596 Mon Sep 17 00:00:00 2001 From: Ilya Markov Date: Tue, 1 Apr 2025 07:49:12 +0200 Subject: [PATCH 107/190] [Distributed] Add custom allreduce support for ROCM (#14125) Signed-off-by: ilmarkov Co-authored-by: ilmarkov Signed-off-by: Bill Nell --- csrc/custom_all_reduce.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 44709b45977..186abf4712f 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -602,4 +602,4 @@ class CustomAllreduce { * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, half *, int, int, int); */ -} // namespace vllm \ No newline at end of file +} // namespace vllm From e58944ec1d152b37ade46be3bac925742dbc3b3a Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Tue, 1 Apr 2025 13:53:37 +0800 Subject: [PATCH 108/190] [Bugfix][Model] fix mllama multi-image (#14883) Signed-off-by: yan ma Signed-off-by: Bill Nell --- vllm/model_executor/models/mllama.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 0c1d61c01f9..971a4e695da 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1245,6 +1245,31 @@ def unpack_data(self, output_tensor[i, :t.size(0)] = t return output_tensor + def unpack_data(self, + image_data: Union[List[torch.Tensor], torch.Tensor], + padding_value=0) -> torch.Tensor: + if isinstance(image_data, torch.Tensor): + # torch.Tensor + return image_data + else: + assert isinstance( + image_data[0], + torch.Tensor), "Image data is not properly batched." + # List[torch.Tensor] + bsz = len(image_data) + max_length = max(t.size(0) for t in image_data) + trailing_dims = image_data[0].shape[1:] + for data in image_data: + cur_trailing_dims = data.shape[1:] + assert cur_trailing_dims == trailing_dims + output_tensor = torch.full((bsz, max_length, *trailing_dims), + padding_value, + dtype=image_data[0].dtype, + device=image_data[0].device) + for i, t in enumerate(image_data): + output_tensor[i, :t.size(0)] = t + return output_tensor + def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: From 9fbd1a919e6d38e11f0b30dd7e2b83cdb0c4d5b6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 1 Apr 2025 22:06:39 +0000 Subject: [PATCH 109/190] module deepgemm moe working Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 39 +++++++++++-------- .../layers/fused_moe/modular_kernel.py | 1 + 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 89e9a073acf..e747a96abf1 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -9,8 +9,12 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + deep_gemm_moe_fp8, + modular_deep_gemm_fused_moe_fp8) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -430,11 +434,13 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - # only aligned sizes - if (N % 128 != 0 or K % 128 != 0 or topk > E or block_size != [128, 128]): + # only aligned sizes TODO: use _valid_deep_gemm here instead? + if (N % block_m != 0 or K % block_m != 0 or topk > E): pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}, " - f"block_size={block_size}") + f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + + if False and N <= 512: + pytest.skip("Skipping N <= 512 until performance issues solved.") vllm_config = VllmConfig() @@ -474,6 +480,13 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + if True: + dgm = modular_deep_gemm_fused_moe_fp8() + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): + return dgm(a, w1, w2, topk_weights, topk_ids, w1_scale=w1_s, w2_scale=w2_s) + else: + deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -483,17 +496,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - out = fused_moe(a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - allow_deep_gemm=True) + topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) + + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index aab7658ae64..c386d5ec1dc 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -126,6 +126,7 @@ def combine( experts, it will have (M, topk, K) shape. - topk_weights: The weights to be applied to the fused_experts_output. - topk_ids: The topk_ids. + - apply_router_weight_on_input: When False, apply the weights to fused_expert_output. """ raise NotImplementedError From a1eecd507af18e33c6da677b85c83223c130b32c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 01:13:38 +0000 Subject: [PATCH 110/190] working deep gemm, wip cutlass Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index e747a96abf1..2511d817fe7 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _valid_deep_gemm, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -456,6 +457,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) +# if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): +# pytest.skip( +# f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] From 6efacf189de04ce8655d243eb743de19f17b1d3b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 13:49:41 +0000 Subject: [PATCH 111/190] working cutlass Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 266ba3bfa07..38f8072ac40 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import torch From e1dd8184c895ce04d7c670ee0dcc4eaadf6de612 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 20:33:59 +0000 Subject: [PATCH 112/190] deepgemm working again Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 38f8072ac40..4fa7139afa4 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch From bf02e1cc83255772e232f4c0b9c81bdda4356f77 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 2 Apr 2025 22:14:54 +0000 Subject: [PATCH 113/190] fix inplace, format and name cleanups Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 21 +++++++++++++------ .../layers/fused_moe/deep_gemm_moe.py | 2 +- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 2511d817fe7..fafc7c18254 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -457,9 +457,9 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) -# if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): -# pytest.skip( -# f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + # if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): + # pytest.skip( + # f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") score = torch.randn((M, E), dtype=dtype) @@ -487,8 +487,16 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, if True: dgm = modular_deep_gemm_fused_moe_fp8() - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): - return dgm(a, w1, w2, topk_weights, topk_ids, w1_scale=w1_s, w2_scale=w2_s) + + def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids): + return dgm(a, + w1, + w2, + topk_weights, + topk_ids, + w1_scale=w1_s, + w2_scale=w2_s) else: deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 @@ -503,7 +511,8 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids): topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 4fa7139afa4..28050a5dd9e 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple import torch From 0abc8e5dc092ac75cb265015bf2276b1a3f6a11f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 01:18:53 +0000 Subject: [PATCH 114/190] test improvements Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 18 ++++++------------ .../layers/fused_moe/deep_gemm_moe.py | 4 ++-- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index fafc7c18254..ed861054b4b 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,9 +10,9 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm, deep_gemm_moe_fp8, - modular_deep_gemm_fused_moe_fp8) + modular_deep_gemm_fused_moe_fp8, + _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -435,13 +435,11 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - # only aligned sizes TODO: use _valid_deep_gemm here instead? - if (N % block_m != 0 or K % block_m != 0 or topk > E): - pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") - if False and N <= 512: - pytest.skip("Skipping N <= 512 until performance issues solved.") + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") vllm_config = VllmConfig() @@ -457,10 +455,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max).clamp(min=fp8_min, max=fp8_max) - # if not _valid_deep_gemm(a, w1_bf16, w2_bf16, None): - # pytest.skip( - # f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - score = torch.randn((M, E), dtype=dtype) block_n, block_k = block_size[0], block_size[1] diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 28050a5dd9e..facbba40c3e 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch @@ -28,7 +28,7 @@ def deep_gemm_block_shape() -> list[int]: def _valid_deep_gemm_shape(M: int, N: int, K: int): align = deep_gemm_block_shape()[0] - return align <= M and N % align == 0 and K % align == 0 + return M >= align and N % align == 0 and K % align == 0 def _valid_deep_gemm(hidden_states: torch.Tensor, From 78967958ed6704956dee9768b88aa87d03a3a949 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 04:41:01 +0000 Subject: [PATCH 115/190] make modular triton classes, fix edge cases Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ded05a0baa1..f858bf45065 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1207,6 +1207,30 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) + elif hidden_states.shape[0] <= envs.VLLM_FUSED_MOE_CHUNK_SIZE: + fe = modular_triton_fused_moe( + use_fp8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + block_shape, + ) + return fe( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, @@ -1214,6 +1238,7 @@ def fused_experts(hidden_states: torch.Tensor, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, + inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, @@ -1440,11 +1465,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, per_channel_quant=per_channel_quant, block_shape=block_shape) - _moe_unpermute_and_reduce( - out_hidden_states[begin_chunk_idx:end_chunk_idx], - intermediate_cache3.view(*intermediate_cache3.shape), inv_perm, - expert_ids, top_k_num, global_num_experts, K, curr_topk_weights, - curr_topk_ids, use_dg) + ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states From e42df0f1a38a6bf4edc41f3fd4c88cfe3703a4cf Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 17:17:16 +0000 Subject: [PATCH 116/190] refactor dispatch/combine stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 3 +++ vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index e52751eddf2..19ca505a256 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -9,6 +9,9 @@ from vllm.model_executor.layers.fused_moe.dispatch_combine import ( StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index facbba40c3e..ab355c7d53e 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -12,6 +12,13 @@ _moe_permute) from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + _moe_permute, + _moe_unpermute_and_reduce +) +from vllm.model_executor.layers.fused_moe.dispatch_combine import ( + StandardDispatchCombine +) from vllm.utils import round_up logger = init_logger(__name__) From f1cb9204185fa312b4dd78563493382d10ce4f63 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 19:39:35 +0000 Subject: [PATCH 117/190] initial pplx dispatch/combine class Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index ed861054b4b..2f9315f1952 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -362,7 +362,7 @@ def fp8_perm(m, idx): return m[idx, ...] -def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): +def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): M, K = a.shape sorted_token_ids, m_indices, num_pad = moe_align_block_size( @@ -381,7 +381,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): return a, a_s, m_indices, inv_perm -def test_moe_unpermute(out, inv_perm, topk, K, topk_weight): +def _moe_unpermute(out, inv_perm, topk, K, topk_weight): M = topk_weight.shape[0] out = out[inv_perm, ...] tmp_out = out.view(-1, topk, K) @@ -403,8 +403,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q, a_s = per_token_group_quant_fp8(a, block_m) - a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids, - num_groups, topk, block_m) + a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, + num_groups, topk, block_m) inter_out = torch.zeros((a_q.shape[0], N * 2), dtype=torch.bfloat16, @@ -421,7 +421,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight) + final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) return final_out From a023439b643cbbb911f14ad1737ea2bd2ba85a58 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 20:41:40 +0000 Subject: [PATCH 118/190] merge triton dispatch into standard, add some comments Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/modular_kernel.py | 1 + .../model_executor/layers/fused_moe/pplx_dispatch_combine.py | 5 +---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index c386d5ec1dc..9cc8131a5d8 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -102,6 +102,7 @@ def dispatch( - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. + - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. Returns a tuple of: - quantized + dispatched a. diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 658705515b4..8eac4fd3f5e 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -47,8 +47,6 @@ def dispatch( assert expert_map is None, "NYI" - # TBD - assert not apply_router_weight_on_input if apply_router_weight_on_input: topk = rank_topk_ids.shape[1] # TODO: this only works for topK=1, will need to update for topK>1 @@ -131,8 +129,7 @@ def combine( assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] - # Set weights to 1? - assert not apply_router_weight_on_input + # Set weights to 1 if we did them in dispatch. This is hacky. if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) From 913c0178adfb4b26256ba1f547069ae3907356db Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 20:47:41 +0000 Subject: [PATCH 119/190] format Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 4 +--- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 3 --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 9 +-------- vllm/model_executor/layers/fused_moe/modular_kernel.py | 1 + 4 files changed, 3 insertions(+), 14 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 2f9315f1952..3939f4b7bab 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,9 +10,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8, - modular_deep_gemm_fused_moe_fp8, - _valid_deep_gemm_shape) + _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 19ca505a256..e52751eddf2 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -9,9 +9,6 @@ from vllm.model_executor.layers.fused_moe.dispatch_combine import ( StandardDispatchCombine) from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache -from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) class CutlassExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index ab355c7d53e..266ba3bfa07 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -12,13 +12,6 @@ _moe_permute) from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, _resize_cache) -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_permute, - _moe_unpermute_and_reduce -) -from vllm.model_executor.layers.fused_moe.dispatch_combine import ( - StandardDispatchCombine -) from vllm.utils import round_up logger = init_logger(__name__) @@ -35,7 +28,7 @@ def deep_gemm_block_shape() -> list[int]: def _valid_deep_gemm_shape(M: int, N: int, K: int): align = deep_gemm_block_shape()[0] - return M >= align and N % align == 0 and K % align == 0 + return align <= M and N % align == 0 and K % align == 0 def _valid_deep_gemm(hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 9cc8131a5d8..a3086dee4b3 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -72,6 +72,7 @@ def _moe_problem_size( return E, M, N, K, topk + class FusedMoEQuantizeDispatchCombine(ABC): """ An abstract base class for the [Quantize-Dispatch] and [Combine] steps From 8e3284a9cc47ade6bfbb2759af95d9e20b864884 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 3 Apr 2025 23:20:14 +0000 Subject: [PATCH 120/190] cleanup for review Signed-off-by: Bill Nell --- tests/kernels/test_block_fp8.py | 20 ++-------------- .../layers/fused_moe/fused_moe.py | 24 ------------------- 2 files changed, 2 insertions(+), 42 deletions(-) diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index 3939f4b7bab..762d0239408 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -477,21 +477,6 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - if True: - dgm = modular_deep_gemm_fused_moe_fp8() - - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids): - return dgm(a, - w1, - w2, - topk_weights, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s) - else: - deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): if M >= 128: @@ -503,8 +488,7 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f858bf45065..1197de1a5f6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1207,30 +1207,6 @@ def fused_experts(hidden_states: torch.Tensor, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif hidden_states.shape[0] <= envs.VLLM_FUSED_MOE_CHUNK_SIZE: - fe = modular_triton_fused_moe( - use_fp8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - block_shape, - ) - return fe( - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1_scale, - a2_scale, - ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, From cc1a878d9b11feb955fa817de7762dd7c2f011a3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Apr 2025 23:17:43 +0000 Subject: [PATCH 121/190] hacking Signed-off-by: Bill Nell --- .../layers/fused_moe/__init__.py | 5 +-- vllm/model_executor/layers/fused_moe/layer.py | 34 +++++++++++++++++-- .../layers/fused_moe/pplx_dispatch_combine.py | 6 ++-- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9829ccdb384..b55a3ede12d 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -38,8 +38,8 @@ def get_config() -> Optional[Dict[str, Any]]: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + TritonExperts, fused_experts, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ "fused_moe", @@ -48,4 +48,5 @@ def get_config() -> Optional[Dict[str, Any]]: "get_config_file_name", "grouped_topk", "cutlass_moe_fp8", + "TritonExperts", ] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ba0dd0f98d8..ed156bcc53b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -8,6 +8,8 @@ import pplx_kernels as pplx import torch import torch.nn.functional as F +from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, nvshmem_init) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs @@ -23,10 +25,13 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, run_once if current_platform.is_cuda_alike(): - from .fused_moe import fused_experts + #from .pplx_dispatch_combine import PplxDispatchCombine + from .dispatch_combine import StandardDispatchCombine + from .fused_moe import TritonExperts, fused_experts + from .modular_kernel import FusedMoEModularKernel else: fused_experts = None # type: ignore if current_platform.is_tpu(): @@ -405,6 +410,14 @@ def determine_expert_map( return (local_num_experts, expert_map) +@run_once +def pplx_init(rank, world_size): + uid = nvshmem_get_unique_id( + ) if rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, rank, world_size) + + class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -528,8 +541,23 @@ def __init__( # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. if quant_config is None: + pplx_init(self.dp_rank, self.dp_size) + + moe = MoEConfig( + num_experts=self.global_num_experts, + experts_per_token=0, + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + dp_size=self.dp_size, + dp_rank=self.dp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + #in_dtype = 0, + #out_dtype = 0, + ) + self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod()) + UnquantizedFusedMoEMethod(moe)) else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 8eac4fd3f5e..0302524fe1c 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -69,14 +69,14 @@ def dispatch( expert_num_tokens = torch.empty( num_local_experts, dtype=torch.int32, - device=device, + device=a1.device, ) num_dp = self.world_size // self.dp_size expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, - device=device, + device=a1.device, ) expert_x_scale: Optional[torch.Tensor] = None @@ -91,7 +91,7 @@ def dispatch( (expert_x.size(2) + block_size - 1) // block_size, ), dtype=torch.float32, - device=device, + device=a1.device, ) # This argument is optional, defaults to indices.shape[0] From 6c320cf562ac1742d03e243835659926657a9708 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 7 Apr 2025 15:04:28 +0000 Subject: [PATCH 122/190] hacking Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 62 ++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ed156bcc53b..7f6ca1d7ce4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import threading +import weakref from abc import abstractmethod from dataclasses import dataclass from enum import Enum @@ -99,13 +101,67 @@ def apply( raise NotImplementedError +class AllToAllCache: + + def __init__(self): + self._cache = {} + self._lock = threading.RLock() # Reentrant lock for thread safety + + def get_or_create(self, **kwargs): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + with self._lock: + if key in self._cache: + instance, refs = self._cache[key] + new_ref = weakref.ref(object(), + lambda _: self._decrement_ref_count(key)) + refs.append(new_ref) + return instance + else: + # Create new instance + instance = pplx.AllToAll(**kwargs) + # Use a weakref.ref with a callback when reference is collected + refs = [ + weakref.ref(object(), + lambda _: self._decrement_ref_count(key)) + ] + self._cache[key] = (instance, refs) + return instance + + def _decrement_ref_count(self, key): + with self._lock: + if key in self._cache: + instance, refs = self._cache[key] + # Remove dead references + refs = [ref for ref in refs if ref() is not None] + if not refs: + # No more references, clean up the instance + instance.destroy() + del self._cache[key] + else: + # Update refs + self._cache[key] = (instance, refs) + + +# Global singleton +_all_to_all_cache = AllToAllCache() + + +# Factory function as a cleaner interface +def get_all_to_all(**kwargs): + return _all_to_all_cache.get_or_create(**kwargs) + + #TODO: Every change in this class is a broken hack!! @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: MoEConfig): - self.all_to_all = pplx.AllToAll( + pplx_init(moe.ep_rank, moe.ep_size) + + self.all_to_all = get_all_to_all( max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, num_experts=moe.num_experts, experts_per_token=moe.experts_per_token, @@ -412,9 +468,11 @@ def determine_expert_map( @run_once def pplx_init(rank, world_size): + print(f"PPLX_INIT {rank} {world_size}") uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) + print(f"PPLX_INIT UID={uid}") + torch.distributed.broadcast(uid.cuda(), src=0) nvshmem_init(uid, rank, world_size) From ddc0b9937b409ebbff14d7b630149c333f650e84 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 7 Apr 2025 19:41:44 +0000 Subject: [PATCH 123/190] init stuff Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 18 +++++++- vllm/model_executor/layers/fused_moe/layer.py | 44 ++++++++++--------- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index cb9658ce100..0bb4835939a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -42,7 +42,7 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, - supports_custom_op) + run_once, supports_custom_op) @dataclass @@ -912,6 +912,20 @@ def init_distributed_environment( "world group already initialized with a different world size") +@run_once +def pplx_init(rank, world_size): + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, nvshmem_init) + print(f"PPLX_INIT {rank} {world_size}") + uid = nvshmem_get_unique_id( + ) if rank == 0 else nvshmem_alloc_empty_unique_id() + uid_gpu = uid.cuda() + get_world_group().broadcast(uid_gpu, src=0) + print(f"PPLX_INIT UID={uid_gpu}") + uid = uid_gpu.to(device='cpu') + nvshmem_init(uid, rank, world_size) + + def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, @@ -1006,6 +1020,8 @@ def initialize_model_parallel( "DP rank %s, PP rank %s, TP rank %s", rank, world_size, _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) + pplx_init(rank, world_size) + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7f6ca1d7ce4..b68c8bbc6ad 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -10,8 +10,6 @@ import pplx_kernels as pplx import torch import torch.nn.functional as F -from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, nvshmem_init) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs @@ -27,13 +25,13 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op, run_once +from vllm.utils import direct_register_custom_op if current_platform.is_cuda_alike(): - #from .pplx_dispatch_combine import PplxDispatchCombine from .dispatch_combine import StandardDispatchCombine from .fused_moe import TritonExperts, fused_experts from .modular_kernel import FusedMoEModularKernel + from .pplx_dispatch_combine import PplxDispatchCombine else: fused_experts = None # type: ignore if current_platform.is_tpu(): @@ -101,7 +99,7 @@ def apply( raise NotImplementedError -class AllToAllCache: +class AllToAllCacheThreadSafe: def __init__(self): self._cache = {} @@ -120,6 +118,7 @@ def get_or_create(self, **kwargs): return instance else: # Create new instance + print("CREATE AllToAll") instance = pplx.AllToAll(**kwargs) # Use a weakref.ref with a callback when reference is collected refs = [ @@ -144,6 +143,25 @@ def _decrement_ref_count(self, key): self._cache[key] = (instance, refs) +class AllToAllCache: + + def __init__(self): + self._cache = {} + + def get_or_create(self, **kwargs): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + if key in self._cache: + return self._cache[key] + else: + # Create new instance + print("CREATE AllToAll") + instance = pplx.AllToAll(**kwargs) + self._cache[key] = instance + return instance + + # Global singleton _all_to_all_cache = AllToAllCache() @@ -159,8 +177,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: MoEConfig): - pplx_init(moe.ep_rank, moe.ep_size) - self.all_to_all = get_all_to_all( max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, num_experts=moe.num_experts, @@ -301,7 +317,7 @@ def forward_cuda( e_score_correction_bias=e_score_correction_bias) return fused_experts( - hidden_states=x, + a1=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, @@ -466,16 +482,6 @@ def determine_expert_map( return (local_num_experts, expert_map) -@run_once -def pplx_init(rank, world_size): - print(f"PPLX_INIT {rank} {world_size}") - uid = nvshmem_get_unique_id( - ) if rank == 0 else nvshmem_alloc_empty_unique_id() - print(f"PPLX_INIT UID={uid}") - torch.distributed.broadcast(uid.cuda(), src=0) - nvshmem_init(uid, rank, world_size) - - class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -599,8 +605,6 @@ def __init__( # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. if quant_config is None: - pplx_init(self.dp_rank, self.dp_size) - moe = MoEConfig( num_experts=self.global_num_experts, experts_per_token=0, From 66ae9855510431c217050b3a679017878b895768 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 7 Apr 2025 19:52:03 +0000 Subject: [PATCH 124/190] call super ctor + fix random stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b68c8bbc6ad..213c9ae6289 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -177,7 +177,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: MoEConfig): - self.all_to_all = get_all_to_all( + super().__init__() + self._moe = moe + self._all_to_all = get_all_to_all( max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, num_experts=moe.num_experts, experts_per_token=moe.experts_per_token, From 4c237849cbd3d73cf981a0cdcdc627a5535b20e1 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 7 Apr 2025 15:53:52 -0400 Subject: [PATCH 125/190] fix use_ep bug Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 213c9ae6289..6dd91311241 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -548,7 +548,7 @@ def __init__( # Use expert parallelism instead of tensor parallelism? vllm_config = get_current_vllm_config() use_ep = (vllm_config.parallel_config.enable_expert_parallel - and self.tp_size * self.dp_size > 1) + and (self.tp_size * self.dp_size) > 1) # For smuggling this layer into the fused moe custom op self.use_direct_call = self.dp_size == 1 From 64e4281705a6b205292a78275f74314ffd73eccd Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 7 Apr 2025 16:30:28 -0400 Subject: [PATCH 126/190] Fix dp_size Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6dd91311241..5b192720f78 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -185,7 +185,7 @@ def __init__(self, moe: MoEConfig): experts_per_token=moe.experts_per_token, rank=moe.ep_rank, world_size=moe.ep_size, - dp_size=moe.dp_size, + dp_size=moe.ep_size // moe.dp_size, hidden_dim=moe.hidden_dim, hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, hidden_dim_scale_bytes=0, From 169403a00908a17c2c156282b40819a31e4da168 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 7 Apr 2025 16:39:36 -0400 Subject: [PATCH 127/190] add comment Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5b192720f78..81389f7b6d7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -185,7 +185,7 @@ def __init__(self, moe: MoEConfig): experts_per_token=moe.experts_per_token, rank=moe.ep_rank, world_size=moe.ep_size, - dp_size=moe.ep_size // moe.dp_size, + dp_size=moe.ep_size // moe.dp_size, # dp_size actually means TP. hidden_dim=moe.hidden_dim, hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, hidden_dim_scale_bytes=0, From c4110cb891cb20ca077c1fe635aa336a476b0cb1 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 7 Apr 2025 16:48:50 -0400 Subject: [PATCH 128/190] fixes Signed-off-by: Tyler Michael Smith Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 0302524fe1c..86fa17561f2 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -103,6 +103,9 @@ def dispatch( # TODO: optimize this? indices = rank_topk_ids.to(dtype=torch.uint32) + # TODO: optimize this? + rank_topk_ids = rank_topk_ids.to(dtype=torch.uint32) + self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, From f4ae47fbe0ae82caf3cad2b19985e805c2a8522e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 7 Apr 2025 20:46:49 +0000 Subject: [PATCH 129/190] get a bit further Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 8 ++++++-- vllm/model_executor/layers/fused_moe/layer.py | 5 +++++ .../layers/fused_moe/pplx_dispatch_combine.py | 3 --- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0bb4835939a..088dc49bf3f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -35,6 +35,9 @@ import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, nvshmem_init, + nvshmem_finalize) import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( @@ -914,8 +917,6 @@ def init_distributed_environment( @run_once def pplx_init(rank, world_size): - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, nvshmem_init) print(f"PPLX_INIT {rank} {world_size}") uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() @@ -1097,6 +1098,8 @@ def get_tensor_model_parallel_rank(): def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP + nvshmem_finalize() + if _TP: _TP.destroy() _TP = None @@ -1112,6 +1115,7 @@ def destroy_model_parallel(): _DP = None + def destroy_distributed_environment(): global _WORLD if _WORLD: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 81389f7b6d7..494b1d955e8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -161,6 +161,11 @@ def get_or_create(self, **kwargs): self._cache[key] = instance return instance + def clear(): + for k, v in self._cache.items(): + v.destroy() + del self._cache + # Global singleton _all_to_all_cache = AllToAllCache() diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 86fa17561f2..0302524fe1c 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -103,9 +103,6 @@ def dispatch( # TODO: optimize this? indices = rank_topk_ids.to(dtype=torch.uint32) - # TODO: optimize this? - rank_topk_ids = rank_topk_ids.to(dtype=torch.uint32) - self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, From 1d7aa87bf5369b07f8f5248c6d07441d0573f34e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 9 Apr 2025 23:06:46 +0000 Subject: [PATCH 130/190] hacking in dispatch_combine Signed-off-by: Bill Nell --- .../layers/fused_moe/fused_moe.py | 10 +- vllm/model_executor/layers/fused_moe/layer.py | 122 +++++++++++++----- .../layers/fused_moe/modular_kernel.py | 18 ++- .../layers/fused_moe/pplx_dispatch_combine.py | 6 + .../layers/fused_moe/triton_deep_gemm_moe.py | 104 +++++++++++++++ .../model_executor/layers/quantization/fp8.py | 78 +++++++---- 6 files changed, 273 insertions(+), 65 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1197de1a5f6..a8b32945664 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1666,6 +1666,9 @@ def apply( raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") + #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") + #print(f"BLOCK_M = {self.block_m}") + # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, @@ -1676,8 +1679,11 @@ def apply( (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) + moe_align_block_size( + topk_ids, + config['BLOCK_SIZE_M'] if self.block_m is None else self.block_m, + global_num_experts, expert_map + )) invoke_fused_moe_kernel(hidden_states, w1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 494b1d955e8..61b18cfd603 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -30,7 +30,7 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine from .fused_moe import TritonExperts, fused_experts - from .modular_kernel import FusedMoEModularKernel + from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine from .pplx_dispatch_combine import PplxDispatchCombine else: fused_experts = None # type: ignore @@ -77,6 +77,9 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError + def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + return False + @abstractmethod def apply( self, @@ -118,7 +121,6 @@ def get_or_create(self, **kwargs): return instance else: # Create new instance - print("CREATE AllToAll") instance = pplx.AllToAll(**kwargs) # Use a weakref.ref with a callback when reference is collected refs = [ @@ -183,18 +185,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: MoEConfig): super().__init__() - self._moe = moe - self._all_to_all = get_all_to_all( - max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, - rank=moe.ep_rank, - world_size=moe.ep_size, - dp_size=moe.ep_size // moe.dp_size, # dp_size actually means TP. - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - hidden_dim_scale_bytes=0, - ) + self.fused_experts = fused_experts + self.moe = moe def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -293,6 +285,26 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) + # Maybe extra args + def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) + print(f"block_m = {block_m}") + + experts = TritonExperts( + use_fp8_w8a8 = False, + use_int8_w8a16 = False, + use_int4_w4a16 = False, + block_shape = None, + block_m = None, #block_m, + ) + + self.fused_experts = FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + return True + def forward_cuda( self, layer: torch.nn.Module, @@ -323,8 +335,8 @@ def forward_cuda( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts( - a1=x, + return self.fused_experts( + hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, @@ -333,7 +345,8 @@ def forward_cuda( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - expert_map=expert_map) + expert_map=expert_map, + ) def forward_cpu( self, @@ -609,27 +622,67 @@ def __init__( from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + moe = MoEConfig( + num_experts=self.global_num_experts, + experts_per_token=top_k, # ? must be same as topk_ids.shape[1] + hidden_dim=hidden_size, + num_local_experts=self.local_num_experts, + dp_size=self.dp_size, + dp_rank=self.dp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + #in_dtype = 0, + #out_dtype = 0, + ) + # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. + quant_method: Optional[FusedMoEMethodBase] = None + if quant_config is None: - moe = MoEConfig( - num_experts=self.global_num_experts, - experts_per_token=0, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - dp_size=self.dp_size, - dp_rank=self.dp_rank, - ep_size=self.ep_size, - ep_rank=self.ep_rank, - #in_dtype = 0, - #out_dtype = 0, + quant_method = UnquantizedFusedMoEMethod(moe) + else: + # moe? + # TODO: setup dispatcher on FusedMoE. callees of this + # function can grab dispatcher from there? Or add + # supports_dispatcher/set_dispatcher method on FusedMoeMethodBase + quant_method = quant_config.get_quant_method(self, prefix) + assert isinstance(quant_method, FusedMoEMethodBase) + + assert quant_method is not None + self.quant_method = quant_method + + # TODO: move to method? + if self.dp_size > 1: + all_to_all = get_all_to_all( + max_num_tokens=MOE_DP_CHUNK_SIZE, # // moe.dp_size, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # has to be same as topk_ids.shape[1] + rank=moe.ep_rank, + world_size=moe.ep_size, + dp_size=moe.ep_size // moe.dp_size, # dp_size actually means TP. + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + hidden_dim_scale_bytes=0, ) - self.quant_method: Optional[QuantizeMethodBase] = ( - UnquantizedFusedMoEMethod(moe)) - else: - self.quant_method = quant_config.get_quant_method(self, prefix) - assert self.quant_method is not None + if False: + dispatch_combine = PplxDispatchCombine( + all_to_all, + MOE_DP_CHUNK_SIZE, + moe.ep_size, + moe.dp_size, + moe.in_dtype, + ) + else: + dispatch_combine = StandardDispatchCombine( + moe.in_dtype, + quant_config.weight_block_size if quant_config is not None else None, + ) + + success = self.quant_method.set_dispatch_combine(dispatch_combine) + if not success: + logger.warning("DP+EP not supported for %s.", type(self.quant_method)) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -976,6 +1029,7 @@ def forward(self, hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): + max_tokens_across_dp = get_forward_context( ).dp_metadata.max_tokens_across_dp cu_tokens_across_dp_cpu = get_forward_context( @@ -983,6 +1037,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, num_tokens_across_dp = get_forward_context( ).dp_metadata.num_tokens_across_dp + #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") + #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a3086dee4b3..f7b3f7899dd 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -60,15 +60,19 @@ def _moe_problem_size( E, N, _ = w1.shape K = w2.shape[1] - assert a1.dim() == 2 assert topk_ids.dim() == 2 - # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.shape[0] == a1.shape[ - 0], f"{topk_ids.shape[0]} != {a1.shape[0]}" - - M = a1.shape[0] topk = topk_ids.shape[1] + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.shape[0] == a1.shape[0], \ + f"{topk_ids.shape[0]} != {a1.shape[0]}" + M = a1.shape[0] + else: + assert a1.dim() == 3 + assert E == a1.shape[0] + M = a1.shape[1] # This is max_num_tokens + return E, M, N, K, topk @@ -311,6 +315,8 @@ def forward( a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) + #print(f"INIT shape: E={E}, M={M}, N={N}, K={K}, top_k={top_k}") + if global_num_experts == -1: global_num_experts = E diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 0302524fe1c..fa717c40c77 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -30,6 +30,10 @@ def __init__(self, self.dp_size = dp_size self.rank = rank self.quant_dtype = quant_dtype + print(f"max_num_tokens = {max_num_tokens}") + print(f"dp_num_tokens = {self.dp_num_tokens}") + print(f"world_size = {world_size}") + print(f"dp_size = {dp_size}") def dispatch( self, @@ -71,6 +75,7 @@ def dispatch( dtype=torch.int32, device=a1.device, ) + expert_num_tokens.fill_(-1) num_dp = self.world_size // self.dp_size expert_x = torch.empty( @@ -78,6 +83,7 @@ def dispatch( dtype=a1q.dtype, device=a1.device, ) + expert_x.fill_(torch.nan) expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py new file mode 100644 index 00000000000..f3a13e44296 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +import importlib.util +from typing import List, Optional, Tuple + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts, + _valid_deep_gemm_shape, + _valid_deep_gemm, +) +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExpert + +class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + allow_deep_gemm: bool = False + ): + super().__init__() + self.triton_expert = TritonExpert( + use_fp8_w8a8, + use_int4_w4a16, + use_int8_w8a16, + block_shape, + block_m + ) + self.deep_gemm_expert = DeepGemmExperts() + self.allow_deep_gemm = allow_deep_gemm + self.use_fp8_w8a8 = use_fp8_w8a8 + + def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, + topk: int, + num_experts: int) -> Tuple[int, int, torch.dtype]: + # Note: the deep gemm workspaces are strictly larger than the triton + # workspaces so we can be pessimistic here and allocate for DeepGemm + # even if we fall back to triton later, e.g. if expert maps are set. + if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): + return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts) + else: + return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: + N = w1.shape[1] + if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 + and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + return self.deep_gemm_expert( + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + ) + else: + return self.triton_expert( + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f4eef830457..a88827dacce 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import functools import importlib.util from typing import Any, Callable, Dict, List, Optional @@ -10,6 +11,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, @@ -426,6 +428,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): """ def __init__(self, quant_config: Fp8Config): + from vllm.model_executor.layers.fused_moe import fused_experts self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None self.allow_deep_gemm = allow_deep_gemm @@ -451,6 +454,11 @@ def __init__(self, quant_config: Fp8Config): logger.warning_once( "DeepGemm not supported on the current platform.") + self.fused_experts = functools.partial( + fused_experts, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm) + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -770,6 +778,32 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale + # Maybe extra args + def set_dispatch_combine(self, dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: + if self.use_marlin: + return False + + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import TritonOrDeepGemmExperts + + #block_m = MOE_DP_CHUNK_SIZE * (moe.ep_size // moe.dp_size) + #print(f"block_m = {block_m}") + + experts = TritonOrDeepGemmExperts( + use_fp8_w8a8 = True, + use_int8_w8a16 = False, + use_int4_w4a16 = False, + block_shape = self.quant_config.weight_block_size, + block_m = None, # TODO + allow_deep_gemm=self.allow_deep_gemm, + ) + + self.fused_experts = mk.FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + return True + def apply( self, layer: torch.nn.Module, @@ -788,8 +822,6 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -816,28 +848,26 @@ def apply( quant_type_id=scalar_types.float8_e4m3fn.id, global_num_experts=global_num_experts, expert_map=expert_map) - - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - ) + else: + return self.fused_experts( + hidden_states=x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=(layer.w13_weight_scale_inv + if self.block_quant else layer.w13_weight_scale), + w2_scale=(layer.w2_weight_scale_inv + if self.block_quant else layer.w2_weight_scale), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): From 13163179ffd386f73a3aa3fb0aed6c9f384e9856 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 10 Apr 2025 14:47:37 +0000 Subject: [PATCH 131/190] hook up some wires Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 130 +++++++----------- .../layers/fused_moe/pplx_dispatch_combine.py | 2 + 2 files changed, 54 insertions(+), 78 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 61b18cfd603..b5872a00b95 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -57,8 +57,10 @@ class MoEConfig: ep_size: int ep_rank: int - in_dtype: torch.dtype = torch.bfloat16 - out_dtype: torch.dtype = torch.bfloat16 + in_dtype: torch.dtype + out_dtype: torch.dtype + + # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 @@ -102,10 +104,10 @@ def apply( raise NotImplementedError -class AllToAllCacheThreadSafe: +class AllToAllCache: def __init__(self): - self._cache = {} + self._cache = weakref.WeakValueDictionary() self._lock = threading.RLock() # Reentrant lock for thread safety def get_or_create(self, **kwargs): @@ -113,61 +115,12 @@ def get_or_create(self, **kwargs): key = tuple(sorted((k, v) for k, v in kwargs.items())) with self._lock: - if key in self._cache: - instance, refs = self._cache[key] - new_ref = weakref.ref(object(), - lambda _: self._decrement_ref_count(key)) - refs.append(new_ref) - return instance - else: - # Create new instance + instance = self._cache.get(key) + if instance is None: instance = pplx.AllToAll(**kwargs) - # Use a weakref.ref with a callback when reference is collected - refs = [ - weakref.ref(object(), - lambda _: self._decrement_ref_count(key)) - ] - self._cache[key] = (instance, refs) - return instance - - def _decrement_ref_count(self, key): - with self._lock: - if key in self._cache: - instance, refs = self._cache[key] - # Remove dead references - refs = [ref for ref in refs if ref() is not None] - if not refs: - # No more references, clean up the instance - instance.destroy() - del self._cache[key] - else: - # Update refs - self._cache[key] = (instance, refs) - - -class AllToAllCache: - - def __init__(self): - self._cache = {} - - def get_or_create(self, **kwargs): - # Create a hashable key from the kwargs - key = tuple(sorted((k, v) for k, v in kwargs.items())) - - if key in self._cache: - return self._cache[key] - else: - # Create new instance - print("CREATE AllToAll") - instance = pplx.AllToAll(**kwargs) - self._cache[key] = instance + self._cache[key] = instance return instance - def clear(): - for k, v in self._cache.items(): - v.destroy() - del self._cache - # Global singleton _all_to_all_cache = AllToAllCache() @@ -622,6 +575,8 @@ def __init__( from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + print(f"params dtype= {params_dtype}") + moe = MoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, # ? must be same as topk_ids.shape[1] @@ -631,8 +586,8 @@ def __init__( dp_rank=self.dp_rank, ep_size=self.ep_size, ep_rank=self.ep_rank, - #in_dtype = 0, - #out_dtype = 0, + in_dtype = params_dtype, # this is probably not right, where to get? + out_dtype = params_dtype, # ditto. ) # Note: get_quant_method will look at the layer's local_num_experts @@ -642,10 +597,6 @@ def __init__( if quant_config is None: quant_method = UnquantizedFusedMoEMethod(moe) else: - # moe? - # TODO: setup dispatcher on FusedMoE. callees of this - # function can grab dispatcher from there? Or add - # supports_dispatcher/set_dispatcher method on FusedMoeMethodBase quant_method = quant_config.get_quant_method(self, prefix) assert isinstance(quant_method, FusedMoEMethodBase) @@ -654,24 +605,47 @@ def __init__( # TODO: move to method? if self.dp_size > 1: - all_to_all = get_all_to_all( - max_num_tokens=MOE_DP_CHUNK_SIZE, # // moe.dp_size, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # has to be same as topk_ids.shape[1] - rank=moe.ep_rank, - world_size=moe.ep_size, - dp_size=moe.ep_size // moe.dp_size, # dp_size actually means TP. - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - hidden_dim_scale_bytes=0, - ) + if True: + max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size + world_size = moe.ep_size + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + rank = moe.ep_rank + + print(f"max num = {max_num_tokens}") + print(f"world size = {world_size}") + print(f"moe ep size = {moe.ep_size}") + print(f"moe dp size = {moe.dp_size}") + print(f"dp size = {dp_size}") + print(f"rank= {rank}") + + all_to_all = get_all_to_all( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=( + 0 + if moe.in_dtype.itemsize != 1 + else ( + (moe.hidden_dim + moe.block_size - 1) + // moe.block_size + * torch.float32.itemsize + ) + ) + ) - if False: dispatch_combine = PplxDispatchCombine( all_to_all, - MOE_DP_CHUNK_SIZE, - moe.ep_size, - moe.dp_size, + max_num_tokens, + world_size, + dp_size, + rank, # just for debugging moe.in_dtype, ) else: @@ -1037,7 +1011,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, num_tokens_across_dp = get_forward_context( ).dp_metadata.num_tokens_across_dp - #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") + print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index fa717c40c77..fd1fbb16751 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -9,6 +9,8 @@ moe_kernel_quantize_input) +logger = init_logger(__name__) + # Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. From b4668540360ea3d13f6875d00a43c3c109d41a0a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 10 Apr 2025 21:48:22 +0000 Subject: [PATCH 132/190] seems to be working Signed-off-by: Bill Nell --- .../layers/fused_moe/deep_gemm_moe.py | 4 +- .../layers/fused_moe/fused_moe.py | 23 +++-- vllm/model_executor/layers/fused_moe/layer.py | 85 ++++++++++--------- .../layers/fused_moe/modular_kernel.py | 6 +- .../layers/fused_moe/pplx_dispatch_combine.py | 11 ++- 5 files changed, 70 insertions(+), 59 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 266ba3bfa07..a694c53d9f3 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -134,7 +134,9 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) - self.activation(activation, workspace2, workspace1.view(-1, N)) + self.activation(activation, + workspace2, + workspace1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a8b32945664..6cb0262cc2a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1678,12 +1678,20 @@ def apply( intermediate_cache3 = _resize_cache(workspace13, (num_tokens, top_k_num, K)) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size( - topk_ids, - config['BLOCK_SIZE_M'] if self.block_m is None else self.block_m, - global_num_experts, expert_map - )) + if hidden_states.dim() == 2: #block_m is None: + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size( + topk_ids, + config['BLOCK_SIZE_M'], + global_num_experts, expert_map + )) + else: + stride = hidden_states.shape[1] + sorted_token_ids = torch.arange(0, hidden_states.shape[0], device=hidden_states.device, dtype=torch.int) + sorted_token_ids = sorted_token_ids * stride + expert_ids = torch.logical_not(torch.isnan(hidden_states)).sum(dim=(1,2)).nonzero() + num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) invoke_fused_moe_kernel(hidden_states, w1, @@ -1706,7 +1714,8 @@ def apply( per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) - self.activation(activation, intermediate_cache2, + self.activation(activation, + intermediate_cache2, intermediate_cache1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b5872a00b95..a0bb84e7fb1 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -241,7 +241,7 @@ def apply( # Maybe extra args def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) - print(f"block_m = {block_m}") + #print(f"block_m = {block_m}") experts = TritonExperts( use_fp8_w8a8 = False, @@ -550,8 +550,8 @@ def __init__( self.ep_size = 1 self.local_num_experts = self.global_num_experts self.expert_map = None + #self.global_num_experts = num_experts redundant? self.top_k = top_k - self.global_num_experts = num_experts assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size @@ -571,11 +571,12 @@ def __init__( if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") + if current_platform.is_hpu(): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - print(f"params dtype= {params_dtype}") + #print(f"params dtype= {params_dtype}") moe = MoEConfig( num_experts=self.global_num_experts, @@ -604,13 +605,13 @@ def __init__( self.quant_method = quant_method # TODO: move to method? - if self.dp_size > 1: - if True: - max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size - world_size = moe.ep_size - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. - rank = moe.ep_rank + if False and self.dp_size > 1: + max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size + world_size = moe.ep_size + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + rank = moe.ep_rank + if False: print(f"max num = {max_num_tokens}") print(f"world size = {world_size}") print(f"moe ep size = {moe.ep_size}") @@ -618,45 +619,45 @@ def __init__( print(f"dp size = {dp_size}") print(f"rank= {rank}") - all_to_all = get_all_to_all( - max_num_tokens=max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=rank, - world_size=world_size, - dp_size=dp_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=( - 0 - if moe.in_dtype.itemsize != 1 - else ( - (moe.hidden_dim + moe.block_size - 1) - // moe.block_size - * torch.float32.itemsize - ) + all_to_all = get_all_to_all( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=( + 0 + if moe.in_dtype.itemsize != 1 + else ( + (moe.hidden_dim + moe.block_size - 1) + // moe.block_size + * torch.float32.itemsize ) ) + ) - dispatch_combine = PplxDispatchCombine( - all_to_all, - max_num_tokens, - world_size, - dp_size, - rank, # just for debugging - moe.in_dtype, - ) - else: - dispatch_combine = StandardDispatchCombine( - moe.in_dtype, - quant_config.weight_block_size if quant_config is not None else None, - ) + dispatch_combine = PplxDispatchCombine( + all_to_all, + max_num_tokens, + world_size, + dp_size, + rank, # just for debugging + moe.in_dtype, + ) success = self.quant_method.set_dispatch_combine(dispatch_combine) if not success: logger.warning("DP+EP not supported for %s.", type(self.quant_method)) + else: + dispatch_combine = StandardDispatchCombine( + moe.in_dtype, + quant_config.weight_block_size if quant_config is not None else None, + ) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -1011,7 +1012,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, num_tokens_across_dp = get_forward_context( ).dp_metadata.num_tokens_across_dp - print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") + #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f7b3f7899dd..a8b8ba65237 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -60,9 +60,6 @@ def _moe_problem_size( E, N, _ = w1.shape K = w2.shape[1] - assert topk_ids.dim() == 2 - topk = topk_ids.shape[1] - if a1.dim() == 2: # Make sure we are using the correct a1 (pre-permute). assert topk_ids.shape[0] == a1.shape[0], \ @@ -73,6 +70,9 @@ def _moe_problem_size( assert E == a1.shape[0] M = a1.shape[1] # This is max_num_tokens + assert topk_ids.dim() == 2 + topk = topk_ids.shape[1] + return E, M, N, K, topk diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index fd1fbb16751..983cc894ffe 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -32,10 +32,6 @@ def __init__(self, self.dp_size = dp_size self.rank = rank self.quant_dtype = quant_dtype - print(f"max_num_tokens = {max_num_tokens}") - print(f"dp_num_tokens = {self.dp_num_tokens}") - print(f"world_size = {world_size}") - print(f"dp_size = {dp_size}") def dispatch( self, @@ -77,7 +73,7 @@ def dispatch( dtype=torch.int32, device=a1.device, ) - expert_num_tokens.fill_(-1) + expert_num_tokens.fill_(-1) # debugging remove num_dp = self.world_size // self.dp_size expert_x = torch.empty( @@ -85,7 +81,7 @@ def dispatch( dtype=a1q.dtype, device=a1.device, ) - expert_x.fill_(torch.nan) + expert_x.fill_(torch.nan) # debugging remove expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: @@ -146,3 +142,6 @@ def combine( weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) + + #print("END COMBINE") + From f005ce6ac089a32d76faedafb458e9f119a7e6d1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 11 Apr 2025 20:33:42 +0000 Subject: [PATCH 133/190] wip Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 1 + .../layers/fused_moe/fused_moe.py | 16 +++++++++----- vllm/model_executor/layers/fused_moe/layer.py | 6 +++-- .../layers/fused_moe/modular_kernel.py | 5 +++++ .../layers/fused_moe/pplx_dispatch_combine.py | 22 ++++++++++++++----- 5 files changed, 37 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 088dc49bf3f..59ca9899ba0 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1098,6 +1098,7 @@ def get_tensor_model_parallel_rank(): def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP + nvshmem_finalize() if _TP: diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 6cb0262cc2a..8b138f8a579 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1686,12 +1686,18 @@ def apply( global_num_experts, expert_map )) else: - stride = hidden_states.shape[1] - sorted_token_ids = torch.arange(0, hidden_states.shape[0], device=hidden_states.device, dtype=torch.int) - sorted_token_ids = sorted_token_ids * stride - expert_ids = torch.logical_not(torch.isnan(hidden_states)).sum(dim=(1,2)).nonzero() - num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int) + #stride = hidden_states.shape[1] + sorted_token_ids = torch.arange(0, num_tokens*hidden_states.shape[1], device=hidden_states.device, dtype=torch.int) + sorted_token_ids = sorted_token_ids.flatten() + nans = torch.isnan(hidden_states).sum(dim=(1,2)) + expert_ids = torch.where((nans > 0).flatten(), -1, torch.arange(0, nans.numel(), device=hidden_states.device, dtype=torch.int32)) + #expert_ids = torch.repeat_interleave(expert_ids, hidden_states.shape[1], dim=0) + #print(f"EXPERT_IDS {nans.shape} {expert_ids}") + #num_tokens_post_padded = torch.tensor([num_tokens], device=hidden_states.device, dtype=torch.int32) + num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int32) + num_tokens_post_padded.fill_(num_tokens) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + #print(f"P = {sorted_token_ids}, {hidden_states.shape}") invoke_fused_moe_kernel(hidden_states, w1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a0bb84e7fb1..eec7bb814b0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -116,7 +116,7 @@ def get_or_create(self, **kwargs): with self._lock: instance = self._cache.get(key) - if instance is None: + if True or instance is None: instance = pplx.AllToAll(**kwargs) self._cache[key] = instance return instance @@ -605,7 +605,7 @@ def __init__( self.quant_method = quant_method # TODO: move to method? - if False and self.dp_size > 1: + if self.dp_size > 1: max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size world_size = moe.ep_size dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. @@ -1030,6 +1030,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] + print(f"loop {chunk_start}:{chunk_end}") + cu_tokens_across_dp_this_iter = torch.cumsum( num_tokens_remaining_across_dp.clamp( max=moe_dp_chunk_size_per_rank), diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a8b8ba65237..76ece80ba47 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -312,6 +312,9 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ + from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank) + print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") + a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) @@ -361,4 +364,6 @@ def forward( self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) + print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") + return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 983cc894ffe..223b5d3d2aa 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -46,6 +46,8 @@ def dispatch( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: # Is this always going to be a1.device? device = a1.device + num_tokens = a1.shape[0] # M + hidden_dim = a1.shape[-1] # K assert expert_map is None, "NYI" @@ -71,7 +73,7 @@ def dispatch( expert_num_tokens = torch.empty( num_local_experts, dtype=torch.int32, - device=a1.device, + device=device, ) expert_num_tokens.fill_(-1) # debugging remove @@ -79,7 +81,7 @@ def dispatch( expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, - device=a1.device, + device=device, ) expert_x.fill_(torch.nan) # debugging remove @@ -95,7 +97,7 @@ def dispatch( (expert_x.size(2) + block_size - 1) // block_size, ), dtype=torch.float32, - device=a1.device, + device=device, ) # This argument is optional, defaults to indices.shape[0] @@ -105,7 +107,7 @@ def dispatch( bound_m = None # TODO: optimize this? - indices = rank_topk_ids.to(dtype=torch.uint32) + indices = rank_topk_ids.to(dtype=torch.uint32).to(device) self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -126,8 +128,17 @@ def combine( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, ) -> None: + device = fused_expert_output.device + #device = torch.device("cuda", self.rank) + #device = get_dp_group().device + #assert fused_expert_output.device == device + + print(f"COMBINE START {self.rank}") + # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + #num_tokens = fused_expert_output.shape[0] # M + #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) bound_m = None assert output.shape[0] <= self.max_num_tokens @@ -143,5 +154,4 @@ def combine( expert_y=fused_expert_output, bound_m=bound_m) - #print("END COMBINE") - + print(f"COMBINE END {self.rank}") From b4dc0b18cd8c392b5cbe72586d5746d3a18f1672 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 14 Apr 2025 21:35:58 +0000 Subject: [PATCH 134/190] batched moe test Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 138 ++++++++++++++++++++++++++++- vllm/distributed/parallel_state.py | 33 +++++-- 2 files changed, 161 insertions(+), 10 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 5250dd82fa1..6a58397947b 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -14,7 +14,7 @@ from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import fused_moe, fused_experts from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) @@ -27,6 +27,7 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +from vllm.model_executor.layers.activation import SiluAndMul NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] @@ -108,6 +109,141 @@ def test_fused_moe( rtol=0) +def batch_by_experts( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int +) -> torch.Tensor: + #print(topk_ids.shape, topk_ids) + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device) + for i in range(topk_ids.shape[0]): + for j in range(topk_ids.shape[1]): + expert_id = topk_ids[i, j] + tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1 + + #print(f"token_per_expert {tokens_per_expert.max()}") + max_num_tokens = tokens_per_expert.max() + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), + dtype=a.dtype, device=a.device) + #print(f"b_a shape {b_a.shape}") + + #experts_per_token = torch.zeros(a.shape[0], dtype=torch.int, device=a.device) + + for i in range(topk_ids.shape[0]): + for j in range(topk_ids.shape[1]): + expert_id = topk_ids[i, j] + #idx = experts_per_token[i] + b_a[expert_id, j:j+1, :] = a[i, :] + #experts_per_token[i] = experts_per_token[i] + 1 + + return b_a, tokens_per_expert + + +def unbatch_output(b_out, topk_ids, K): + num_tokens, topk = topk_ids.shape + + #print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}") + num_experts = b_out.shape[0] + out = torch.zeros((num_tokens, topk, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + #print(f"b_out[0] = {b_out[0].shape}") + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, i:i+1, :] = b_out[expert_id, idx:idx+1, :] + idx = idx + 1 + expert_counts[expert_id] = idx + + return out + + +def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): + assert a.dim() == 3 + #print(f"A = {a.shape} {a[0, :, :].shape}") + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = a.shape + num_experts = w1.shape[0] + out = torch.zeros((num_experts, max_num_tokens, w2.shape[1]), dtype=a.dtype, device=a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + #out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + + out = unbatch_output(out, topk_ids, w2.shape[1]) + + return (out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1) + + +def torch_moe2(a, w1, w2, topk_weight, topk_ids): + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + score = torch.randn((m, e), device="cuda", dtype=dtype) + e_map = None + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) + + if True: + triton_output = torch_batched_moe(b_a, + w1, + w2, + tokens_per_expert, + topk_weight, + topk_ids) + else: + triton_output = fused_experts(a, # b_a + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e) + + #torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + @pytest.mark.parametrize("m", [1, 32, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 1024]) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 59ca9899ba0..ade7b5183dd 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -915,16 +915,31 @@ def init_distributed_environment( "world group already initialized with a different world size") +PPLX_DID_INIT: bool = False + @run_once def pplx_init(rank, world_size): - print(f"PPLX_INIT {rank} {world_size}") - uid = nvshmem_get_unique_id( - ) if rank == 0 else nvshmem_alloc_empty_unique_id() - uid_gpu = uid.cuda() - get_world_group().broadcast(uid_gpu, src=0) - print(f"PPLX_INIT UID={uid_gpu}") - uid = uid_gpu.to(device='cpu') - nvshmem_init(uid, rank, world_size) + if world_size > 1: + try: + global PPLX_DID_INIT + print(f"PPLX_INIT {rank} {world_size}") + uid = nvshmem_get_unique_id( + ) if rank == 0 else nvshmem_alloc_empty_unique_id() + uid_gpu = uid.cuda() + get_world_group().broadcast(uid_gpu, src=0) + print(f"PPLX_INIT UID={uid_gpu}") + uid = uid_gpu.to(device='cpu') + nvshmem_init(uid, rank, world_size) + PPLX_DID_INIT = True + except Exception as ex: + logger.error("Failed to initialize nvshmem for pplx: %s", ex) + + +@run_once +def pplx_finalize(): + global PPLX_DID_INIT + if PPLX_DID_INIT: + nvshmem_finalize() def initialize_model_parallel( @@ -1099,7 +1114,7 @@ def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP - nvshmem_finalize() + pplx_finalize() if _TP: _TP.destroy() From 7490b67e2ed2b0d644f1949eb339abe38392629a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 02:13:33 +0000 Subject: [PATCH 135/190] simple test Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 58 +++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 6a58397947b..9bd928fc737 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -118,9 +118,12 @@ def batch_by_experts( assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] + num_tokens = a.shape[0] + topk = topk_ids.shape[1] + tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device) - for i in range(topk_ids.shape[0]): - for j in range(topk_ids.shape[1]): + for i in range(num_tokens): + for j in range(topk): expert_id = topk_ids[i, j] tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1 @@ -130,34 +133,41 @@ def batch_by_experts( dtype=a.dtype, device=a.device) #print(f"b_a shape {b_a.shape}") - #experts_per_token = torch.zeros(a.shape[0], dtype=torch.int, device=a.device) + experts_per_token = torch.zeros(num_experts, dtype=torch.int, device=a.device) - for i in range(topk_ids.shape[0]): - for j in range(topk_ids.shape[1]): - expert_id = topk_ids[i, j] - #idx = experts_per_token[i] - b_a[expert_id, j:j+1, :] = a[i, :] - #experts_per_token[i] = experts_per_token[i] + 1 + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = experts_per_token[expert_id] + b_a[expert_id, idx:idx+1, :] = a[token, :] + experts_per_token[expert_id] = experts_per_token[expert_id] + 1 + + if False: + print(f"topk_ids = {topk_ids}") + print(f"tokens_per_expert = {tokens_per_expert}") + print(f"experts_per_token = {experts_per_token}") return b_a, tokens_per_expert -def unbatch_output(b_out, topk_ids, K): +def unbatch_output(b_out, topk_weight, topk_ids, K): num_tokens, topk = topk_ids.shape #print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}") num_experts = b_out.shape[0] - out = torch.zeros((num_tokens, topk, K), dtype=b_out.dtype, device=b_out.device) + topk = topk_ids.shape[1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + experts = torch.arange(0, num_experts, dtype=torch.int, device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] #print(f"b_out[0] = {b_out[0].shape}") for i in range(expert_ids.numel()): expert_id = expert_ids[i] idx = expert_counts[expert_id] - out[token, i:i+1, :] = b_out[expert_id, idx:idx+1, :] - idx = idx + 1 - expert_counts[expert_id] = idx + #print(f"out = {out[token, :].shape}, b_out = {b_out[expert_id, idx:idx+1, :].shape}, topk_w = {topk_weight[token, i]}") + out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 return out @@ -175,9 +185,9 @@ def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): #out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) - out = unbatch_output(out, topk_ids, w2.shape[1]) + out = unbatch_output(out, topk_weight, topk_ids, K) - return (out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1) + return out #(out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1) def torch_moe2(a, w1, w2, topk_weight, topk_ids): @@ -202,6 +212,12 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +#@pytest.mark.parametrize("m", [33]) +#@pytest.mark.parametrize("n", [128]) +#@pytest.mark.parametrize("k", [128]) +#@pytest.mark.parametrize("e", [8]) +#@pytest.mark.parametrize("topk", [2]) +#@pytest.mark.parametrize("dtype", [torch.float16]) def test_fused_moe_batched_experts( m: int, n: int, @@ -210,12 +226,13 @@ def test_fused_moe_batched_experts( topk: int, dtype: torch.dtype, ): + current_platform.seed_everything(7) + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - e_map = None vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): @@ -240,6 +257,13 @@ def test_fused_moe_batched_experts( topk_ids, global_num_experts=e) + if False: + torch.set_printoptions(profile="full") + print("BASELINE") + print(torch_output) + print("OUTPUT") + print(triton_output) + #torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) From 11edcc119b3a5244ab8c9d764e01d3da76e0fca4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 02:16:23 +0000 Subject: [PATCH 136/190] cleanup Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 9bd928fc737..ac1abdbe88c 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -114,7 +114,6 @@ def batch_by_experts( topk_ids: torch.Tensor, num_experts: int ) -> torch.Tensor: - #print(topk_ids.shape, topk_ids) assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] @@ -127,25 +126,19 @@ def batch_by_experts( expert_id = topk_ids[i, j] tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1 - #print(f"token_per_expert {tokens_per_expert.max()}") max_num_tokens = tokens_per_expert.max() b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), dtype=a.dtype, device=a.device) #print(f"b_a shape {b_a.shape}") - experts_per_token = torch.zeros(num_experts, dtype=torch.int, device=a.device) + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) for token in range(num_tokens): for j in range(topk): expert_id = topk_ids[token, j] - idx = experts_per_token[expert_id] + idx = token_counts[expert_id] b_a[expert_id, idx:idx+1, :] = a[token, :] - experts_per_token[expert_id] = experts_per_token[expert_id] + 1 - - if False: - print(f"topk_ids = {topk_ids}") - print(f"tokens_per_expert = {tokens_per_expert}") - print(f"experts_per_token = {experts_per_token}") + token_counts[expert_id] = token_counts[expert_id] + 1 return b_a, tokens_per_expert @@ -153,7 +146,6 @@ def batch_by_experts( def unbatch_output(b_out, topk_weight, topk_ids, K): num_tokens, topk = topk_ids.shape - #print(f"b_out = {b_out.shape} M={num_tokens}, K={K}, topk={topk}") num_experts = b_out.shape[0] topk = topk_ids.shape[1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) @@ -161,11 +153,9 @@ def unbatch_output(b_out, topk_weight, topk_ids, K): experts = torch.arange(0, num_experts, dtype=torch.int, device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] - #print(f"b_out[0] = {b_out[0].shape}") for i in range(expert_ids.numel()): expert_id = expert_ids[i] idx = expert_counts[expert_id] - #print(f"out = {out[token, :].shape}, b_out = {b_out[expert_id, idx:idx+1, :].shape}, topk_w = {topk_weight[token, i]}") out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] expert_counts[expert_id] = expert_counts[expert_id] + 1 @@ -174,7 +164,6 @@ def unbatch_output(b_out, topk_weight, topk_ids, K): def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): assert a.dim() == 3 - #print(f"A = {a.shape} {a[0, :, :].shape}") num_tokens, topk = topk_ids.shape _, max_num_tokens, K = a.shape num_experts = w1.shape[0] @@ -182,12 +171,12 @@ def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): for expert in range(num_experts): num = tokens_per_expert[expert] if num > 0: - #out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) - out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + #out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) out = unbatch_output(out, topk_weight, topk_ids, K) - return out #(out * topk_weight.view(num_tokens, -1, 1).to(out.dtype)).sum(dim=1) + return out def torch_moe2(a, w1, w2, topk_weight, topk_ids): @@ -212,12 +201,6 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -#@pytest.mark.parametrize("m", [33]) -#@pytest.mark.parametrize("n", [128]) -#@pytest.mark.parametrize("k", [128]) -#@pytest.mark.parametrize("e", [8]) -#@pytest.mark.parametrize("topk", [2]) -#@pytest.mark.parametrize("dtype", [torch.float16]) def test_fused_moe_batched_experts( m: int, n: int, @@ -250,7 +233,7 @@ def test_fused_moe_batched_experts( topk_weight, topk_ids) else: - triton_output = fused_experts(a, # b_a + triton_output = fused_experts(b_a, w1, w2, topk_weight, @@ -264,7 +247,6 @@ def test_fused_moe_batched_experts( print("OUTPUT") print(triton_output) - #torch.testing.assert_close(triton_b_output, torch_b_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) From 74f0a54333b998ce90a06c855135ed51f906e7ec Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 15:01:31 +0000 Subject: [PATCH 137/190] test pplx w/naive implementation Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 21 +++--- .../layers/fused_moe/fused_moe.py | 66 +++++++++++++++++++ vllm/model_executor/layers/fused_moe/layer.py | 21 +++--- .../layers/fused_moe/triton_deep_gemm_moe.py | 17 +++-- 4 files changed, 99 insertions(+), 26 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index ac1abdbe88c..a01fa46ae9a 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -120,11 +120,7 @@ def batch_by_experts( num_tokens = a.shape[0] topk = topk_ids.shape[1] - tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a.device) - for i in range(num_tokens): - for j in range(topk): - expert_id = topk_ids[i, j] - tokens_per_expert[expert_id] = tokens_per_expert[expert_id] + 1 + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) max_num_tokens = tokens_per_expert.max() b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), @@ -172,7 +168,6 @@ def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): num = tokens_per_expert[expert] if num > 0: out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) - #out[expert, :, :] = SiluAndMul()(a[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) out = unbatch_output(out, topk_weight, topk_ids, K) @@ -233,12 +228,14 @@ def test_fused_moe_batched_experts( topk_weight, topk_ids) else: - triton_output = fused_experts(b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e) + triton_output = fused_batched_experts( + b_a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e + ) if False: torch.set_printoptions(profile="full") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 8b138f8a579..3bd8430ac73 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1754,6 +1754,72 @@ def apply( return intermediate_cache3 +class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + ): + super().__init__() + assert not use_fp8_w8a8 + assert not use_int4_w4a16 + assert not use_int8_w8a16 + assert block_shape is None + assert block_m is None + + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + a: torch.Tensor, + ) -> Tuple[int, int, torch.dtype]: + max_num_tokens = a.shape[1] + workspace13 = num_experts * max_num_tokens * K + workspace2 = M * topk * N * num_experts + return (workspace13, workspace2, a_dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + ) -> torch.Tensor: + from vllm.model_executor.layers.activation import SiluAndMul + assert hidden_states.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = hidden_states.shape + num_experts = w1.shape[0] + out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) + #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + for expert in range(num_experts): + num = 1 #tokens_per_expert[expert] + if num > 0: + #out[expert, :num, :] = SiluAndMul(hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + out[expert, :, :] = SiluAndMul()(hidden_states[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + + return out + + def modular_triton_fused_moe( use_fp8_w8a8: bool, use_int8_w8a8: bool, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index eec7bb814b0..c1f2655ba78 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,7 +29,7 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine - from .fused_moe import TritonExperts, fused_experts + from .fused_moe import TritonExperts, BatchedExperts, fused_experts from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine from .pplx_dispatch_combine import PplxDispatchCombine else: @@ -243,13 +243,16 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) #print(f"block_m = {block_m}") - experts = TritonExperts( - use_fp8_w8a8 = False, - use_int8_w8a16 = False, - use_int4_w4a16 = False, - block_shape = None, - block_m = None, #block_m, - ) + if False: + experts = TritonExperts( + use_fp8_w8a8 = False, + use_int8_w8a16 = False, + use_int4_w4a16 = False, + block_shape = None, + block_m = None, #block_m, + ) + else: + experts = BatchedExperts() self.fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -1030,7 +1033,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - print(f"loop {chunk_start}:{chunk_end}") + #print(f"loop {chunk_start}:{chunk_end}") cu_tokens_across_dp_this_iter = torch.cumsum( num_tokens_remaining_across_dp.clamp( diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index f3a13e44296..21cba37478e 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -35,16 +35,23 @@ def __init__( self.allow_deep_gemm = allow_deep_gemm self.use_fp8_w8a8 = use_fp8_w8a8 - def workspace_shapes(self, a_dtype: torch.dtype, M: int, N: int, K: int, - topk: int, - num_experts: int) -> Tuple[int, int, torch.dtype]: + def workspace_shapes( + self, + a_dtype: torch.dtype, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + a: torch.Tensor, + ) -> Tuple[int, int, torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): - return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts) + return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a) else: - return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts) + return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a) def apply( self, From 406924d266d2486e10eb5fc99a4dd2694201545f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 15:02:05 +0000 Subject: [PATCH 138/190] test pplx w/naive implementation Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index a01fa46ae9a..75d8c9f857b 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -14,8 +14,9 @@ from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe import fused_moe, fused_experts -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( From 80164b9ae1a48d5802e7bdb0af352583a5b9d6ff Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Apr 2025 17:13:17 +0000 Subject: [PATCH 139/190] hack fix for chunking loop Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 33 ++++++++++--------- .../layers/fused_moe/fused_moe.py | 10 +++--- vllm/model_executor/layers/fused_moe/layer.py | 22 +++++++++++-- .../layers/fused_moe/modular_kernel.py | 6 ++-- .../layers/fused_moe/pplx_dispatch_combine.py | 4 +-- 5 files changed, 48 insertions(+), 27 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 75d8c9f857b..5018bc33bfb 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -110,7 +110,7 @@ def test_fused_moe( rtol=0) -def batch_by_experts( +def torch_dispatch( a: torch.Tensor, topk_ids: torch.Tensor, num_experts: int @@ -140,14 +140,14 @@ def batch_by_experts( return b_a, tokens_per_expert -def unbatch_output(b_out, topk_weight, topk_ids, K): +def torch_combine(b_out, topk_weight, topk_ids): num_tokens, topk = topk_ids.shape num_experts = b_out.shape[0] topk = topk_ids.shape[1] + K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) - experts = torch.arange(0, num_experts, dtype=torch.int, device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(expert_ids.numel()): @@ -159,22 +159,25 @@ def unbatch_output(b_out, topk_weight, topk_ids, K): return out -def torch_batched_moe(a, w1, w2, tokens_per_expert, topk_weight, topk_ids): - assert a.dim() == 3 - num_tokens, topk = topk_ids.shape - _, max_num_tokens, K = a.shape +def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): num_experts = w1.shape[0] - out = torch.zeros((num_experts, max_num_tokens, w2.shape[1]), dtype=a.dtype, device=a.device) + b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and K == w2.shape[1] + out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) for expert in range(num_experts): num = tokens_per_expert[expert] if num > 0: - out[expert, :num, :] = SiluAndMul()(a[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) - out = unbatch_output(out, topk_weight, topk_ids, K) - - return out + return torch_combine(out, topk_weight, topk_ids) +# TODO: same as torch_moe but with fused_topk factored out. def torch_moe2(a, w1, w2, topk_weight, topk_ids): M, K = a.shape topk = topk_ids.shape[1] @@ -219,16 +222,14 @@ def test_fused_moe_batched_experts( torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - if True: - triton_output = torch_batched_moe(b_a, + triton_output = torch_batched_moe(a, w1, w2, - tokens_per_expert, topk_weight, topk_ids) else: + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) triton_output = fused_batched_experts( b_a, w1, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3bd8430ac73..e0422fe7bae 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1783,7 +1783,7 @@ def workspace_shapes( ) -> Tuple[int, int, torch.dtype]: max_num_tokens = a.shape[1] workspace13 = num_experts * max_num_tokens * K - workspace2 = M * topk * N * num_experts + workspace2 = max_num_tokens * (N // 2) return (workspace13, workspace2, a_dtype) def apply( @@ -1810,12 +1810,14 @@ def apply( _, max_num_tokens, K = hidden_states.shape num_experts = w1.shape[0] out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) + # causes deadlock #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) for expert in range(num_experts): - num = 1 #tokens_per_expert[expert] + num = max_num_tokens #tokens_per_expert[expert] if num > 0: - #out[expert, :num, :] = SiluAndMul(hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) - out[expert, :, :] = SiluAndMul()(hidden_states[expert,:,:] @ w1[expert].transpose(0, 1)) @ w2[expert].transpose(0, 1) + tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) + torch.ops._C.silu_and_mul(tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c1f2655ba78..410c1e6176e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1029,11 +1029,15 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) - for _ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): + #print(f"ORIGINAL SHAPE {full_hidden_states.shape}") + + #print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}") + + for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - #print(f"loop {chunk_start}:{chunk_end}") + #print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape}") cu_tokens_across_dp_this_iter = torch.cumsum( num_tokens_remaining_across_dp.clamp( @@ -1063,6 +1067,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, activation=self.activation, ) + #print(f"final1 = {final_hidden_states.shape}") + if self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ self.dp_rank - 1] @@ -1072,19 +1078,31 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] + #print(f"final2 (AR) = {final_hidden_states.shape}") + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) + #print(f"final3 (AR) = {final_hidden_states.shape}") + full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) + #print(f"full final = {full_final_hidden_states.shape}") + # Update bounds num_tokens_remaining_across_dp = torch.clamp( num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) + #print(f"num remaining = {num_tokens_remaining_across_dp}") + + # HACK FIX + if num_tokens_remaining_across_dp.sum() == 0: + break + def update_chunk_bound(x: int): return min(x + moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 76ece80ba47..35f8b829277 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -312,8 +312,8 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank) - print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") + #from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank) + #print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) @@ -364,6 +364,6 @@ def forward( self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) - print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") + #print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 223b5d3d2aa..9377d6d6331 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -133,7 +133,7 @@ def combine( #device = get_dp_group().device #assert fused_expert_output.device == device - print(f"COMBINE START {self.rank}") + #print(f"COMBINE START {self.rank}") # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -154,4 +154,4 @@ def combine( expert_y=fused_expert_output, bound_m=bound_m) - print(f"COMBINE END {self.rank}") + #print(f"COMBINE END {self.rank}") From 84ea0bddffdf9af6354fc7a1513efed9fc01d223 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 16 Apr 2025 20:34:49 +0000 Subject: [PATCH 140/190] wip. add pplx unit test Signed-off-by: Bill Nell --- examples/offline_inference/data_parallel.py | 2 +- tests/kernels/moe/test_moe.py | 2 - tests/kernels/test_pplx_moe.py | 432 ++++++++++++++++++ .../layers/fused_moe/fused_moe.py | 93 +++- vllm/model_executor/layers/fused_moe/layer.py | 36 +- .../layers/fused_moe/modular_kernel.py | 2 +- .../layers/fused_moe/pplx_dispatch_combine.py | 2 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 3 + 8 files changed, 550 insertions(+), 22 deletions(-) create mode 100644 tests/kernels/test_pplx_moe.py diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 965915beaf5..1c070105189 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -160,7 +160,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=300) + proc.join(timeout=3000) if proc.exitcode is None: print(f"Killing process {proc.pid} that " f"didn't stop within 5 minutes.") diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 5018bc33bfb..d24039749db 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -142,9 +142,7 @@ def torch_dispatch( def torch_combine(b_out, topk_weight, topk_ids): num_tokens, topk = topk_ids.shape - num_experts = b_out.shape[0] - topk = topk_ids.shape[1] K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py new file mode 100644 index 00000000000..b3b8817c69c --- /dev/null +++ b/tests/kernels/test_pplx_moe.py @@ -0,0 +1,432 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE layers. + +Run `pytest tests/kernels/test_pplx_moe.py`. +""" +import dataclasses +import os +import pytest +import torch +from torch.nn import Parameter +from torch.nn import functional as F +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] +from typing import Callable, Concatenate, ParamSpec + +from pplx_kernels import AllToAll +from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, +) + +import vllm.model_executor.layers.fused_moe # noqa +from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, + torch_moe, torch_moe_single) +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe import fused_moe +#from vllm.model_executor.layers.fused_moe.fused_batched_moe import fused_batched_experts +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( + fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + quantize_weights) +from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.model_executor.layers.activation import SiluAndMul + +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine +from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] +TOP_KS = [2, 6] + +P = ParamSpec("P") + +require_multi_node = pytest.mark.skipif( + "MASTER_ADDR" not in os.environ, + reason="Requires multi-node environment", +) + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception: + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + + args, + nprocs=world_local_size, + join=True, + ) + + +def torch_dispatch( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int +) -> torch.Tensor: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + num_tokens = a.shape[0] + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + + max_num_tokens = tokens_per_expert.max() + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), + dtype=a.dtype, device=a.device) + #print(f"b_a shape {b_a.shape}") + + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = token_counts[expert_id] + b_a[expert_id, idx:idx+1, :] = a[token, :] + token_counts[expert_id] = token_counts[expert_id] + 1 + + return b_a, tokens_per_expert + + +def torch_combine(b_out, topk_weight, topk_ids): + num_tokens, topk = topk_ids.shape + num_experts = b_out.shape[0] + K = b_out.shape[-1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return out + + +def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): + num_experts = w1.shape[0] + b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and K == w2.shape[1] + out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) + + return torch_combine(out, topk_weight, topk_ids) + + +# TODO: same as torch_moe but with fused_topk factored out. +def torch_moe2(a, w1, w2, topk_weight, topk_ids): + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + if True: + triton_output = torch_batched_moe(a, + w1, + w2, + topk_weight, + topk_ids) + else: + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) + triton_output = fused_batched_experts( + b_a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e + ) + + if False: + torch.set_printoptions(profile="full") + print("BASELINE") + print(torch_output) + print("OUTPUT") + print(triton_output) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + + +def torch_pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): + hidden_dim = a.shape[-1] + num_experts = w1.shape[0] + num_local_experts = num_experts // pgi.world_size + block_size = 128 + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + max_num_tokens = tokens_per_expert.max() + rank = pgi.rank + + ata = AllToAll( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=pgi.world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=( + 0 + if a.dtype.itemsize != 1 + else ( + (hidden_dim + block_size - 1) + // block_size + * torch.float32.itemsize + ) + ), + ) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + pgi.world_size, + dp_size, + rank, + a.dtype, + ) + + experts = BatchedExperts() + + fused_experts = FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + out = fused_experts( + a, + w1, + w2, + topk_weight, + topk_ids + ) + + ata.destroy() + + return out + + + +def _pplx_moe( + pgi: ProcessGroupInfo, + dp_size: int, + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + triton_output = torch_pplx_moe(pgi, + a, + w1, + w2, + topk_weight, + topk_ids) + + if False: + torch.set_printoptions(profile="full") + print("BASELINE") + print(torch_output) + print("OUTPUT") + print(triton_output) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_pplx_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + world_size = 4 + dp_size = 2 + parallel_launch( + world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype + ) + diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e0422fe7bae..7ee1ee93304 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1754,6 +1754,82 @@ def apply( return intermediate_cache3 +class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): + def __init__(self, + world_size: int, + rank: int): + super().__init__() + self.world_size = world_size + self.rank = rank + + def dispatch( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a1.shape[0] + + num_tokens = a1.shape[0] + topk = topk_ids.shape[1] + + #assert num_experts % self.world_size == 0 + #num_local_experts = num_experts // self.world_size + + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + max_num_tokens = tokens_per_expert.max() + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) + + b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), + dtype=a1.dtype, device=a1.device) + + #print(f"START DISPATCH {hex(id(self))}") + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = expert_counts[expert_id] + b_a1[expert_id, idx:idx+1, :] = a1[token, :] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + #print(f"END DISPATCH {hex(id(self))}: tokens_per_expert {(tokens_per_expert > 0).nonzero().view(-1)}") + + return b_a1, a1_scale, tokens_per_expert + + def combine( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ) -> None: + if False: + print(f"topk_ids {topk_ids.shape}") + print(f"fused_expert_output {fused_expert_output.shape}") + print(f"output {output.shape}") + print(f"counts {self.expert_counts.shape}") + + #print(f"START COMBINE {hex(id(self))}") + + num_tokens, topk = topk_ids.shape + num_experts, _, K = fused_expert_output.shape + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=fused_expert_output.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(topk_ids.shape[1]): + expert_id = expert_ids[i] + if expert_id < num_experts: + idx = expert_counts[expert_id] + output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] * topk_weights[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + #print(f"END COMBINE {hex(id(self))}") + + class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -1803,21 +1879,28 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: - from vllm.model_executor.layers.activation import SiluAndMul + #print("START EXPERTS") assert hidden_states.dim() == 3 + assert expert_num_tokens is not None num_tokens, topk = topk_ids.shape _, max_num_tokens, K = hidden_states.shape num_experts = w1.shape[0] out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) - # causes deadlock - #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) for expert in range(num_experts): - num = max_num_tokens #tokens_per_expert[expert] + num = expert_num_tokens[expert] if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) - torch.ops._C.silu_and_mul(tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) + self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + # fill remainder with 0??? + #out[expert, num:, :].fill_(0) + else: + #out[expert, :, :].fill_(0) # ?? + pass + + #print("END EXPERTS") return out diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 410c1e6176e..0f6eccab5a4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,7 +29,7 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine - from .fused_moe import TritonExperts, BatchedExperts, fused_experts + from .fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine from .pplx_dispatch_combine import PplxDispatchCombine else: @@ -116,7 +116,7 @@ def get_or_create(self, **kwargs): with self._lock: instance = self._cache.get(key) - if True or instance is None: + if instance is None: instance = pplx.AllToAll(**kwargs) self._cache[key] = instance return instance @@ -240,10 +240,15 @@ def apply( # Maybe extra args def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + assert self.fused_experts == fused_experts + block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) #print(f"block_m = {block_m}") - if False: + if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): + logger.info("BatchedExperts") + experts = BatchedExperts() + else: experts = TritonExperts( use_fp8_w8a8 = False, use_int8_w8a16 = False, @@ -251,8 +256,6 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine block_shape = None, block_m = None, #block_m, ) - else: - experts = BatchedExperts() self.fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -609,6 +612,7 @@ def __init__( # TODO: move to method? if self.dp_size > 1: + logger.info("using pplx dispatch") max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size world_size = moe.ep_size dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. @@ -652,15 +656,22 @@ def __init__( rank, # just for debugging moe.in_dtype, ) - - success = self.quant_method.set_dispatch_combine(dispatch_combine) - if not success: - logger.warning("DP+EP not supported for %s.", type(self.quant_method)) - else: + elif False: + logger.info("using standard dispatch") dispatch_combine = StandardDispatchCombine( moe.in_dtype, quant_config.weight_block_size if quant_config is not None else None, ) + else: + logger.info("using batched dispatch") + dispatch_combine = BatchedDispatchCombine( + moe.ep_size, + moe.ep_rank, + ) + + success = self.quant_method.set_dispatch_combine(dispatch_combine) + if not success: + logger.warning("DP+EP not supported for %s.", type(self.quant_method)) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -1030,7 +1041,6 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_final_hidden_states = torch.empty_like(full_hidden_states) #print(f"ORIGINAL SHAPE {full_hidden_states.shape}") - #print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}") for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): @@ -1090,7 +1100,7 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) - #print(f"full final = {full_final_hidden_states.shape}") + #print(f"partial final = {full_final_hidden_states.shape}") # Update bounds num_tokens_remaining_across_dp = torch.clamp( @@ -1110,6 +1120,8 @@ def update_chunk_bound(x: int): chunk_start = update_chunk_bound(chunk_start) chunk_end = update_chunk_bound(chunk_end) + #print(f"full final shape {full_final_hidden_states.shape}") + return full_final_hidden_states def forward_impl(self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 35f8b829277..96ecf5990a6 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -323,7 +323,7 @@ def forward( if global_num_experts == -1: global_num_experts = E - output = a1 if inplace else torch.empty_like(a1) + output = a1 if inplace else torch.zeros_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( self.fused_experts.workspace_shapes(a1, M, N, K, top_k, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 9377d6d6331..a36c825d9e7 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -75,7 +75,7 @@ def dispatch( dtype=torch.int32, device=device, ) - expert_num_tokens.fill_(-1) # debugging remove + #expert_num_tokens.fill_(-1) # debugging remove num_dp = self.world_size // self.dp_size expert_x = torch.empty( diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 21cba37478e..be28d620f47 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -70,6 +70,7 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: N = w1.shape[1] if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 @@ -90,6 +91,7 @@ def apply( a2_scale, workspace13, workspace2, + expert_num_tokens, ) else: return self.triton_expert( @@ -108,4 +110,5 @@ def apply( a2_scale, workspace13, workspace2, + expert_num_tokens, ) From 8a9895c3d36091176829a17d8ff6c8bfabfe4522 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 17 Apr 2025 00:10:05 +0000 Subject: [PATCH 141/190] work on unit test Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 323 +++++++++++++++--- .../layers/fused_moe/fused_moe.py | 3 +- .../layers/fused_moe/pplx_dispatch_combine.py | 17 +- 3 files changed, 286 insertions(+), 57 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index b3b8817c69c..0156253d680 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -7,6 +7,8 @@ import os import pytest import torch +import traceback + from torch.nn import Parameter from torch.nn import functional as F from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] @@ -38,6 +40,8 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.utils import round_up + from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts @@ -102,7 +106,9 @@ def _worker_parallel_launch( *args, **kwargs, ) - except Exception: + except Exception as ex: + print(ex) + traceback.print_exception(ex) raise finally: torch.distributed.destroy_process_group() @@ -247,13 +253,150 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_moe_batched_experts( +# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128, 1024, 2048]) +# @pytest.mark.parametrize("k", [128, 511, 1024]) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# def test_fused_moe_batched_experts( +# m: int, +# n: int, +# k: int, +# e: int, +# topk: int, +# dtype: torch.dtype, +# ): +# current_platform.seed_everything(7) + +# a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 +# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 +# w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + +# score = torch.randn((m, e), device="cuda", dtype=dtype) + +# vllm_config = VllmConfig() +# with set_current_vllm_config(vllm_config): +# topk_weight, topk_ids = fused_topk(a, score, topk, False) + +# torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + +# if True: +# triton_output = torch_batched_moe(a, +# w1, +# w2, +# topk_weight, +# topk_ids) +# else: +# b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) +# triton_output = fused_batched_experts( +# b_a, +# w1, +# w2, +# topk_weight, +# topk_ids, +# global_num_experts=e +# ) + +# if False: +# torch.set_printoptions(profile="full") +# print("BASELINE") +# print(torch_output) +# print("OUTPUT") +# print(triton_output) + +# torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + +def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): + assert torch.cuda.current_device() == pgi.local_rank + + num_tokens, hidden_dim = a.shape + num_experts = w1.shape[0] + block_size = 128 + device = pgi.device + + max_num_tokens = num_tokens + print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") + rank = pgi.rank + + ata = AllToAll( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=pgi.world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=( + 0 + if a.dtype.itemsize != 1 + else ( + (hidden_dim + block_size - 1) + // block_size + * torch.float32.itemsize + ) + ), + ) + + dispatch_combine = PplxDispatchCombine( + ata, + max_num_tokens, + pgi.world_size, + dp_size, + rank, + a.dtype, + ) + + def chunk_by_rank(t, r): + num = t.shape[0] + assert num % pgi.world_size == 0, f"{num}, {pgi.world_size}" # for now + chunk = num // pgi.world_size + print(f"chunk {t.shape}, {pgi.world_size}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") + return t[(r * chunk):(r + 1)*chunk] + + a_chunk = chunk_by_rank(a, rank).to(device) + score_chunk = chunk_by_rank(scores, rank).to(device) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + + #print(f"chunk_topk_ids = {chunk_topk_ids}") + + b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( + a_chunk, + None, + None, + chunk_topk_ids, + num_experts, # store at PplxDispatchCombine creation? + None + ) + torch.cuda.synchronize() # necessary? + + out = torch.full( + (max_num_tokens, hidden_dim), + torch.nan, + dtype=a.dtype, + device=device, + ) + + dispatch_combine.combine( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + ) + torch.cuda.synchronize() + + ata.destroy() + + torch.distributed.barrier() + + return out[:num_tokens] + + +def _pplx_dispatch_combine( + pgi: ProcessGroupInfo, + dp_size: int, m: int, n: int, k: int, @@ -261,7 +404,9 @@ def test_fused_moe_batched_experts( topk: int, dtype: torch.dtype, ): - current_platform.seed_everything(7) + uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 @@ -269,49 +414,74 @@ def test_fused_moe_batched_experts( score = torch.randn((m, e), device="cuda", dtype=dtype) - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) + topk_weight, topk_ids = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + print(f"a {a.shape}") + a_rep = torch.repeat_interleave(a, topk, dim=1) + print(f"a_rep {a_rep.shape}") + + torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) - if True: - triton_output = torch_batched_moe(a, + pplx_output = torch_pplx_dispatch_combine(pgi, + dp_size, + a, w1, w2, - topk_weight, - topk_ids) - else: - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - triton_output = fused_batched_experts( - b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e - ) + score, + topk) if False: torch.set_printoptions(profile="full") print("BASELINE") print(torch_output) print("OUTPUT") - print(triton_output) + print(pplx_output) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) +# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128, 1024, 2048]) +# @pytest.mark.parametrize("k", [128, 511, 1024]) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) +@pytest.mark.parametrize("topk", [2]) #TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_pplx_dispatch_combine( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + if False: + world_size = 4 + dp_size = 2 + else: + world_size = 2 + dp_size = 1 + parallel_launch( + world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype + ) -def torch_pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): +def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): hidden_dim = a.shape[-1] num_experts = w1.shape[0] num_local_experts = num_experts // pgi.world_size block_size = 128 - topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - max_num_tokens = tokens_per_expert.max() + max_num_tokens = round_up(a.shape[0], 128) #tokens_per_expert.max() + print(f"max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}/{num_local_experts}") rank = pgi.rank ata = AllToAll( @@ -350,20 +520,60 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): experts, ) - out = fused_experts( - a, - w1, - w2, - topk_weight, - topk_ids - ) + def chunk_by_rank(t, r): + num = t.shape[0] + assert num % pgi.world_size == 0, f"{num}, {dp_size}" # for now + chunk = num // pgi.world_size + return t[(r * chunk):(r + 1)*chunk] + + a_chunk = chunk_by_rank(a, rank) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, chunk_by_rank(scores, rank), topk, False) + + print(f"chunk_topk_ids = {chunk_topk_ids}") + + # TODO: chunk up by rank + if False: + out = fused_experts( + a_chunk, + w1, # chunk? + w2, # chunk? + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_local_experts + ) + # reduce outputs? + else: + b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( + a_chunk, + None, + None, + chunk_topk_ids, + num_experts, + None + ) + torch.cuda.synchronize() + + out = torch.full( + (max_num_tokens, hidden_dim), + torch.nan, + dtype=a.dtype, + device=a.device, + ) + + dispatch_combine.combine( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + ) + + torch.cuda.synchronize() ata.destroy() return out - def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, @@ -391,11 +601,12 @@ def _pplx_moe( torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) triton_output = torch_pplx_moe(pgi, + dp_size, a, w1, w2, - topk_weight, - topk_ids) + score, + topk) if False: torch.set_printoptions(profile="full") @@ -409,12 +620,18 @@ def _pplx_moe( nvshmem_finalize() -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128, 1024, 2048]) +# @pytest.mark.parametrize("k", [128, 511, 1024]) +# @pytest.mark.parametrize("e", NUM_EXPERTS) +# @pytest.mark.parametrize("topk", TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128]) +@pytest.mark.parametrize("k", [128]) +@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) +@pytest.mark.parametrize("topk", [2]) #TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_pplx_moe( m: int, n: int, @@ -424,8 +641,12 @@ def test_pplx_moe( dtype: torch.dtype, ): current_platform.seed_everything(7) - world_size = 4 - dp_size = 2 + if False: + world_size = 4 + dp_size = 2 + else: + world_size = 2 + dp_size = 1 parallel_launch( world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7ee1ee93304..63e32ae68a0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1886,7 +1886,8 @@ def apply( assert expert_num_tokens is not None num_tokens, topk = topk_ids.shape _, max_num_tokens, K = hidden_states.shape - num_experts = w1.shape[0] + print(f"global_num_experts = {global_num_experts}") + num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) for expert in range(num_experts): num = expert_num_tokens[expert] diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index a36c825d9e7..dd8fe4a36fb 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -75,15 +75,18 @@ def dispatch( dtype=torch.int32, device=device, ) - #expert_num_tokens.fill_(-1) # debugging remove + #expert_num_tokens.fill_(-1) # debugging, remove later num_dp = self.world_size // self.dp_size + print(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}") expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, device=device, ) - expert_x.fill_(torch.nan) # debugging remove + expert_x.fill_(torch.nan) # debugging, remove later + + print(f"GOT HERE B {self.rank}") expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: @@ -100,6 +103,8 @@ def dispatch( device=device, ) + print(f"GOT HERE C {self.rank}") + # This argument is optional, defaults to indices.shape[0] # This causes a deadlock???? #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -107,7 +112,9 @@ def dispatch( bound_m = None # TODO: optimize this? - indices = rank_topk_ids.to(dtype=torch.uint32).to(device) + indices = rank_topk_ids.to(dtype=torch.uint32) + + print(f"GOT HERE D {self.rank}") self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -133,7 +140,7 @@ def combine( #device = get_dp_group().device #assert fused_expert_output.device == device - #print(f"COMBINE START {self.rank}") + print(f"COMBINE START {self.rank}") # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -154,4 +161,4 @@ def combine( expert_y=fused_expert_output, bound_m=bound_m) - #print(f"COMBINE END {self.rank}") + print(f"COMBINE END {self.rank}") From d37a3010e89d94e0ea814a0bed37b77f37bc80bd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 17 Apr 2025 03:45:09 +0000 Subject: [PATCH 142/190] dispatch/combine unit test Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 104 ++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 46 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index 0156253d680..afb0b885866 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -308,6 +308,14 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): # torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) +def chunk_by_rank(t, r, w): + num = t.shape[0] + assert num % w == 0, f"{num}, {w}" # for now + chunk = num // w + #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") + return t[(r * chunk):(r + 1)*chunk] + + def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): assert torch.cuda.current_device() == pgi.local_rank @@ -315,10 +323,12 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_experts = w1.shape[0] block_size = 128 device = pgi.device + rank_num_tokens = num_tokens // pgi.world_size max_num_tokens = num_tokens - print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") + #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") rank = pgi.rank + world_size = pgi.world_size ata = AllToAll( max_num_tokens=max_num_tokens, @@ -342,22 +352,15 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): dispatch_combine = PplxDispatchCombine( ata, - max_num_tokens, + max_num_tokens, # // world_size? pgi.world_size, dp_size, rank, a.dtype, ) - def chunk_by_rank(t, r): - num = t.shape[0] - assert num % pgi.world_size == 0, f"{num}, {pgi.world_size}" # for now - chunk = num // pgi.world_size - print(f"chunk {t.shape}, {pgi.world_size}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") - return t[(r * chunk):(r + 1)*chunk] - - a_chunk = chunk_by_rank(a, rank).to(device) - score_chunk = chunk_by_rank(scores, rank).to(device) + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) #print(f"chunk_topk_ids = {chunk_topk_ids}") @@ -391,16 +394,22 @@ def chunk_by_rank(t, r): torch.distributed.barrier() - return out[:num_tokens] + #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") + + #torch.distributed.all_reduce(out) + + #print(f"AR OUT {rank}: {out.shape} {out}") + + return out[:rank_num_tokens] def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, - m: int, - n: int, - k: int, - e: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, topk: int, dtype: torch.dtype, ): @@ -408,19 +417,18 @@ def _pplx_dispatch_combine( torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - score = torch.randn((m, e), device="cuda", dtype=dtype) + m, k = a.shape + e, _, n = w2.shape topk_weight, topk_ids = fused_topk(a, score, topk, False) - print(f"a {a.shape}") - a_rep = torch.repeat_interleave(a, topk, dim=1) - print(f"a_rep {a_rep.shape}") + #print(f"a {a.shape}") + a_rep = torch.repeat_interleave(a, topk, dim=0) + #print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}") + + torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).to(a.dtype).sum(dim=1) - torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) + #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") pplx_output = torch_pplx_dispatch_combine(pgi, dp_size, @@ -437,23 +445,25 @@ def _pplx_dispatch_combine( print("OUTPUT") print(pplx_output) + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() -# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128, 1024, 2048]) -# @pytest.mark.parametrize("k", [128, 511, 1024]) -# @pytest.mark.parametrize("e", NUM_EXPERTS) -# @pytest.mark.parametrize("topk", TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128]) -@pytest.mark.parametrize("k", [128]) -@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) -@pytest.mark.parametrize("topk", [2]) #TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) # what is restriction on this? +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions here? +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("m", [2]) ##, 32]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128]) +# @pytest.mark.parametrize("k", [128]) +# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) +# @pytest.mark.parametrize("topk", [2]) #TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_pplx_dispatch_combine( m: int, n: int, @@ -469,8 +479,14 @@ def test_pplx_dispatch_combine( else: world_size = 2 dp_size = 1 + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + parallel_launch( - world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype + world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype ) @@ -483,6 +499,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): max_num_tokens = round_up(a.shape[0], 128) #tokens_per_expert.max() print(f"max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}/{num_local_experts}") rank = pgi.rank + world_size = pgi.world_size ata = AllToAll( max_num_tokens=max_num_tokens, @@ -520,14 +537,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): experts, ) - def chunk_by_rank(t, r): - num = t.shape[0] - assert num % pgi.world_size == 0, f"{num}, {dp_size}" # for now - chunk = num // pgi.world_size - return t[(r * chunk):(r + 1)*chunk] - - a_chunk = chunk_by_rank(a, rank) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, chunk_by_rank(scores, rank), topk, False) + a_chunk = chunk_by_rank(a, rank, world_size) + score_chunk = chunk_by_rank(scores, rank, world_size) + chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) print(f"chunk_topk_ids = {chunk_topk_ids}") From be9423272d5361d5df0c821c1c6a7173751f23d2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 17 Apr 2025 13:08:04 +0000 Subject: [PATCH 143/190] forgot file Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 104 +++++++++++++-------------------- 1 file changed, 41 insertions(+), 63 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index afb0b885866..87c6d42862b 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -373,7 +373,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_experts, # store at PplxDispatchCombine creation? None ) - torch.cuda.synchronize() # necessary? + #torch.cuda.synchronize() # necessary? out = torch.full( (max_num_tokens, hidden_dim), @@ -452,18 +452,12 @@ def _pplx_dispatch_combine( nvshmem_finalize() -@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) # what is restriction on this? +@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions here? +@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("m", [2]) ##, 32]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128]) -# @pytest.mark.parametrize("k", [128]) -# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) -# @pytest.mark.parametrize("topk", [2]) #TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_pplx_dispatch_combine( m: int, n: int, @@ -491,13 +485,16 @@ def test_pplx_dispatch_combine( def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): - hidden_dim = a.shape[-1] + assert torch.cuda.current_device() == pgi.local_rank + + num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] - num_local_experts = num_experts // pgi.world_size block_size = 128 + device = pgi.device + rank_num_tokens = num_tokens // pgi.world_size - max_num_tokens = round_up(a.shape[0], 128) #tokens_per_expert.max() - print(f"max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}/{num_local_experts}") + max_num_tokens = num_tokens + #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") rank = pgi.rank world_size = pgi.world_size @@ -523,7 +520,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): dispatch_combine = PplxDispatchCombine( ata, - max_num_tokens, + max_num_tokens, # // world_size? pgi.world_size, dp_size, rank, @@ -537,53 +534,34 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): experts, ) - a_chunk = chunk_by_rank(a, rank, world_size) - score_chunk = chunk_by_rank(scores, rank, world_size) + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - print(f"chunk_topk_ids = {chunk_topk_ids}") + #print(f"chunk_topk_ids = {chunk_topk_ids}") - # TODO: chunk up by rank - if False: - out = fused_experts( - a_chunk, - w1, # chunk? - w2, # chunk? - chunk_topk_weight, - chunk_topk_ids, - global_num_experts=num_local_experts - ) - # reduce outputs? - else: - b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( - a_chunk, - None, - None, - chunk_topk_ids, - num_experts, - None - ) - torch.cuda.synchronize() + out = fused_experts( + a_chunk, + w1, # chunk? + w2, # chunk? + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts #? num_local_experts? + ) - out = torch.full( - (max_num_tokens, hidden_dim), - torch.nan, - dtype=a.dtype, - device=a.device, - ) + torch.cuda.synchronize() - dispatch_combine.combine( - out, - b_a, - chunk_topk_weight, - chunk_topk_ids, - ) + ata.destroy() - torch.cuda.synchronize() + torch.distributed.barrier() - ata.destroy() + #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") - return out + #torch.distributed.all_reduce(out) + + print(f"OUT {rank}: {out.shape} {out}") + + return out[:rank_num_tokens] def _pplx_moe( @@ -612,29 +590,29 @@ def _pplx_moe( torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - triton_output = torch_pplx_moe(pgi, - dp_size, - a, - w1, - w2, - score, - topk) + pplxd_output = torch_pplx_moe(pgi, + dp_size, + a, + w1, + w2, + score, + topk) if False: torch.set_printoptions(profile="full") print("BASELINE") print(torch_output) print("OUTPUT") - print(triton_output) + print(pplx_output) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() # @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) # @pytest.mark.parametrize("n", [128, 1024, 2048]) -# @pytest.mark.parametrize("k", [128, 511, 1024]) +# @pytest.mark.parametrize("k", [128, 512, 1024]) # @pytest.mark.parametrize("e", NUM_EXPERTS) # @pytest.mark.parametrize("topk", TOP_KS) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) From 5e5b3ad904b88ff7e9e0d6d1a8de6effd910e390 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 18 Apr 2025 02:22:24 +0000 Subject: [PATCH 144/190] somewhat working unit test Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 137 +++++++++--------- .../layers/fused_moe/fused_moe.py | 5 +- .../layers/fused_moe/modular_kernel.py | 2 +- .../layers/fused_moe/pplx_dispatch_combine.py | 12 +- 4 files changed, 78 insertions(+), 78 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index 87c6d42862b..f6443187f14 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -9,10 +9,8 @@ import torch import traceback -from torch.nn import Parameter -from torch.nn import functional as F from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing import Callable, Concatenate, ParamSpec +from typing import Callable, Concatenate, ParamSpec, Tuple from pplx_kernels import AllToAll from pplx_kernels.nvshmem import ( @@ -25,27 +23,18 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, torch_moe, torch_moe_single) -from vllm import _custom_ops as ops +#from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe import fused_moe +#from vllm.model_executor.layers.fused_moe import fused_moe #from vllm.model_executor.layers.fused_moe.fused_batched_moe import fused_batched_experts from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) -from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe) -from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - marlin_quantize) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - quantize_weights) -from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types -from vllm.utils import round_up from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine NUM_EXPERTS = [8, 64] @@ -373,7 +362,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_experts, # store at PplxDispatchCombine creation? None ) - #torch.cuda.synchronize() # necessary? + + b_a = b_a * 1.5 out = torch.full( (max_num_tokens, hidden_dim), @@ -392,7 +382,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): ata.destroy() - torch.distributed.barrier() + #torch.distributed.barrier() #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") @@ -406,19 +396,26 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - score: torch.Tensor, + m, n, k, e, + #a: torch.Tensor, + #w1: torch.Tensor, + #w2: torch.Tensor, + #score: torch.Tensor, topk: int, dtype: torch.dtype, ): uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) + device = pgi.device - m, k = a.shape - e, _, n = w2.shape + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + #m, k = a.shape + #e, _, n = w2.shape topk_weight, topk_ids = fused_topk(a, score, topk, False) @@ -426,7 +423,7 @@ def _pplx_dispatch_combine( a_rep = torch.repeat_interleave(a, topk, dim=0) #print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}") - torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).to(a.dtype).sum(dim=1) + torch_output = (a_rep.view(-1, topk, k) * 1.5 * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") @@ -452,12 +449,13 @@ def _pplx_dispatch_combine( nvshmem_finalize() -@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("m", [4, 32, 64, 222]) #, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) def test_pplx_dispatch_combine( m: int, n: int, @@ -465,22 +463,14 @@ def test_pplx_dispatch_combine( e: int, topk: int, dtype: torch.dtype, + world_dp_size: Tuple[int, int], ): current_platform.seed_everything(7) - if False: - world_size = 4 - dp_size = 2 - else: - world_size = 2 - dp_size = 1 - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + world_size, dp_size = world_dp_size parallel_launch( - world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype + #world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype + world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype ) @@ -489,9 +479,10 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] + num_local_experts = num_experts // pgi.world_size block_size = 128 device = pgi.device - rank_num_tokens = num_tokens // pgi.world_size + rank_num_tokens = num_tokens // pgi.world_size # TODO even divide max_num_tokens = num_tokens #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") @@ -518,6 +509,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): ), ) + w1 = w1.to(device) + w2 = w2.to(device) + dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, # // world_size? @@ -538,28 +532,28 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - #print(f"chunk_topk_ids = {chunk_topk_ids}") + #print(f"chunk_topk_ids {rank} {chunk_topk_ids.shape} {chunk_topk_ids.view(-1)}") out = fused_experts( a_chunk, - w1, # chunk? - w2, # chunk? + w1, + w2, chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_experts #? num_local_experts? + global_num_experts=num_local_experts #? num_local_experts? ) torch.cuda.synchronize() ata.destroy() - torch.distributed.barrier() + #torch.distributed.barrier() #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") #torch.distributed.all_reduce(out) - print(f"OUT {rank}: {out.shape} {out}") + #print(f"OUT {rank}: {out.shape} {out}") return out[:rank_num_tokens] @@ -567,10 +561,10 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, - m: int, - n: int, - k: int, - e: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, topk: int, dtype: torch.dtype, ): @@ -578,33 +572,37 @@ def _pplx_moe( torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + m, k = a.shape + e, _, n = w2.shape - score = torch.randn((m, e), device="cuda", dtype=dtype) + torch.set_printoptions(profile="full") vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): topk_weight, topk_ids = fused_topk(a, score, topk, False) + #print(f"topk_ids {pgi.rank} {topk_ids.shape} {topk_ids.view(-1)}") + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplxd_output = torch_pplx_moe(pgi, - dp_size, - a, - w1, - w2, - score, - topk) + pplx_output = torch_pplx_moe(pgi, + dp_size, + a, + w1, + w2, + score, + topk) + + #print(f"torch_output {pgi.rank}: {torch_output}") if False: - torch.set_printoptions(profile="full") print("BASELINE") print(torch_output) print("OUTPUT") print(pplx_output) + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() @@ -616,12 +614,13 @@ def _pplx_moe( # @pytest.mark.parametrize("e", NUM_EXPERTS) # @pytest.mark.parametrize("topk", TOP_KS) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128]) +@pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128]) @pytest.mark.parametrize("n", [128]) @pytest.mark.parametrize("k", [128]) @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) @pytest.mark.parametrize("topk", [2]) #TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) def test_pplx_moe( m: int, n: int, @@ -629,15 +628,17 @@ def test_pplx_moe( e: int, topk: int, dtype: torch.dtype, + world_dp_size: Tuple[int, int], ): current_platform.seed_everything(7) - if False: - world_size = 4 - dp_size = 2 - else: - world_size = 2 - dp_size = 1 + world_size, dp_size = world_dp_size + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + parallel_launch( - world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype + world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, dtype + #world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 63e32ae68a0..f1fcdc46293 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1858,7 +1858,7 @@ def workspace_shapes( a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: max_num_tokens = a.shape[1] - workspace13 = num_experts * max_num_tokens * K + workspace13 = num_experts * max_num_tokens * K * 2 # *2 = HACK!!!!! workspace2 = max_num_tokens * (N // 2) return (workspace13, workspace2, a_dtype) @@ -1889,7 +1889,8 @@ def apply( print(f"global_num_experts = {global_num_experts}") num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) - for expert in range(num_experts): + num_local_experts = expert_num_tokens.numel() + for expert in range(num_local_experts): # num_experts num = expert_num_tokens[expert] if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 96ecf5990a6..35f8b829277 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -323,7 +323,7 @@ def forward( if global_num_experts == -1: global_num_experts = E - output = a1 if inplace else torch.zeros_like(a1) + output = a1 if inplace else torch.empty_like(a1) workspace13_shape, workspace2_shape, workspace_dtype = ( self.fused_experts.workspace_shapes(a1, M, N, K, top_k, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index dd8fe4a36fb..682935e2c68 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -78,7 +78,7 @@ def dispatch( #expert_num_tokens.fill_(-1) # debugging, remove later num_dp = self.world_size // self.dp_size - print(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}") + logger.debug(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}") expert_x = torch.empty( (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), dtype=a1q.dtype, @@ -86,7 +86,7 @@ def dispatch( ) expert_x.fill_(torch.nan) # debugging, remove later - print(f"GOT HERE B {self.rank}") + logger.debug(f"GOT HERE B {self.rank}") expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: @@ -103,7 +103,7 @@ def dispatch( device=device, ) - print(f"GOT HERE C {self.rank}") + logger.debug(f"GOT HERE C {self.rank}") # This argument is optional, defaults to indices.shape[0] # This causes a deadlock???? @@ -114,8 +114,6 @@ def dispatch( # TODO: optimize this? indices = rank_topk_ids.to(dtype=torch.uint32) - print(f"GOT HERE D {self.rank}") - self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -140,7 +138,7 @@ def combine( #device = get_dp_group().device #assert fused_expert_output.device == device - print(f"COMBINE START {self.rank}") + logger.debug(f"COMBINE START {self.rank}") # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens @@ -161,4 +159,4 @@ def combine( expert_y=fused_expert_output, bound_m=bound_m) - print(f"COMBINE END {self.rank}") + logger.debug(f"COMBINE END {self.rank}") From 669e4f3ebed146c565e5dfeda3e7178387355157 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 18 Apr 2025 19:31:31 +0000 Subject: [PATCH 145/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 164 ++++++++++-------- .../layers/fused_moe/fused_moe.py | 4 +- .../layers/fused_moe/pplx_dispatch_combine.py | 2 +- 3 files changed, 93 insertions(+), 77 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index f6443187f14..b80ebfd64a0 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -164,7 +164,7 @@ def torch_dispatch( a: torch.Tensor, topk_ids: torch.Tensor, num_experts: int -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] @@ -172,10 +172,11 @@ def torch_dispatch( topk = topk_ids.shape[1] tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - max_num_tokens = tokens_per_expert.max() + b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), dtype=a.dtype, device=a.device) + #print(f"b_a shape {b_a.shape}") token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) @@ -242,59 +243,58 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) -# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128, 1024, 2048]) -# @pytest.mark.parametrize("k", [128, 511, 1024]) -# @pytest.mark.parametrize("e", NUM_EXPERTS) -# @pytest.mark.parametrize("topk", TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# def test_fused_moe_batched_experts( -# m: int, -# n: int, -# k: int, -# e: int, -# topk: int, -# dtype: torch.dtype, -# ): -# current_platform.seed_everything(7) - -# a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 -# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 -# w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - -# score = torch.randn((m, e), device="cuda", dtype=dtype) - -# vllm_config = VllmConfig() -# with set_current_vllm_config(vllm_config): -# topk_weight, topk_ids = fused_topk(a, score, topk, False) - -# torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - -# if True: -# triton_output = torch_batched_moe(a, -# w1, -# w2, -# topk_weight, -# topk_ids) -# else: -# b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) -# triton_output = fused_batched_experts( -# b_a, -# w1, -# w2, -# topk_weight, -# topk_ids, -# global_num_experts=e -# ) - -# if False: -# torch.set_printoptions(profile="full") -# print("BASELINE") -# print(torch_output) -# print("OUTPUT") -# print(triton_output) - -# torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) +@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids = fused_topk(a, score, topk, False) + + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + + if True: + triton_output = torch_batched_moe(a, + w1, + w2, + topk_weight, + topk_ids) + else: + b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) + triton_output = fused_batched_experts( + b_a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts=e + ) + + if False: + torch.set_printoptions(profile="full") + print("BASELINE") + print(torch_output) + print("OUTPUT") + print(triton_output) + + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) def chunk_by_rank(t, r, w): @@ -310,6 +310,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] + num_local_experts = w1.shape[0] // pgi.world_size block_size = 128 device = pgi.device rank_num_tokens = num_tokens // pgi.world_size @@ -352,7 +353,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - #print(f"chunk_topk_ids = {chunk_topk_ids}") + #print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -363,6 +364,25 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): None ) + #topk_weight, topk_ids = fused_topk(a_chunk, score_chunk, topk, False) + naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) + + torch.distributed.all_reduce(tokens_per_expert) + #max_num = tokens_per_expert.max() + tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) + + #print(f"tpe {tokens_per_expert}") + #print(f"ent {expert_num_tokens}") + + #naive_b_a = chunk_by_rank(naive_b_a, rank, world_size) + + #torch.set_printoptions(profile="full") + #print("b_a", b_a[:naive_b_a.shape[1]]) + #print("naive_b_a", naive_b_a) + + torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) + #torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0) + b_a = b_a * 1.5 out = torch.full( @@ -382,8 +402,6 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): ata.destroy() - #torch.distributed.barrier() - #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") #torch.distributed.all_reduce(out) @@ -547,8 +565,6 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): ata.destroy() - #torch.distributed.barrier() - #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") #torch.distributed.all_reduce(out) @@ -593,8 +609,6 @@ def _pplx_moe( score, topk) - #print(f"torch_output {pgi.rank}: {torch_output}") - if False: print("BASELINE") print(torch_output) @@ -603,23 +617,25 @@ def _pplx_moe( torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() -# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128, 1024, 2048]) -# @pytest.mark.parametrize("k", [128, 512, 1024]) -# @pytest.mark.parametrize("e", NUM_EXPERTS) -# @pytest.mark.parametrize("topk", TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128]) -@pytest.mark.parametrize("k", [128]) -@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) -@pytest.mark.parametrize("topk", [2]) #TOP_KS) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128]) +# @pytest.mark.parametrize("n", [128]) +# @pytest.mark.parametrize("k", [128]) +# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) +# @pytest.mark.parametrize("topk", [2]) #TOP_KS) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) def test_pplx_moe( m: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f1fcdc46293..7603fa5e06c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1858,8 +1858,8 @@ def workspace_shapes( a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: max_num_tokens = a.shape[1] - workspace13 = num_experts * max_num_tokens * K * 2 # *2 = HACK!!!!! - workspace2 = max_num_tokens * (N // 2) + workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack + workspace2 = max_num_tokens * N return (workspace13, workspace2, a_dtype) def apply( diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 682935e2c68..10c02fb2ff2 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -84,7 +84,7 @@ def dispatch( dtype=a1q.dtype, device=device, ) - expert_x.fill_(torch.nan) # debugging, remove later + expert_x.fill_(0) #torch.nan # debugging, remove later logger.debug(f"GOT HERE B {self.rank}") From cbcf12afb0db933dc599a40663b283e49d15e439 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 18 Apr 2025 22:37:31 +0000 Subject: [PATCH 146/190] fix test Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 46 +++++++++---------- .../layers/fused_moe/fused_moe.py | 18 +++++++- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index b80ebfd64a0..a62dbbcc4cd 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -10,7 +10,7 @@ import traceback from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing import Callable, Concatenate, ParamSpec, Tuple +from typing import Callable, Concatenate, Optional, ParamSpec, Tuple from pplx_kernels import AllToAll from pplx_kernels.nvshmem import ( @@ -163,7 +163,8 @@ def parallel_launch_from_env( def torch_dispatch( a: torch.Tensor, topk_ids: torch.Tensor, - num_experts: int + num_experts: int, + max_num_tokens: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] @@ -172,7 +173,8 @@ def torch_dispatch( topk = topk_ids.shape[1] tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - max_num_tokens = tokens_per_expert.max() + if max_num_tokens is None: + max_num_tokens = tokens_per_expert.max() b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), dtype=a.dtype, device=a.device) @@ -314,11 +316,10 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): block_size = 128 device = pgi.device rank_num_tokens = num_tokens // pgi.world_size - - max_num_tokens = num_tokens - #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") rank = pgi.rank world_size = pgi.world_size + max_num_tokens = num_tokens + #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") ata = AllToAll( max_num_tokens=max_num_tokens, @@ -342,7 +343,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): dispatch_combine = PplxDispatchCombine( ata, - max_num_tokens, # // world_size? + max_num_tokens, pgi.world_size, dp_size, rank, @@ -353,7 +354,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - #print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") + print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -371,14 +372,17 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): #max_num = tokens_per_expert.max() tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) - #print(f"tpe {tokens_per_expert}") - #print(f"ent {expert_num_tokens}") + print(f"tpe {tokens_per_expert}") + print(f"ent {expert_num_tokens}") + + #torch.set_printoptions(profile="full") + #torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX) + #torch.distributed.broadcast(naive_b_a, src=rank) #naive_b_a = chunk_by_rank(naive_b_a, rank, world_size) - #torch.set_printoptions(profile="full") - #print("b_a", b_a[:naive_b_a.shape[1]]) - #print("naive_b_a", naive_b_a) + #print("b_a", b_a.shape, b_a) #[:, :naive_b_a.shape[1]]) + #print("naive_b_a", naive_b_a.shape, naive_b_a) torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) #torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0) @@ -386,7 +390,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): b_a = b_a * 1.5 out = torch.full( - (max_num_tokens, hidden_dim), + (rank_num_tokens * world_size, hidden_dim), torch.nan, dtype=a.dtype, device=device, @@ -539,7 +543,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): a.dtype, ) - experts = BatchedExperts() + experts = BatchedExperts(max_num_tokens, rank) fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -554,24 +558,20 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): out = fused_experts( a_chunk, - w1, - w2, + chunk_by_rank(w1, rank, world_size), + chunk_by_rank(w2, rank, world_size), chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_local_experts #? num_local_experts? + global_num_experts=num_experts #? num_local_experts? ) torch.cuda.synchronize() ata.destroy() - #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") - - #torch.distributed.all_reduce(out) - #print(f"OUT {rank}: {out.shape} {out}") - return out[:rank_num_tokens] + return out[:rank_num_tokens] # chunk_by_rank? def _pplx_moe( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7603fa5e06c..1fecff34106 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1834,6 +1834,8 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + max_num_tokens: Optional[int] = None, + rank: int = 0, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1846,6 +1848,8 @@ def __init__( assert not use_int8_w8a16 assert block_shape is None assert block_m is None + self.max_num_tokens = max_num_tokens + self.rank = rank def workspace_shapes( self, @@ -1857,7 +1861,8 @@ def workspace_shapes( num_experts: int, a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: - max_num_tokens = a.shape[1] + #assert self.max_num_tokens >= a.shape[1] + max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack workspace2 = max_num_tokens * N return (workspace13, workspace2, a_dtype) @@ -1885,13 +1890,20 @@ def apply( assert hidden_states.dim() == 3 assert expert_num_tokens is not None num_tokens, topk = topk_ids.shape - _, max_num_tokens, K = hidden_states.shape + _, tmp_max_num_tokens, K = hidden_states.shape + max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens print(f"global_num_experts = {global_num_experts}") num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) num_local_experts = expert_num_tokens.numel() + #assert num_local_experts >= topk_ids.view(-1).max() + #print(f"apply a={hidden_states}") + #print(f"apply topk={topk_ids}") + #print(f"apply num_tokens={expert_num_tokens}") + for expert in range(num_local_experts): # num_experts num = expert_num_tokens[expert] + assert num <= max_num_tokens if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) @@ -1904,6 +1916,8 @@ def apply( #print("END EXPERTS") + #print(f"apply out={out}") + return out From 9a800ad6b401b8e917e09e59e01ea66500881ec4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 19 Apr 2025 01:08:54 +0000 Subject: [PATCH 147/190] some cleanup Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 41 ++++++++----------- .../layers/fused_moe/fused_moe.py | 22 ++-------- .../layers/fused_moe/pplx_dispatch_combine.py | 2 +- 3 files changed, 23 insertions(+), 42 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index a62dbbcc4cd..0e5e0cd7728 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -299,10 +299,13 @@ def test_fused_moe_batched_experts( torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + def chunk_by_rank(t, r, w): - num = t.shape[0] - assert num % w == 0, f"{num}, {w}" # for now - chunk = num // w + chunk = rank_chunk(t.shape[0], r, w) #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") return t[(r * chunk):(r + 1)*chunk] @@ -312,12 +315,11 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] - num_local_experts = w1.shape[0] // pgi.world_size block_size = 128 device = pgi.device - rank_num_tokens = num_tokens // pgi.world_size rank = pgi.rank world_size = pgi.world_size + rank_num_tokens = rank_chunk(num_tokens, rank, world_size) max_num_tokens = num_tokens #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") @@ -354,7 +356,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): score_chunk = chunk_by_rank(scores, rank, world_size).to(device) chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") + #print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -372,8 +374,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): #max_num = tokens_per_expert.max() tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) - print(f"tpe {tokens_per_expert}") - print(f"ent {expert_num_tokens}") + #print(f"tpe {tokens_per_expert}") + #print(f"ent {expert_num_tokens}") #torch.set_printoptions(profile="full") #torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX) @@ -501,15 +503,12 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] - num_local_experts = num_experts // pgi.world_size block_size = 128 device = pgi.device - rank_num_tokens = num_tokens // pgi.world_size # TODO even divide - - max_num_tokens = num_tokens - #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") rank = pgi.rank world_size = pgi.world_size + rank_num_tokens = rank_chunk(num_tokens, rank, world_size) + max_num_tokens = num_tokens ata = AllToAll( max_num_tokens=max_num_tokens, @@ -558,6 +557,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): out = fused_experts( a_chunk, + # Chunking weights like this only works for batched format chunk_by_rank(w1, rank, world_size), chunk_by_rank(w2, rank, world_size), chunk_topk_weight, @@ -571,7 +571,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): #print(f"OUT {rank}: {out.shape} {out}") - return out[:rank_num_tokens] # chunk_by_rank? + return out[:rank_num_tokens] def _pplx_moe( @@ -624,18 +624,13 @@ def _pplx_moe( nvshmem_finalize() -@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) +# TODO: M == 1 doesn't work +@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024])# , 2048]) +@pytest.mark.parametrize("k", [128, 512]) # , 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128]) -# @pytest.mark.parametrize("n", [128]) -# @pytest.mark.parametrize("k", [128]) -# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS) -# @pytest.mark.parametrize("topk", [2]) #TOP_KS) -# @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) def test_pplx_moe( m: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1fecff34106..23f903d33dc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1777,9 +1777,6 @@ def dispatch( num_tokens = a1.shape[0] topk = topk_ids.shape[1] - #assert num_experts % self.world_size == 0 - #num_local_experts = num_experts // self.world_size - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) max_num_tokens = tokens_per_expert.max() expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) @@ -1892,31 +1889,20 @@ def apply( num_tokens, topk = topk_ids.shape _, tmp_max_num_tokens, K = hidden_states.shape max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens - print(f"global_num_experts = {global_num_experts}") + #print(f"global_num_experts = {global_num_experts}") num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) num_local_experts = expert_num_tokens.numel() - #assert num_local_experts >= topk_ids.view(-1).max() - #print(f"apply a={hidden_states}") - #print(f"apply topk={topk_ids}") - #print(f"apply num_tokens={expert_num_tokens}") + #print(f"shapes = {hidden_states.shape}, {w1.shape}, {w2.shape}, {out.shape} {expert_num_tokens.shape} {workspace2.shape} {num_experts}") for expert in range(num_local_experts): # num_experts num = expert_num_tokens[expert] - assert num <= max_num_tokens + assert num <= max_num_tokens, f"{num}, {max_num_tokens}" + #print(f"{type(num)}, {num}, {max_num_tokens}") if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) - # fill remainder with 0??? - #out[expert, num:, :].fill_(0) - else: - #out[expert, :, :].fill_(0) # ?? - pass - - #print("END EXPERTS") - - #print(f"apply out={out}") return out diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 10c02fb2ff2..90bfa385dac 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -84,7 +84,7 @@ def dispatch( dtype=a1q.dtype, device=device, ) - expert_x.fill_(0) #torch.nan # debugging, remove later + #expert_x.fill_(0) #torch.nan # debugging, remove later logger.debug(f"GOT HERE B {self.rank}") From 76882334fd2d64a0c7a8ac0e4f8ed82555b5ff17 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 19 Apr 2025 01:49:57 +0000 Subject: [PATCH 148/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 6 ++++-- .../layers/fused_moe/fused_moe.py | 18 +++++++++++++++--- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index 0e5e0cd7728..a8ce6c6dc2b 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -535,14 +535,14 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): dispatch_combine = PplxDispatchCombine( ata, - max_num_tokens, # // world_size? + max_num_tokens, pgi.world_size, dp_size, rank, a.dtype, ) - experts = BatchedExperts(max_num_tokens, rank) + experts = BatchedExperts(rank, pgi.world_size, max_num_tokens) fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -560,6 +560,8 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): # Chunking weights like this only works for batched format chunk_by_rank(w1, rank, world_size), chunk_by_rank(w2, rank, world_size), + #w1, + #w2, chunk_topk_weight, chunk_topk_ids, global_num_experts=num_experts #? num_local_experts? diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 23f903d33dc..07aa024289b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1827,12 +1827,18 @@ def combine( #print(f"END COMBINE {hex(id(self))}") +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: Optional[int] = None, rank: int = 0, + world_size: int = 1, + max_num_tokens: Optional[int] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, @@ -1847,6 +1853,7 @@ def __init__( assert block_m is None self.max_num_tokens = max_num_tokens self.rank = rank + self.world_size = world_size def workspace_shapes( self, @@ -1895,14 +1902,19 @@ def apply( num_local_experts = expert_num_tokens.numel() #print(f"shapes = {hidden_states.shape}, {w1.shape}, {w2.shape}, {out.shape} {expert_num_tokens.shape} {workspace2.shape} {num_experts}") + # TODO: don't need world_size or rank if expert_base always == 0 + #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" + #expert_base = rank_chunk(w1.shape[0], self.rank, self.world_size) * self.rank + expert_base = 0 + for expert in range(num_local_experts): # num_experts num = expert_num_tokens[expert] assert num <= max_num_tokens, f"{num}, {max_num_tokens}" #print(f"{type(num)}, {num}, {max_num_tokens}") if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) - self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert].transpose(0, 1)) - out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert_base + expert].transpose(0, 1)) + out[expert, :num, :] = tmp @ w2[expert_base + expert].transpose(0, 1) return out diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0f6eccab5a4..80304303626 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -246,8 +246,8 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine #print(f"block_m = {block_m}") if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): - logger.info("BatchedExperts") - experts = BatchedExperts() + logger.info(f"BatchedExperts {self.moe}") + experts = BatchedExperts() #rank=self.moe.ep_rank, world_size=self.moe.ep_size) else: experts = TritonExperts( use_fp8_w8a8 = False, From 02d42a7b5bb432921dce3500dd39e54ca590519a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 13:40:35 +0000 Subject: [PATCH 149/190] wip Signed-off-by: Bill Nell --- tests/kernels/test_pplx_moe.py | 3 -- vllm/forward_context.py | 2 +- .../layers/fused_moe/fused_moe.py | 23 +++++++------- vllm/model_executor/layers/fused_moe/layer.py | 31 ++++++++++--------- .../layers/fused_moe/triton_deep_gemm_moe.py | 7 ++--- 5 files changed, 32 insertions(+), 34 deletions(-) diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py index a8ce6c6dc2b..97fc74e3bd3 100644 --- a/tests/kernels/test_pplx_moe.py +++ b/tests/kernels/test_pplx_moe.py @@ -23,10 +23,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, torch_moe, torch_moe_single) -#from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -#from vllm.model_executor.layers.fused_moe import fused_moe -#from vllm.model_executor.layers.fused_moe.fused_batched_moe import fused_batched_experts from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size) from vllm.platforms import current_platform diff --git a/vllm/forward_context.py b/vllm/forward_context.py index f6a036d228d..ded15df1f94 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -98,7 +98,7 @@ def set_forward_context(attn_metadata: Any, from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) #TODO device? - max_tokens_across_dp = torch.max(num_tokens_tensor).to(device="cuda") + max_tokens_across_dp = torch.max(num_tokens_tensor) #.to(device="cuda") cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_rank_num_tokens = torch.tensor( [num_tokens], diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 07aa024289b..e8e297a37d2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1594,8 +1594,9 @@ def workspace_shapes( topk: int, num_experts: int, ) -> Tuple[int, int, torch.dtype]: - workspace1 = M * topk * max(N * 2, K) - workspace2 = M * topk * N + factor = num_experts if a.dim() == 3 else 1 + workspace1 = M * topk * max(N * 2, K) * factor + workspace2 = M * topk * N * factor return (workspace1, workspace2, a.dtype) def apply( @@ -1686,16 +1687,15 @@ def apply( global_num_experts, expert_map )) else: - #stride = hidden_states.shape[1] - sorted_token_ids = torch.arange(0, num_tokens*hidden_states.shape[1], device=hidden_states.device, dtype=torch.int) + max_num_tokens = hidden_states.shape[1] + sorted_token_ids = torch.arange(0, hidden_states.shape[0] * max_num_tokens, device=hidden_states.device, dtype=torch.int) sorted_token_ids = sorted_token_ids.flatten() - nans = torch.isnan(hidden_states).sum(dim=(1,2)) - expert_ids = torch.where((nans > 0).flatten(), -1, torch.arange(0, nans.numel(), device=hidden_states.device, dtype=torch.int32)) - #expert_ids = torch.repeat_interleave(expert_ids, hidden_states.shape[1], dim=0) - #print(f"EXPERT_IDS {nans.shape} {expert_ids}") + expert_ids = torch.arange(0, global_num_experts, device=hidden_states.device, dtype=torch.int) + expert_ids = torch.repeat_interleave(expert_ids, max_num_tokens, dim=0) + print(f"EXPERT_IDS {expert_ids}") #num_tokens_post_padded = torch.tensor([num_tokens], device=hidden_states.device, dtype=torch.int32) num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int32) - num_tokens_post_padded.fill_(num_tokens) + num_tokens_post_padded.fill_(max_num_tokens) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) #print(f"P = {sorted_token_ids}, {hidden_states.shape}") @@ -1857,19 +1857,18 @@ def __init__( def workspace_shapes( self, - a_dtype: torch.dtype, + a: torch.Tensor, M: int, N: int, K: int, topk: int, num_experts: int, - a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: #assert self.max_num_tokens >= a.shape[1] max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack workspace2 = max_num_tokens * N - return (workspace13, workspace2, a_dtype) + return (workspace13, workspace2, a.dtype) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 80304303626..ad48fe74dc3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -249,6 +249,7 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine logger.info(f"BatchedExperts {self.moe}") experts = BatchedExperts() #rank=self.moe.ep_rank, world_size=self.moe.ep_size) else: + logger.info(f"TritonExperts {self.moe}") experts = TritonExperts( use_fp8_w8a8 = False, use_int8_w8a16 = False, @@ -1012,21 +1013,20 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): if self.use_direct_call: return self.forward_impl(hidden_states, router_logits) - else: + elif True: return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): - max_tokens_across_dp = get_forward_context( - ).dp_metadata.max_tokens_across_dp - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - num_tokens_across_dp = get_forward_context( - ).dp_metadata.num_tokens_across_dp + ctx = get_forward_context() + + max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp + #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu + num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp - #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{get_forward_context().dp_metadata.dp_rank_num_tokens}") + #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{ctx.dp_metadata.dp_rank_num_tokens}") #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens @@ -1043,17 +1043,19 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, #print(f"ORIGINAL SHAPE {full_hidden_states.shape}") #print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}") + assert full_hidden_states.shape[0] == full_router_logits.shape[0] + for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - #print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape}") - cu_tokens_across_dp_this_iter = torch.cumsum( num_tokens_remaining_across_dp.clamp( max=moe_dp_chunk_size_per_rank), dim=0) + print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape} {cu_tokens_across_dp_this_iter}") + hidden_states = self.naive_multicast( hidden_states, cu_tokens_across_dp_this_iter) router_logits = self.naive_multicast( @@ -1088,14 +1090,14 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - #print(f"final2 (AR) = {final_hidden_states.shape}") + print(f"final2 (AR) = {final_hidden_states.shape}") if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - #print(f"final3 (AR) = {final_hidden_states.shape}") + print(f"final3 (AR) = {final_hidden_states.shape}") full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) @@ -1129,8 +1131,9 @@ def forward_impl(self, hidden_states: torch.Tensor, assert self.quant_method is not None if self.dp_size > 1: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu + print("FORWARD_IMPL") + ctx = get_forward_context() + cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu hidden_states = self.naive_multicast(hidden_states, cu_tokens_across_dp_cpu) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index be28d620f47..e85f3514160 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -37,21 +37,20 @@ def __init__( def workspace_shapes( self, - a_dtype: torch.dtype, + a: torch.Tensor, M: int, N: int, K: int, topk: int, num_experts: int, - a: torch.Tensor, ) -> Tuple[int, int, torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): - return self.deep_gemm_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a) + return self.deep_gemm_expert.workspace_shapes(a, M, N, K, topk, num_experts) else: - return self.triton_expert.workspace_shapes(a_dtype, M, N, K, topk, num_experts, a) + return self.triton_expert.workspace_shapes(a, M, N, K, topk, num_experts) def apply( self, From 27bee28b5c26e8463016062fafc01d54db0e53cf Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 16:44:42 +0000 Subject: [PATCH 150/190] undo random changes Signed-off-by: Bill Nell --- csrc/custom_all_reduce.cuh | 2 +- vllm/distributed/parallel_state.py | 8 ++++---- vllm/model_executor/models/mllama.py | 25 ------------------------- 3 files changed, 5 insertions(+), 30 deletions(-) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 186abf4712f..44709b45977 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -602,4 +602,4 @@ class CustomAllreduce { * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, half *, int, int, int); */ -} // namespace vllm +} // namespace vllm \ No newline at end of file diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ade7b5183dd..904f07dac5d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -34,10 +34,10 @@ import torch import torch.distributed -from torch.distributed import Backend, ProcessGroup from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, nvshmem_init, - nvshmem_finalize) + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) +from torch.distributed import Backend, ProcessGroup import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( @@ -917,6 +917,7 @@ def init_distributed_environment( PPLX_DID_INIT: bool = False + @run_once def pplx_init(rank, world_size): if world_size > 1: @@ -1131,7 +1132,6 @@ def destroy_model_parallel(): _DP = None - def destroy_distributed_environment(): global _WORLD if _WORLD: diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 971a4e695da..0c1d61c01f9 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1245,31 +1245,6 @@ def unpack_data(self, output_tensor[i, :t.size(0)] = t return output_tensor - def unpack_data(self, - image_data: Union[List[torch.Tensor], torch.Tensor], - padding_value=0) -> torch.Tensor: - if isinstance(image_data, torch.Tensor): - # torch.Tensor - return image_data - else: - assert isinstance( - image_data[0], - torch.Tensor), "Image data is not properly batched." - # List[torch.Tensor] - bsz = len(image_data) - max_length = max(t.size(0) for t in image_data) - trailing_dims = image_data[0].shape[1:] - for data in image_data: - cur_trailing_dims = data.shape[1:] - assert cur_trailing_dims == trailing_dims - output_tensor = torch.full((bsz, max_length, *trailing_dims), - padding_value, - dtype=image_data[0].dtype, - device=image_data[0].device) - for i, t in enumerate(image_data): - output_tensor[i, :t.size(0)] = t - return output_tensor - def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: From 089a71dde1d70740d9b1bb9de72f615c3d517d2b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 20:29:06 +0000 Subject: [PATCH 151/190] merge Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 140 +----------------- tests/kernels/moe/test_triton_moe_ptpc_fp8.py | 34 +++-- tests/kernels/quantization/test_block_fp8.py | 32 +--- .../layers/fused_moe/fused_batched_moe.py | 17 +-- .../layers/fused_moe/fused_moe.py | 113 ++++---------- vllm/model_executor/layers/fused_moe/layer.py | 51 ++----- .../layers/fused_moe/modular_kernel.py | 14 +- .../layers/fused_moe/pplx_dispatch_combine.py | 29 +--- 8 files changed, 84 insertions(+), 346 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index d24039749db..976234083fd 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -110,143 +110,6 @@ def test_fused_moe( rtol=0) -def torch_dispatch( - a: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int -) -> torch.Tensor: - assert topk_ids.dim() == 2 - assert topk_ids.shape[0] == a.shape[0] - - num_tokens = a.shape[0] - topk = topk_ids.shape[1] - - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - - max_num_tokens = tokens_per_expert.max() - b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), - dtype=a.dtype, device=a.device) - #print(f"b_a shape {b_a.shape}") - - token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) - - for token in range(num_tokens): - for j in range(topk): - expert_id = topk_ids[token, j] - idx = token_counts[expert_id] - b_a[expert_id, idx:idx+1, :] = a[token, :] - token_counts[expert_id] = token_counts[expert_id] + 1 - - return b_a, tokens_per_expert - - -def torch_combine(b_out, topk_weight, topk_ids): - num_tokens, topk = topk_ids.shape - num_experts = b_out.shape[0] - K = b_out.shape[-1] - out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) - for token in range(num_tokens): - expert_ids = topk_ids[token] - for i in range(expert_ids.numel()): - expert_id = expert_ids[i] - idx = expert_counts[expert_id] - out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] - expert_counts[expert_id] = expert_counts[expert_id] + 1 - - return out - - -def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): - num_experts = w1.shape[0] - b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) - assert b_a.dim() == 3 - num_tokens, topk = topk_ids.shape - _, max_num_tokens, K = b_a.shape - assert num_experts == b_a.shape[0] and K == w2.shape[1] - out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) - tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) - for expert in range(num_experts): - num = tokens_per_expert[expert] - if num > 0: - torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) - out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) - - return torch_combine(out, topk_weight, topk_ids) - - -# TODO: same as torch_moe but with fused_topk factored out. -def torch_moe2(a, w1, w2, topk_weight, topk_ids): - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_moe_batched_experts( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, -): - current_platform.seed_everything(7) - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - score = torch.randn((m, e), device="cuda", dtype=dtype) - - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - - if True: - triton_output = torch_batched_moe(a, - w1, - w2, - topk_weight, - topk_ids) - else: - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - triton_output = fused_batched_experts( - b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e - ) - - if False: - torch.set_printoptions(profile="full") - print("BASELINE") - print(torch_output) - print("OUTPUT") - print(triton_output) - - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - - @pytest.mark.parametrize("m", [1, 32, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 1024]) @@ -587,7 +450,8 @@ def test_fused_marlin_moe( topk_weights, topk_ids, token_expert_indices = fused_topk( a, score, topk, False) - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index 44734e9340a..3b5838a99fa 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -7,6 +7,7 @@ import torch from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.platforms import current_platform @@ -15,6 +16,10 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): """Matrix multiplication function that supports per-token input @@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale score = torch.randn((M, E), dtype=dtype) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, # using fp8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization - ) + with set_current_vllm_config(vllm_config): + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, # using fp8 + per_channel_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) # Check results rel_diff = (torch.mean( diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 11fb5000713..c06e1821c82 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -30,6 +30,10 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83, 2048] @@ -210,10 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() - vllm_config.scheduler_config.max_num_seqs = 128 - vllm_config.scheduler_config.max_model_len = 8192 - with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -261,6 +261,7 @@ def per_block_cast_to_fp8( @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -426,26 +427,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - if True: - dgm = modular_deep_gemm_fused_moe_fp8() - - def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids): - return dgm(a, - w1, - w2, - topk_weights, - topk_ids, - w1_scale=w1_s, - w2_scale=w2_s) - else: - deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() - vllm_config.scheduler_config.max_num_seqs = 128 - vllm_config.scheduler_config.max_model_len = 8192 - with set_current_vllm_config(vllm_config): if M >= 128: ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, @@ -457,8 +439,8 @@ def deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_weights, topk_ids, token_expert_indices = fused_topk( a, score.float(), topk, False) - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, + topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 56b1b343c86..e3279cd37f2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -24,7 +24,7 @@ def dispatch( topk_ids: torch.Tensor, num_experts: int, expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool = False, + apply_router_weight_on_input: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a1.shape[0] @@ -99,8 +99,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - rank: int = 0, - world_size: int = 1, max_num_tokens: Optional[int] = None, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -116,8 +114,6 @@ def __init__( assert block_shape is None assert block_m is None self.max_num_tokens = max_num_tokens - self.rank = rank - self.world_size = world_size assert not use_fp8_w8a8, "NYI" assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" @@ -171,12 +167,6 @@ def apply( (num_experts, max_num_tokens, w2.shape[1])) num_local_experts = expert_num_tokens.numel() - # TODO: don't need world_size or rank if expert_base always == 0 - #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" - #expert_base = rank_chunk(w1.shape[0], self.rank, - # self.world_size) * self.rank - expert_base = 0 - for expert in range(num_local_experts): num = expert_num_tokens[expert] assert num <= max_num_tokens, f"{num}, {max_num_tokens}" @@ -184,8 +174,7 @@ def apply( tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation( activation, tmp, hidden_states[expert, :num, :] - @ w1[expert_base + expert].transpose(0, 1)) - out[expert, :num, :] = tmp @ w2[expert_base + - expert].transpose(0, 1) + @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e8e297a37d2..ec3587973e1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -3,7 +3,6 @@ import functools import json import os -from math import prod from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -28,13 +27,6 @@ logger = init_logger(__name__) -has_deep_gemm = False -try: - import deep_gemm as dg - has_deep_gemm = True -except ImportError: - pass - @triton.jit def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, @@ -493,7 +485,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - if use_fp8_w8a8: + if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) == B_scale.shape[-2]) @@ -510,20 +502,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, M = A.shape[0] num_tokens = M * top_k - if use_fp8_w8a8: - assert B_scale is not None - assert (block_shape is None or triton.cdiv(B.shape[-2], block_shape[0]) - == B_scale.shape[-2]) - assert (block_shape is None or triton.cdiv(B.shape[-1], block_shape[1]) - == B_scale.shape[-1]) - - elif use_int8_w8a16 or use_int4_w4a16: - assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 - else: - assert A_scale is None - assert B_scale is None - EM = sorted_token_ids.shape[0] if A.shape[0] < config["BLOCK_SIZE_M"]: # optimize for small batch_size. @@ -1063,8 +1041,7 @@ def inplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> None: + block_shape: Optional[List[int]] = None) -> None: pass @@ -1098,8 +1075,7 @@ def outplace_fused_experts( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: + block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1129,8 +1105,7 @@ def outplace_fused_experts_fake( w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None, - allow_deep_gemm: bool = False) -> torch.Tensor: + block_shape: Optional[List[int]] = None) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1214,7 +1189,6 @@ def fused_experts(hidden_states: torch.Tensor, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, @@ -1302,6 +1276,19 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * top_k_num * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) + intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) + + # This needs separate memory since it's used concurrently with cache1 + intermediate_cache2 = torch.empty((M * top_k_num, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: @@ -1316,50 +1303,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, else: out_hidden_states = torch.empty_like(hidden_states) - block_m = config['BLOCK_SIZE_M'] - assert not use_dg or block_m == dg.get_m_alignment_for_contiguous_layout() - - cache1_view: Tuple[int, ...] = () - cache2_view: Tuple[int, ...] = () - cache3_view: Tuple[int, ...] = () - - if use_dg: - assert w1_scale is not None - assert w2_scale is not None - - # We attempt to transpose and align offline in Fp8MoEMethod, in which - # case these calls will be nops. Otherwise, they'll be performed every - # time the layer is executed. - w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous() - w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous() - - M_sum = topk_ids.numel() + global_num_experts * (block_m - 1) - M_sum = round_up(M_sum, block_m) - - cache1_view = (M_sum, N) - cache3_view = (M_sum, K) - else: - M_sum = M * top_k_num - cache1_view = (M, top_k_num, N) - cache3_view = (M, top_k_num, K) - - num_chunks = (num_tokens // CHUNK_SIZE) + 1 - - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - cache13 = torch.empty(M_sum * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = cache13[:M_sum * N].view(*cache1_view) - intermediate_cache2 = torch.empty((M_sum, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = cache13[:M_sum * K].view(*cache3_view) - - needs_fp8_quantization = use_fp8_w8a8 or use_dg - - for chunk in range(num_chunks): + for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens)) @@ -1369,6 +1313,17 @@ def fused_experts_impl(hidden_states: torch.Tensor, if tokens_in_chunk == 0: break + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * + topk_ids.shape[1]] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] @@ -1380,8 +1335,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, block_m, global_num_experts, - expert_map)) + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) invoke_fused_moe_kernel(qcurr_hidden_states, w1, @@ -1667,9 +1622,6 @@ def apply( raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") - #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") - #print(f"BLOCK_M = {self.block_m}") - # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, @@ -1720,8 +1672,7 @@ def apply( per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) - self.activation(activation, - intermediate_cache2, + self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ad48fe74dc3..eac1ae900d0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Callable, List, Optional, Tuple -import pplx_kernels as pplx +import pplx_kernels as pplx # TODO: guard this import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter @@ -243,19 +243,20 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine assert self.fused_experts == fused_experts block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) - #print(f"block_m = {block_m}") if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): logger.info(f"BatchedExperts {self.moe}") - experts = BatchedExperts() #rank=self.moe.ep_rank, world_size=self.moe.ep_size) + experts = BatchedExperts() else: logger.info(f"TritonExperts {self.moe}") experts = TritonExperts( use_fp8_w8a8 = False, + use_int8_w8a8 = False, use_int8_w8a16 = False, use_int4_w4a16 = False, block_shape = None, block_m = None, #block_m, + per_channel_quant = False, ) self.fused_experts = FusedMoEModularKernel( @@ -526,7 +527,7 @@ def __init__( # Use expert parallelism instead of tensor parallelism? vllm_config = get_current_vllm_config() use_ep = (vllm_config.parallel_config.enable_expert_parallel - and (self.tp_size * self.dp_size) > 1) + and self.tp_size * self.dp_size > 1) # For smuggling this layer into the fused moe custom op self.use_direct_call = self.dp_size == 1 @@ -557,7 +558,6 @@ def __init__( self.ep_size = 1 self.local_num_experts = self.global_num_experts self.expert_map = None - #self.global_num_experts = num_experts redundant? self.top_k = top_k assert intermediate_size % self.tp_size == 0 @@ -578,23 +578,20 @@ def __init__( if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - if current_platform.is_hpu(): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - #print(f"params dtype= {params_dtype}") - moe = MoEConfig( num_experts=self.global_num_experts, - experts_per_token=top_k, # ? must be same as topk_ids.shape[1] + experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, dp_size=self.dp_size, dp_rank=self.dp_rank, ep_size=self.ep_size, ep_rank=self.ep_rank, - in_dtype = params_dtype, # this is probably not right, where to get? + in_dtype = params_dtype, # this is probably not right, where to get? out_dtype = params_dtype, # ditto. ) @@ -619,14 +616,6 @@ def __init__( dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank - if False: - print(f"max num = {max_num_tokens}") - print(f"world size = {world_size}") - print(f"moe ep size = {moe.ep_size}") - print(f"moe dp size = {moe.dp_size}") - print(f"dp size = {dp_size}") - print(f"rank= {rank}") - all_to_all = get_all_to_all( max_num_tokens=max_num_tokens, num_experts=moe.num_experts, @@ -657,7 +646,7 @@ def __init__( rank, # just for debugging moe.in_dtype, ) - elif False: + elif True: logger.info("using standard dispatch") dispatch_combine = StandardDispatchCombine( moe.in_dtype, @@ -1013,7 +1002,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): if self.use_direct_call: return self.forward_impl(hidden_states, router_logits) - elif True: + else: return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) @@ -1023,11 +1012,9 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, ctx = get_forward_context() max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp - #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu + cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp - #print(f"max/num/rank_num = {max_tokens_across_dp}/{num_tokens_across_dp}/{ctx.dp_metadata.dp_rank_num_tokens}") - #In this function we define two ranges: # 1. chunk_range - The current iteration of the loops's range over the DP world tokens # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. @@ -1040,9 +1027,6 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) - #print(f"ORIGINAL SHAPE {full_hidden_states.shape}") - #print(f"moe_dp_chunk_size_per_rank = {moe_dp_chunk_size_per_rank}") - assert full_hidden_states.shape[0] == full_router_logits.shape[0] for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): @@ -1054,8 +1038,6 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, max=moe_dp_chunk_size_per_rank), dim=0) - print(f"loop {iter}: {chunk_start}:{chunk_end}, {hidden_states.shape} {cu_tokens_across_dp_this_iter}") - hidden_states = self.naive_multicast( hidden_states, cu_tokens_across_dp_this_iter) router_logits = self.naive_multicast( @@ -1079,8 +1061,6 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, activation=self.activation, ) - #print(f"final1 = {final_hidden_states.shape}") - if self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ self.dp_rank - 1] @@ -1090,27 +1070,19 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - print(f"final2 (AR) = {final_hidden_states.shape}") - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - print(f"final3 (AR) = {final_hidden_states.shape}") - full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) - #print(f"partial final = {full_final_hidden_states.shape}") - # Update bounds num_tokens_remaining_across_dp = torch.clamp( num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0) - #print(f"num remaining = {num_tokens_remaining_across_dp}") - # HACK FIX if num_tokens_remaining_across_dp.sum() == 0: break @@ -1122,8 +1094,6 @@ def update_chunk_bound(x: int): chunk_start = update_chunk_bound(chunk_start) chunk_end = update_chunk_bound(chunk_end) - #print(f"full final shape {full_final_hidden_states.shape}") - return full_final_hidden_states def forward_impl(self, hidden_states: torch.Tensor, @@ -1131,7 +1101,6 @@ def forward_impl(self, hidden_states: torch.Tensor, assert self.quant_method is not None if self.dp_size > 1: - print("FORWARD_IMPL") ctx = get_forward_context() cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 35f8b829277..d550c8b040c 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -76,7 +76,6 @@ def _moe_problem_size( return E, M, N, K, topk - class FusedMoEQuantizeDispatchCombine(ABC): """ An abstract base class for the [Quantize-Dispatch] and [Combine] steps @@ -107,7 +106,8 @@ def dispatch( - num_experts: The total number of experts in the global expert space. - expert_map: A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard. - - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. + - apply_router_weight_on_input: When True, apply the weights to the + activations, before quantization + dispatching. Returns a tuple of: - quantized + dispatched a. @@ -132,7 +132,8 @@ def combine( experts, it will have (M, topk, K) shape. - topk_weights: The weights to be applied to the fused_experts_output. - topk_ids: The topk_ids. - - apply_router_weight_on_input: When False, apply the weights to fused_expert_output. + - apply_router_weight_on_input: When False, apply the weights to + fused_expert_output. """ raise NotImplementedError @@ -312,14 +313,9 @@ def forward( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - #from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank) - #print(f"START {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") - a1 = hidden_states E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) - #print(f"INIT shape: E={E}, M={M}, N={N}, K={K}, top_k={top_k}") - if global_num_experts == -1: global_num_experts = E @@ -364,6 +360,4 @@ def forward( self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) - #print(f"DONE {hidden_states.shape} {topk_ids.shape} {get_tensor_model_parallel_rank()}/{get_dp_group().rank_in_group}") - return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 90bfa385dac..ef5da7a5d9e 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -9,8 +9,6 @@ moe_kernel_quantize_input) -logger = init_logger(__name__) - # Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. @@ -46,7 +44,6 @@ def dispatch( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: # Is this always going to be a1.device? device = a1.device - num_tokens = a1.shape[0] # M hidden_dim = a1.shape[-1] # K assert expert_map is None, "NYI" @@ -75,18 +72,13 @@ def dispatch( dtype=torch.int32, device=device, ) - #expert_num_tokens.fill_(-1) # debugging, remove later num_dp = self.world_size // self.dp_size - logger.debug(f"GOT HERE A {self.rank}: {self.max_num_tokens} {num_dp} {hidden_dim}") expert_x = torch.empty( - (num_local_experts, self.max_num_tokens * num_dp, a1q.shape[-1]), + (num_local_experts, self.max_num_tokens * num_dp, hidden_dim), dtype=a1q.dtype, device=device, ) - #expert_x.fill_(0) #torch.nan # debugging, remove later - - logger.debug(f"GOT HERE B {self.rank}") expert_x_scale: Optional[torch.Tensor] = None if a1q.dtype.itemsize == 1: @@ -103,11 +95,10 @@ def dispatch( device=device, ) - logger.debug(f"GOT HERE C {self.rank}") - # This argument is optional, defaults to indices.shape[0] - # This causes a deadlock???? + # This causes a deadlock? #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens + #num_tokens = a1.shape[0] # M #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) bound_m = None @@ -133,23 +124,17 @@ def combine( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, ) -> None: - device = fused_expert_output.device - #device = torch.device("cuda", self.rank) - #device = get_dp_group().device - #assert fused_expert_output.device == device - - logger.debug(f"COMBINE START {self.rank}") - # This argument is optional #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens #num_tokens = fused_expert_output.shape[0] # M - #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) + #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + # device=fused_expert_output.device) bound_m = None assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] - # Set weights to 1 if we did them in dispatch. This is hacky. + # Set weights to 1 if we did them in dispatch. This is hacky. if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) @@ -158,5 +143,3 @@ def combine( weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) - - logger.debug(f"COMBINE END {self.rank}") From db3f01dd24ae0ce0e73ee951725295d74a52dd12 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 20:37:43 +0000 Subject: [PATCH 152/190] tweak Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 405ced54d2e..696a1cb4d60 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -523,13 +523,9 @@ def _pplx_moe( m, k = a.shape e, _, n = w2.shape - torch.set_printoptions(profile="full") - with set_current_vllm_config(vllm_config): topk_weight, topk_ids = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplx_output = torch_pplx_moe(pgi, dp_size, a, w1, w2, score, topk) torch_output = chunk_by_rank(torch_output, pgi.rank, From 17a978ba8f0dc023d105f80c858025926ab647a7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 21:22:58 +0000 Subject: [PATCH 153/190] revert hack Signed-off-by: Bill Nell --- examples/offline_inference/data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 1c070105189..965915beaf5 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -160,7 +160,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=3000) + proc.join(timeout=300) if proc.exitcode is None: print(f"Killing process {proc.pid} that " f"didn't stop within 5 minutes.") From 323d12ffad0fb809e04fe07c38413a3a72ed9af4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 21:39:44 +0000 Subject: [PATCH 154/190] fixes Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 696a1cb4d60..b58c2d2c6d3 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -471,10 +471,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): pgi.world_size, dp_size, rank, - a.dtype, ) - experts = BatchedExperts(rank, pgi.world_size, max_num_tokens) + experts = BatchedExperts(max_num_tokens) fused_experts = FusedMoEModularKernel( dispatch_combine, From bb4e896553c4f658e728cc96c771ff1c1d5533e2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 21:45:57 +0000 Subject: [PATCH 155/190] pplx update Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index b58c2d2c6d3..aeedadea385 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -300,7 +300,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): rank_num_tokens = rank_chunk(num_tokens, rank, world_size) max_num_tokens = num_tokens - ata = AllToAll( + ata = AllToAll.internode( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, @@ -448,7 +448,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): rank_num_tokens = rank_chunk(num_tokens, rank, world_size) max_num_tokens = num_tokens - ata = AllToAll( + ata = AllToAll.internode( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, From 1260147570d06bcdadf4d3f0e690f9cd99f5403c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 22:17:50 +0000 Subject: [PATCH 156/190] varun's fixes Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 158 +++++ vllm/distributed/parallel_state.py | 4 +- .../layers/fused_moe/fused_batched_moe.py | 627 +++++++++++++++++- vllm/model_executor/layers/fused_moe/layer.py | 52 +- .../layers/fused_moe/pplx_dispatch_combine.py | 18 +- vllm/model_executor/models/deepseek_v2.py | 4 +- 6 files changed, 824 insertions(+), 39 deletions(-) create mode 100644 tests/kernels/moe/test_batched_moe.py diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py new file mode 100644 index 00000000000..ffd69935b46 --- /dev/null +++ b/tests/kernels/moe/test_batched_moe.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + +import pytest +from dataclasses import dataclass + +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + invoke_moe_batched_triton_kernel, + invoke_batched_silu_and_mul) + + +@dataclass +class BatchedMMConfig: + dtype: torch.dtype + num_experts: int + max_tokens_per_expert: int + K: int + N: int + +@dataclass +class BatchedMMTensors: + A: torch.Tensor # [E, max_tokens, K] + B: torch.Tensor # [E, K, N] - column major + C: torch.Tensor # [E, max_tokens, N] + num_expert_tokens: torch.Tensor # [E] + + @staticmethod + def make_tensors(config: BatchedMMConfig): + A = torch.randn((config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", dtype=config.dtype) / 50.0 + B = torch.randn((config.num_experts, config.N, config.K), device="cuda", dtype=config.dtype) / 50.0 + C = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", dtype=config.dtype) + num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) + return BatchedMMTensors(A,B,C, num_expert_tokens) + + +def ref_impl(A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: + + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e] + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) + + + return C + +@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", [512]) +@pytest.mark.parametrize("K", [256]) +@pytest.mark.parametrize("N", [512]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_batched_mm(num_experts: int, + max_tokens_per_expert: int, + K: int, + N: int, + dtype: torch.dtype): + + config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) + tensors = BatchedMMTensors.make_tensors(config) + + test_output = tensors.C + ref_output = test_output.clone() + + + compute_tl_dtype = {torch.float16 : tl.float16, + torch.bfloat16 : tl.bfloat16, + torch.float32 : tl.float32}[test_output.dtype] + invoke_moe_batched_triton_kernel(tensors.A, + tensors.B, + test_output, + tensors.num_expert_tokens, + compute_tl_dtype, + # Quantization data + None, + None, + None, + # Quantization schemes + False, + False, + False, + config = {"BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16}) + + + ref_output = ref_impl(tensors.A, tensors.B, ref_output, tensors.num_expert_tokens) + #torch.cuda.synchronize() + #print (f"ref output {ref_output}") + #print (f"test output {test_output}") + + torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3) + + +@dataclass +class BatchedSiluMulConfig: + dtype: torch.dtype + num_experts: int + max_tokens_per_expert: int + D: int + +@dataclass +class BatchedSiluMulTensors: + input: torch.Tensor + output: torch.Tensor + expert_num_tokens: torch.Tensor + + @staticmethod + def make_tensors(config: BatchedSiluMulConfig): + input = torch.randn((config.num_experts, config.max_tokens_per_expert, config.D * 2), device="cuda", dtype=config.dtype) / 50.0 + output = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.D), device="cuda", dtype=config.dtype) + num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) + return BatchedSiluMulTensors(input, output, num_expert_tokens) + + +def ref_batched_silu_mul( + output: torch.Tensor, + input: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: + + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e].item() + out_part = output[e, :num_tokens, :] + in_part = input[e, :num_tokens, :] + torch.ops._C.silu_and_mul(out_part, in_part) + + +@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", [128]) +@pytest.mark.parametrize("D", [128, 256]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_batched_silu_mul(num_experts: int, + max_tokens_per_expert: int, + D: int, + dtype: torch.dtype): + + config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D) + tensors = BatchedSiluMulTensors.make_tensors(config) + + test_out = tensors.output + ref_out = torch.zeros_like(test_out) + + ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens) + + invoke_batched_silu_and_mul(test_out, tensors.input, tensors.expert_num_tokens) + + torch.testing.assert_close(test_out, ref_out) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 904f07dac5d..cf715681c87 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -923,12 +923,12 @@ def pplx_init(rank, world_size): if world_size > 1: try: global PPLX_DID_INIT - print(f"PPLX_INIT {rank} {world_size}") + logger.debug(f"PPLX_INIT {rank} {world_size}") uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() uid_gpu = uid.cuda() get_world_group().broadcast(uid_gpu, src=0) - print(f"PPLX_INIT UID={uid_gpu}") + logger.debug(f"PPLX_INIT UID={uid_gpu}") uid = uid_gpu.to(device='cpu') nvshmem_init(uid, rank, world_size) PPLX_DID_INIT = True diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index e3279cd37f2..907670cbb7b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -3,9 +3,465 @@ from typing import List, Optional, Tuple import torch +import triton +import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_config_dtype_str, + try_get_optimal_moe_config, +) + +@triton.jit +def batched_silu_and_mul_kernel(output, # [E, MAX_NUM_TOKENS, D] + input, # [E, MAX_NUM_TOKENS, D * 2] + expert_num_tokens, # [E] + stride_oe, + stride_om, + stride_ie, + stride_im, + compute_type: tl.constexpr, + D, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr): + + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # early exit + return + + pid_m = tl.program_id(axis=1) + cta_m_start = pid_m * BLOCK_M + if cta_m_start >= e_num_tokens: + # early exit + return + + cta_input_ptr = input + expert_id * stride_ie + cta_m_start * stride_im + cta_output_ptr = output + expert_id * stride_oe + cta_m_start * stride_om + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + offs_m = tl.arange(0, BLOCK_M)[:, None] + mask_m = offs_m < cta_m_size + + cta_input_ptrs = cta_input_ptr + offs_m * stride_im + cta_output_ptrs = cta_output_ptr + offs_m * stride_om + + # offset by D + offs_D = tl.arange(0, BLOCK_D) + cta_input_ptrs = cta_input_ptrs + offs_D + cta_output_ptrs = cta_output_ptrs + offs_D + + for d in range(0, tl.cdiv(D, BLOCK_D)): + mask_D = offs_D < (D - (d * BLOCK_D)) + mask_tile = mask_m & mask_D + + x_tile = tl.load(cta_input_ptrs, mask=mask_tile, other=0.0).to(dtype=tl.float32) + y_tile = tl.load(cta_input_ptrs + D, mask=mask_tile, other=0.0) + + # silu and mul + out_tile = (x_tile * (1.0 / (1.0 + tl.exp(-x_tile)))).to(dtype=compute_type) + out_tile = out_tile * y_tile + tl.store(cta_output_ptrs, out_tile, mask=mask_tile) + + cta_input_ptrs = cta_input_ptrs + BLOCK_D + cta_output_ptrs = cta_output_ptrs + BLOCK_D + +@triton.jit +def moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr): + + offs_k = tl.arange(0, BLOCK_K) + + if use_w8a16: + b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + offs_bsn = offs_n // group_n + b_scale_ptrs = (b_scale_ptr + expert_id * stride_bse + + offs_bsn * stride_bsn) + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + expert_id) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=mask_m[:, None] & + (offs_k[None, :] < K - k * BLOCK_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_K, + other=0.0) + # We accumulate along the K dimension. + if use_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, + mask=mask_m, + other=0.0) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, + None] * b_scale[None, :] + else: + if use_w8a8: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + if use_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + + return accumulator + + +@triton.jit +def expert_triton_kernel(a_ptr, #[max_tokens, K] + b_ptr, #[K, N] + c_ptr, #[max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) % N + offs_k = tl.arange(0, BLOCK_K) + mask_m = offs_m < M + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + + + accumulator = moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n, + group_k, + # Meta-parameters + BLOCK_M, + BLOCK_N, + BLOCK_K, + compute_type, + use_fp8_w8a8, + use_int8_w8a16) + + # store in C + offs_cn = tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = mask_m[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + +@triton.jit +def batched_triton_kernel(a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_ce, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # Early exit + return + + pid_mn = tl.program_id(axis=1) + num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid_mn // num_pid_n + pid_n = pid_mn % num_pid_n + + cta_m_start = pid_m * BLOCK_M + cta_n_start = pid_n * BLOCK_N + if cta_m_start >= e_num_tokens: + # Early exit + return + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + cta_n_size = min(BLOCK_N, N - cta_n_start) + + a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am + b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn + c_ptr = c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + cta_n_start * stride_cn + + expert_triton_kernel(a_ptr, + b_ptr, + c_ptr, + expert_id, + compute_type, + cta_m_size, # M + cta_n_size, # N + K, # K + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M, + BLOCK_N, + BLOCK_K) + + +def invoke_moe_batched_triton_kernel(A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: torch.Tensor, + B_scale: torch.Tensor, + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None): + + assert not use_int4_w4a16 + max_num_tokens = A.size(1) + K = A.size(2) + N = C.size(2) + + BLOCK_M = config['BLOCK_SIZE_M'] + BLOCK_N = config['BLOCK_SIZE_N'] + BLOCK_K = config['BLOCK_SIZE_K'] + assert max_num_tokens % BLOCK_M == 0 + + grid = (expert_num_tokens.size(0), + triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.shape[1], BLOCK_N)) + + batched_triton_kernel[grid](A, + B, + C, + expert_num_tokens, + compute_type, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + A_scale, + B_scale, + B_zp, + # Strides + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + # Blockwise quantization data + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M = BLOCK_M, + BLOCK_N = BLOCK_N, + BLOCK_K = BLOCK_K) + + +def invoke_batched_silu_and_mul(output : torch.Tensor, #[E, MAX_TOKENS, D] + input: torch.Tensor, #[E, MAX_TOKENS, D * 2] + expert_num_tokens: torch.Tensor): + + + num_experts = output.size(0) + max_num_tokens = output.size(1) + D = output.size(2) + + BLOCK_D = 1024 + BLOCK_M = 1 + + compute_tl_dtype = {torch.float16 : tl.float16, + torch.float32 : tl.float32, + torch.bfloat16 : tl.bfloat16}[output.dtype] + + #print(f"compute type {compute_tl_dtype}") + + grid = (num_experts, triton.cdiv(max_num_tokens, BLOCK_M)) + batched_silu_and_mul_kernel[grid](output, + input, + expert_num_tokens, + output.stride(0), + output.stride(1), + input.stride(0), + input.stride(1), + compute_tl_dtype, + D, + BLOCK_M, + BLOCK_D) class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): @@ -90,11 +546,6 @@ def combine( expert_counts[expert_id] = expert_counts[expert_id] + 1 -def rank_chunk(num, r, w): - rem = num % w - return (num // w) + (1 if r < rem else 0) - - class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -108,16 +559,13 @@ def __init__( block_m: Optional[int] = None, ): super().__init__() - assert not use_fp8_w8a8 - assert not use_int4_w4a16 - assert not use_int8_w8a16 assert block_shape is None assert block_m is None - self.max_num_tokens = max_num_tokens assert not use_fp8_w8a8, "NYI" assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" + self.max_num_tokens = max_num_tokens def workspace_shapes( self, @@ -178,3 +626,164 @@ def apply( out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out + + +class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + max_num_tokens: Optional[int] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + block_shape: Optional[List[int]] = None, + ): + super().__init__() + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int8_w8a8 = use_int8_w8a8 + self.use_int4_w4a16 = use_int4_w4a16 + self.use_int8_w8a16 = use_int8_w8a16 + self.block_shape = block_shape + self.max_num_tokens = max_num_tokens + assert not use_int8_w8a8, "NYI" + assert not use_int4_w4a16, "NYI" + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> Tuple[int, int, torch.dtype]: + max_num_tokens = a.shape[ + 1] if self.max_num_tokens is None else self.max_num_tokens + workspace13 = num_experts * max_num_tokens * max(K, N) + workspace2 = num_experts * max_num_tokens * (N // 2) + return (workspace13, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + + num_tokens = topk_ids.size(0) + #print_debug = expert_map[0] != -1 and num_tokens < 50 and num_tokens != 1 and False + + # Check constraints. + if self.use_int4_w4a16: + assert hidden_states.shape[-1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[-1] == w1.shape[ + 2], f"Hidden size mismatch {hidden_states.shape[-1]} != {w1.shape[2]}" + + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + assert w1.shape[0] == E + assert w2.shape[0] == E + + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + dtype=hidden_states.dtype) + + config = try_get_optimal_moe_config( + w1.shape, + w2.shape, + top_k_num, + config_dtype, + num_tokens, + block_shape=self.block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") + + #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N)) + intermediate_cache2 = _resize_cache(workspace2, + (E, num_tokens, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K)) + + # MM1 + invoke_moe_batched_triton_kernel(A=hidden_states, + B=w1, + C=intermediate_cache1, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a1q_scale, + B_scale=w1_scale, + B_zp=w1_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + block_shape=self.block_shape) + + # Fix activations + assert activation == "silu" + invoke_batched_silu_and_mul(output=intermediate_cache2, + input=intermediate_cache1, + expert_num_tokens=expert_num_tokens) + + qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale + # TODO (varun) : support w8a8 + assert not self.use_fp8_w8a8 + #if self.use_fp8_w8a8: + # qintermediate_cache2, a2q_scale = _fp8_quantize( + # intermediate_cache2, a2_scale, self.block_shape) + + invoke_moe_batched_triton_kernel(A=intermediate_cache2, + B=w2, + C=intermediate_cache3, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a2q_scale, + B_scale=w2_scale, + B_zp=w2_zp, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + config=config, + block_shape=self.block_shape) + + return intermediate_cache3 diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index eac1ae900d0..6147b17127f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,7 +29,8 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine - from .fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts + from .fused_moe import TritonExperts, fused_experts + from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine from .pplx_dispatch_combine import PplxDispatchCombine else: @@ -117,7 +118,8 @@ def get_or_create(self, **kwargs): with self._lock: instance = self._cache.get(key) if instance is None: - instance = pplx.AllToAll(**kwargs) + # TODO: should be intranode + instance = pplx.AllToAll.internode(**kwargs) self._cache[key] = instance return instance @@ -245,8 +247,14 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): - logger.info(f"BatchedExperts {self.moe}") - experts = BatchedExperts() + logger.info(f"BatchedTritonExperts {self.moe}") + experts = BatchedTritonExperts( + use_fp8_w8a8 = False, + use_int8_w8a8 = False, + use_int8_w8a16 = False, + use_int4_w4a16 = False, + block_shape = None, + ) else: logger.info(f"TritonExperts {self.moe}") experts = TritonExperts( @@ -255,7 +263,6 @@ def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine use_int8_w8a16 = False, use_int4_w4a16 = False, block_shape = None, - block_m = None, #block_m, per_channel_quant = False, ) @@ -1038,10 +1045,12 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, max=moe_dp_chunk_size_per_rank), dim=0) - hidden_states = self.naive_multicast( - hidden_states, cu_tokens_across_dp_this_iter) - router_logits = self.naive_multicast( - router_logits, cu_tokens_across_dp_this_iter) + # TODO: still may be needed for non-pplx, put into dispatcher class. + if False: + hidden_states = self.naive_multicast( + hidden_states, cu_tokens_across_dp_this_iter) + router_logits = self.naive_multicast( + router_logits, cu_tokens_across_dp_this_iter) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -1061,7 +1070,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, activation=self.activation, ) - if self.dp_size > 1: + # TODO: needed for non-pplx? + if False and self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ self.dp_rank - 1] end = cu_tokens_across_dp_this_iter[self.dp_rank] @@ -1070,7 +1080,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # TODO: needed for non-pplx? + if False and self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -1091,8 +1102,14 @@ def update_chunk_bound(x: int): return min(x + moe_dp_chunk_size_per_rank, full_hidden_states.shape[0]) - chunk_start = update_chunk_bound(chunk_start) - chunk_end = update_chunk_bound(chunk_end) + #chunk_start = update_chunk_bound(chunk_start) + #chunk_end = update_chunk_bound(chunk_end) + if chunk_end == full_hidden_states.shape[0]: + # simply redo computation + pass + else: + chunk_start = update_chunk_bound(chunk_start) + chunk_end = update_chunk_bound(chunk_end) return full_final_hidden_states @@ -1100,7 +1117,8 @@ def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None - if self.dp_size > 1: + # TODO: still may be needed for non-pplx + if False and self.dp_size > 1: ctx = get_forward_context() cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu @@ -1128,7 +1146,8 @@ def forward_impl(self, hidden_states: torch.Tensor, apply_router_weight_on_input=self.apply_router_weight_on_input, ) - if self.dp_size > 1: + # TODO: needed for non-pplx? + if False and self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] @@ -1136,7 +1155,8 @@ def forward_impl(self, hidden_states: torch.Tensor, all_hidden_states = get_dp_group().all_reduce(final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + # TODO: needed for non-pplx? + if False and self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index ef5da7a5d9e..576c454ec31 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -46,7 +46,8 @@ def dispatch( device = a1.device hidden_dim = a1.shape[-1] # K - assert expert_map is None, "NYI" + # ?? + # assert expert_map is None, "NYI" if apply_router_weight_on_input: topk = rank_topk_ids.shape[1] @@ -96,11 +97,8 @@ def dispatch( ) # This argument is optional, defaults to indices.shape[0] - # This causes a deadlock? - #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens - #num_tokens = a1.shape[0] # M - #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) - bound_m = None + num_tokens = a1.shape[0] # M + bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) # TODO: optimize this? indices = rank_topk_ids.to(dtype=torch.uint32) @@ -125,11 +123,9 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - #bound_m = get_forward_context().dp_metadata.dp_rank_num_tokens - #num_tokens = fused_expert_output.shape[0] # M - #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, - # device=fused_expert_output.device) - bound_m = None + num_tokens = output.shape[0] # M + bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + device=fused_expert_output.device) assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ce86b9b2c4f..9e9b8d336de 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -171,7 +171,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # See DeepseekV2DecoderLayer for more details. final_hidden_states = final_hidden_states + shared_output \ * (1. / self.routed_scaling_factor) - if self.tp_size > 1: + + # TODO: check if needed for non-pplx? + if False and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) From cb7bec9db83c4bf5e7b350339c6426c138e8b5a8 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 22:23:35 +0000 Subject: [PATCH 157/190] varun's fixes Signed-off-by: Bill Nell --- .../layers/fused_moe/pplx_dispatch_combine.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 576c454ec31..f88044da020 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -123,9 +123,10 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - num_tokens = output.shape[0] # M - bound_m = torch.tensor([num_tokens], dtype=torch.uint32, - device=fused_expert_output.device) + #num_tokens = output.shape[0] # M + #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + # device=fused_expert_output.device) + bound_m = None assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] From a403f7cc6a2e728249462f6bcf13e04357e34b3e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 22:25:51 +0000 Subject: [PATCH 158/190] tweak bound_m Signed-off-by: Bill Nell --- .../layers/fused_moe/pplx_dispatch_combine.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index f88044da020..576c454ec31 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -123,10 +123,9 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - #num_tokens = output.shape[0] # M - #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, - # device=fused_expert_output.device) - bound_m = None + num_tokens = output.shape[0] # M + bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + device=fused_expert_output.device) assert output.shape[0] <= self.max_num_tokens assert output.shape[1] == fused_expert_output.shape[-1] From 501d04cfac177bcc2e464df457fc661f9b6ea87c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 22:59:42 +0000 Subject: [PATCH 159/190] run linter Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 134 ++-- tests/kernels/moe/test_moe.py | 3 +- tests/kernels/quantization/test_block_fp8.py | 5 +- tests/kernels/test_block_fp8.py | 499 ------------- tests/kernels/test_pplx_moe.py | 654 ------------------ vllm/forward_context.py | 9 +- .../layers/fused_moe/deep_gemm_moe.py | 4 +- .../layers/fused_moe/fused_batched_moe.py | 500 ++++++------- .../layers/fused_moe/fused_moe.py | 241 ++----- vllm/model_executor/layers/fused_moe/layer.py | 97 +-- .../layers/fused_moe/modular_kernel.py | 4 +- .../layers/fused_moe/pplx_dispatch_combine.py | 12 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 40 +- .../model_executor/layers/quantization/fp8.py | 26 +- 14 files changed, 472 insertions(+), 1756 deletions(-) delete mode 100644 tests/kernels/test_block_fp8.py delete mode 100644 tests/kernels/test_pplx_moe.py diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index ffd69935b46..1bb8f4e09dd 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 -import torch -import triton -import triton.language as tl +from dataclasses import dataclass import pytest -from dataclasses import dataclass +import torch +import triton.language as tl from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - invoke_moe_batched_triton_kernel, - invoke_batched_silu_and_mul) + invoke_batched_silu_and_mul, invoke_moe_batched_triton_kernel) @dataclass @@ -20,25 +18,36 @@ class BatchedMMConfig: K: int N: int + @dataclass class BatchedMMTensors: A: torch.Tensor # [E, max_tokens, K] B: torch.Tensor # [E, K, N] - column major C: torch.Tensor # [E, max_tokens, N] - num_expert_tokens: torch.Tensor # [E] + num_expert_tokens: torch.Tensor # [E] @staticmethod def make_tensors(config: BatchedMMConfig): - A = torch.randn((config.num_experts, config.max_tokens_per_expert, config.K), device="cuda", dtype=config.dtype) / 50.0 - B = torch.randn((config.num_experts, config.N, config.K), device="cuda", dtype=config.dtype) / 50.0 - C = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", dtype=config.dtype) - num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) - return BatchedMMTensors(A,B,C, num_expert_tokens) - - -def ref_impl(A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, + A = torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.K), + device="cuda", + dtype=config.dtype) / 50.0 + B = torch.randn((config.num_experts, config.N, config.K), + device="cuda", + dtype=config.dtype) / 50.0 + C = torch.zeros( + (config.num_experts, config.max_tokens_per_expert, config.N), + device="cuda", + dtype=config.dtype) + num_expert_tokens = torch.randint(low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts, ), + device="cuda", + dtype=torch.int32) + return BatchedMMTensors(A, B, C, num_expert_tokens) + + +def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, num_expert_tokens: torch.Tensor) -> torch.Tensor: num_expert_tokens_cpu = num_expert_tokens.clone() @@ -49,19 +58,16 @@ def ref_impl(A: torch.Tensor, num_tokens = num_expert_tokens_cpu[e] C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) - return C + @pytest.mark.parametrize("num_experts", [16, 32]) @pytest.mark.parametrize("max_tokens_per_expert", [512]) @pytest.mark.parametrize("K", [256]) @pytest.mark.parametrize("N", [512]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_batched_mm(num_experts: int, - max_tokens_per_expert: int, - K: int, - N: int, - dtype: torch.dtype): +def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, + N: int, dtype: torch.dtype): config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) tensors = BatchedMMTensors.make_tensors(config) @@ -69,29 +75,33 @@ def test_batched_mm(num_experts: int, test_output = tensors.C ref_output = test_output.clone() - - compute_tl_dtype = {torch.float16 : tl.float16, - torch.bfloat16 : tl.bfloat16, - torch.float32 : tl.float32}[test_output.dtype] - invoke_moe_batched_triton_kernel(tensors.A, - tensors.B, - test_output, - tensors.num_expert_tokens, - compute_tl_dtype, - # Quantization data - None, - None, - None, - # Quantization schemes - False, - False, - False, - config = {"BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 16}) - - - ref_output = ref_impl(tensors.A, tensors.B, ref_output, tensors.num_expert_tokens) + compute_tl_dtype = { + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32 + }[test_output.dtype] + invoke_moe_batched_triton_kernel( + tensors.A, + tensors.B, + test_output, + tensors.num_expert_tokens, + compute_tl_dtype, + # Quantization data + None, + None, + None, + # Quantization schemes + False, + False, + False, + config={ + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16 + }) + + ref_output = ref_impl(tensors.A, tensors.B, ref_output, + tensors.num_expert_tokens) #torch.cuda.synchronize() #print (f"ref output {ref_output}") #print (f"test output {test_output}") @@ -106,6 +116,7 @@ class BatchedSiluMulConfig: max_tokens_per_expert: int D: int + @dataclass class BatchedSiluMulTensors: input: torch.Tensor @@ -114,16 +125,24 @@ class BatchedSiluMulTensors: @staticmethod def make_tensors(config: BatchedSiluMulConfig): - input = torch.randn((config.num_experts, config.max_tokens_per_expert, config.D * 2), device="cuda", dtype=config.dtype) / 50.0 - output = torch.zeros((config.num_experts, config.max_tokens_per_expert, config.D), device="cuda", dtype=config.dtype) - num_expert_tokens=torch.randint(low = 0, high = config.max_tokens_per_expert, size=(config.num_experts,), device="cuda", dtype=torch.int32) + input = torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.D * 2), + device="cuda", + dtype=config.dtype) / 50.0 + output = torch.zeros( + (config.num_experts, config.max_tokens_per_expert, config.D), + device="cuda", + dtype=config.dtype) + num_expert_tokens = torch.randint(low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts, ), + device="cuda", + dtype=torch.int32) return BatchedSiluMulTensors(input, output, num_expert_tokens) -def ref_batched_silu_mul( - output: torch.Tensor, - input: torch.Tensor, - num_expert_tokens: torch.Tensor) -> torch.Tensor: +def ref_batched_silu_mul(output: torch.Tensor, input: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: num_expert_tokens_cpu = num_expert_tokens.clone() num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") @@ -140,10 +159,8 @@ def ref_batched_silu_mul( @pytest.mark.parametrize("max_tokens_per_expert", [128]) @pytest.mark.parametrize("D", [128, 256]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_batched_silu_mul(num_experts: int, - max_tokens_per_expert: int, - D: int, - dtype: torch.dtype): +def test_batched_silu_mul(num_experts: int, max_tokens_per_expert: int, D: int, + dtype: torch.dtype): config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D) tensors = BatchedSiluMulTensors.make_tensors(config) @@ -153,6 +170,7 @@ def test_batched_silu_mul(num_experts: int, ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens) - invoke_batched_silu_and_mul(test_out, tensors.input, tensors.expert_num_tokens) + invoke_batched_silu_and_mul(test_out, tensors.input, + tensors.expert_num_tokens) torch.testing.assert_close(test_out, ref_out) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 976234083fd..85df55d0ac1 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -15,8 +15,7 @@ torch_moe_single) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c06e1821c82..ef1d7e47ef8 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8, modular_deep_gemm_fused_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -439,8 +439,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): topk_weights, topk_ids, token_expert_indices = fused_topk( a, score.float(), topk, False) - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py deleted file mode 100644 index 762d0239408..00000000000 --- a/tests/kernels/test_block_fp8.py +++ /dev/null @@ -1,499 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Adapted from https://github.com/sgl-project/sglang/pull/2575 -import itertools - -import pytest -import torch - -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) -from vllm.platforms import current_platform - -dg_available = False -try: - import deep_gemm - dg_available = True -except ImportError: - pass - -if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) - -# Test configurations -DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] -NUM_TOKENS = [7, 83, 2048] -D = [512, 4096, 5120, 13824] -GROUP_SIZE = [64, 128, 256, 512] -M = [1, 7, 8, 83, 84, 512, 2048, 4096] -N = [128, 512, 1024, 4096, 7168, 7748, 13824] -K = [256, 4096, 5120, 3884, 13824, 16384] -# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 -# and its hidden size is 7168. -M_moe = [1, 2, 7, 83, 128, 512, 2048] -M_moe_dg = [1, 128, 192, 512, 1335, 2048] -N_moe = [128, 256, 4608] # [13824] -K_moe = [256, 512, 7168] # [13824] -BLOCK_SIZE = [[128, 128]] -E = [2, 8, 16, 24] # [128, 256] -TOP_KS = [1, 2, 6] -OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] -SEEDS = [0] - - -def native_per_token_group_quant_fp8(x, - group_size, - eps=1e-10, - dtype=torch.float8_e4m3fn): - """Function to perform per-token-group quantization on an input tensor - `x` using native torch.""" - assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " - "be divisible by `group_size`") - assert x.is_contiguous(), "`x` is not contiguous" - - finfo = torch.finfo(dtype) - fp8_min = finfo.min - fp8_max = finfo.max - - x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) - x_s = amax / fp8_max - x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) - x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) - - return x_q, x_s - - -def native_w8a8_block_fp8_matmul(A, - B, - As, - Bs, - block_size, - output_dtype=torch.float16): - """Matrix multiplication with block-wise quantization using native torch.""" - A = A.to(torch.float32) - B = B.to(torch.float32) - assert A.shape[-1] == B.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] - assert A.shape[:-1] == As.shape[:-1] - - M = A.numel() // A.shape[-1] - N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) - A = A.reshape(M, A.shape[-1]) - As = As.reshape(M, As.shape[-1]) - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - assert n_tiles == Bs.shape[0] - assert k_tiles == Bs.shape[1] - - C_shape = (M, N) - C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) - - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) - ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - a = A_tiles[i] - b = B_tiles[j][i] - c = C_tiles[j] - s = As_tiles[i] * Bs[j][i] - c[:, :] += torch.matmul(a, b.t()) * s - - C = C.reshape(origin_C_shape).to(output_dtype) - return C - - -def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): - """Fused moe with block-wise quantization using native torch.""" - B, D = a.shape - a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) - out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) - score = torch.softmax(score, dim=-1, dtype=torch.float32) - topk_weight, topk_ids = torch.topk(score, topk) - topk_weight = topk_weight.view(-1) - topk_ids = topk_ids.view(-1) - - _, block_k = block_shape[0], block_shape[1] - a_q, a_s = native_per_token_group_quant_fp8(a, block_k) - a_q = a_q.to(torch.float32) - for i in range(w1.shape[0]): - mask = topk_ids == i - if mask.sum(): - inter_out = native_w8a8_block_fp8_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k) - act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_fp8_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) - - -# Skip all tests if CUDA is not available -pytest.importorskip("torch.cuda") - - -@pytest.fixture(autouse=True) -def setup_cuda(): - torch.set_default_device("cuda") - - -@pytest.mark.parametrize( - "num_tokens,d,dtype,group_size,seed", - itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS)) -@torch.inference_mode() -def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): - torch.manual_seed(seed) - x = torch.rand(num_tokens, d, dtype=dtype) - - ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) - out, scale = per_token_group_quant_fp8(x, group_size) - - assert torch.allclose(out.to(torch.float32), - ref_out.to(torch.float32), - rtol=0.15) - assert torch.allclose(scale, ref_scale) - - -@pytest.mark.parametrize( - "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) -@torch.inference_mode() -def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - - block_n, block_k = block_size[0], block_size[1] - n_tiles = (N + block_n - 1) // block_n - k_tiles = (K + block_k - 1) // block_k - - As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale - Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - - ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) - out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.001 - - -@pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) -@torch.inference_mode() -def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): - if topk > E: - pytest.skip(f"Skipping test; topk={topk} > E={E}") - - torch.manual_seed(seed) - factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - vllm_config = VllmConfig() - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = (torch.rand( - (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w1_bf16 - - w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max - w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) - del w2_bf16 - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = (2 * N + block_n - 1) // block_n - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - k_tiles_w2 = (N + block_k - 1) // block_k - - w1_s = torch.rand( - (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale - w2_s = torch.rand( - (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale - - score = torch.randn((M, E), dtype=dtype) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=block_size, - ) - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, - block_size) - - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 - - -def per_block_cast_to_fp8( - x: torch.Tensor, - block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (deep_gemm.ceil_div(m, 128) * 128, - deep_gemm.ceil_div(n, block_size_n) * block_size_n), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - -@pytest.mark.parametrize( - "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) -@torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): - # only aligned sizes - if M % 4 != 0 or K % 128 != 0 or N % 64 != 0: - pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}") - - torch.manual_seed(seed) - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max = fp8_info.max - - A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - - _, block_k = block_size[0], block_size[1] - - A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k) - B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) - - As = As_fp8.to(torch.float32) - Bs = Bs_fp8.to(torch.float32) - - ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) - - # Transpose earlier so that the testing will not trigger transposing kernels - As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) - - out = torch.zeros((M, N), device='cuda', dtype=out_dtype) - - assert As_fp8.shape == (M, (K + 127) // - 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" - - deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.001 - - -def fp8_perm(m, idx): - if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8: - return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype) - else: - return m[idx, ...] - - -def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): - M, K = a.shape - - sorted_token_ids, m_indices, num_pad = moe_align_block_size( - topk_ids, block_m, num_groups, None, pad_sorted_ids=True) - - num_tokens = topk * M - - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) - inv_perm = torch.argsort(sorted_token_ids)[:M * topk] - - a = fp8_perm(a, sorted_token_ids // topk) - if a_s is not None: - a_s = a_s[sorted_token_ids // topk] - - return a, a_s, m_indices, inv_perm - - -def _moe_unpermute(out, inv_perm, topk, K, topk_weight): - M = topk_weight.shape[0] - out = out[inv_perm, ...] - tmp_out = out.view(-1, topk, K) - return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - -def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, - block_shape): - """Fused moe with block-wise quantization using DeepGemm grouped gemm.""" - num_groups = w1.shape[0] - M, K = a.shape - N = w2.shape[-1] - - topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() - - _, block_k = block_shape[0], block_shape[1] - - a_q, a_s = per_token_group_quant_fp8(a, block_m) - - a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, - num_groups, topk, block_m) - - inter_out = torch.zeros((a_q.shape[0], N * 2), - dtype=torch.bfloat16, - device=a.device) - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), - inter_out, m_indices) - - act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) - - out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - - final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) - - return final_out - - -@pytest.mark.parametrize( - "M,N,K,E,topk,block_size,dtype,seed", - itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES, - SEEDS)) -@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, - dtype, seed): - - if topk > E: - pytest.skip(f"Skipping test: topk={topk} > E={E}") - - if not _valid_deep_gemm_shape(M, N, K): - pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") - - vllm_config = VllmConfig() - - torch.manual_seed(seed) - fp8_info = torch.finfo(torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - - a = torch.randn((M, K), dtype=dtype) / 10 - - w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * - fp8_max).clamp(min=fp8_min, max=fp8_max) - - w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * - fp8_max).clamp(min=fp8_min, max=fp8_max) - - score = torch.randn((M, E), dtype=dtype) - - block_n, block_k = block_size[0], block_size[1] - n_tiles_w1 = ((2 * N) + block_n - 1) // block_n - k_tiles_w1 = (K + block_k - 1) // block_k - n_tiles_w2 = (K + block_n - 1) // block_n - k_tiles_w2 = (N + block_k - 1) // block_k - - w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) - w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - - w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) - w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - - w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() - - assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) - assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] - - for i in range(E): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) - - # Set the context to avoid lots of warning spam. - with set_current_vllm_config(vllm_config): - if M >= 128: - ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, - score, topk, block_size) - else: - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, - topk, block_size) - - topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) - - out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) - - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) - assert rel_diff < 0.03 diff --git a/tests/kernels/test_pplx_moe.py b/tests/kernels/test_pplx_moe.py deleted file mode 100644 index 97fc74e3bd3..00000000000 --- a/tests/kernels/test_pplx_moe.py +++ /dev/null @@ -1,654 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Tests for the MOE layers. - -Run `pytest tests/kernels/test_pplx_moe.py`. -""" -import dataclasses -import os -import pytest -import torch -import traceback - -from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] -from typing import Callable, Concatenate, Optional, ParamSpec, Tuple - -from pplx_kernels import AllToAll -from pplx_kernels.nvshmem import ( - nvshmem_alloc_empty_unique_id, - nvshmem_finalize, - nvshmem_get_unique_id, - nvshmem_init, -) - -import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, - torch_moe, torch_moe_single) -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, moe_align_block_size) -from vllm.platforms import current_platform - -from vllm.model_executor.layers.activation import SiluAndMul - -from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts, BatchedDispatchCombine, BatchedExperts, fused_experts -from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel -from vllm.model_executor.layers.fused_moe.pplx_dispatch_combine import PplxDispatchCombine - -NUM_EXPERTS = [8, 64] -EP_SIZE = [1, 4] -TOP_KS = [2, 6] - -P = ParamSpec("P") - -require_multi_node = pytest.mark.skipif( - "MASTER_ADDR" not in os.environ, - reason="Requires multi-node environment", -) - - -@dataclasses.dataclass -class ProcessGroupInfo: - world_size: int - world_local_size: int - rank: int - node_rank: int - local_rank: int - device: torch.device - - -def _worker_parallel_launch( - local_rank: int, - world_size: int, - world_local_size: int, - node_rank: int, - init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - rank = node_rank * world_local_size + local_rank - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - torch.distributed.init_process_group( - backend="cpu:gloo,cuda:nccl", - init_method=init_method, - rank=rank, - world_size=world_size, - device_id=device, - ) - barrier = torch.tensor([rank], device=device) - torch.distributed.all_reduce(barrier) - - try: - worker( - ProcessGroupInfo( - world_size=world_size, - world_local_size=world_local_size, - rank=rank, - node_rank=node_rank, - local_rank=local_rank, - device=device, - ), - *args, - **kwargs, - ) - except Exception as ex: - print(ex) - traceback.print_exception(ex) - raise - finally: - torch.distributed.destroy_process_group() - - -def parallel_launch( - world_size: int, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - assert not kwargs - spawn( - _worker_parallel_launch, - args=( - world_size, - world_size, - 0, - "tcp://localhost:29500", - worker, - ) - + args, - nprocs=world_size, - join=True, - ) - - -def parallel_launch_from_env( - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - """ - Launches a worker function in parallel across all processes in the current - environment. The environment must have the following variables set: - - WORLD_SIZE: The total number of processes. - - WORLD_LOCAL_SIZE: The number of processes on the current node. - - NODE_RANK: The rank of the current - - MASTER_ADDR: The address of the master process. - - MASTER_PORT: The port of the master process. - """ - assert not kwargs - world_size = int(os.environ["WORLD_SIZE"]) - world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) - node_rank = int(os.environ["NODE_RANK"]) - assert "MASTER_ADDR" in os.environ - assert "MASTER_PORT" in os.environ - spawn( - _worker_parallel_launch, - args=( - world_size, - world_local_size, - node_rank, - "env://", - worker, - ) - + args, - nprocs=world_local_size, - join=True, - ) - - -def torch_dispatch( - a: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - max_num_tokens: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - assert topk_ids.dim() == 2 - assert topk_ids.shape[0] == a.shape[0] - - num_tokens = a.shape[0] - topk = topk_ids.shape[1] - - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - if max_num_tokens is None: - max_num_tokens = tokens_per_expert.max() - - b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), - dtype=a.dtype, device=a.device) - - #print(f"b_a shape {b_a.shape}") - - token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) - - for token in range(num_tokens): - for j in range(topk): - expert_id = topk_ids[token, j] - idx = token_counts[expert_id] - b_a[expert_id, idx:idx+1, :] = a[token, :] - token_counts[expert_id] = token_counts[expert_id] + 1 - - return b_a, tokens_per_expert - - -def torch_combine(b_out, topk_weight, topk_ids): - num_tokens, topk = topk_ids.shape - num_experts = b_out.shape[0] - K = b_out.shape[-1] - out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) - for token in range(num_tokens): - expert_ids = topk_ids[token] - for i in range(expert_ids.numel()): - expert_id = expert_ids[i] - idx = expert_counts[expert_id] - out[token, :] = out[token, :] + b_out[expert_id, idx:idx+1, :] * topk_weight[token, i] - expert_counts[expert_id] = expert_counts[expert_id] + 1 - - return out - - -def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): - num_experts = w1.shape[0] - b_a, tokens_per_expert = torch_dispatch(a, topk_ids, num_experts) - assert b_a.dim() == 3 - num_tokens, topk = topk_ids.shape - _, max_num_tokens, K = b_a.shape - assert num_experts == b_a.shape[0] and K == w2.shape[1] - out = torch.zeros((num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device) - tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device) - for expert in range(num_experts): - num = tokens_per_expert[expert] - if num > 0: - torch.ops._C.silu_and_mul(tmp[:num], b_a[expert,:num,:] @ w1[expert].transpose(0, 1)) - out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) - - return torch_combine(out, topk_weight, topk_ids) - - -# TODO: same as torch_moe but with fused_topk factored out. -def torch_moe2(a, w1, w2, topk_weight, topk_ids): - M, K = a.shape - topk = topk_ids.shape[1] - a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) - out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - num_experts = w1.shape[0] - for i in range(num_experts): - mask = (topk_ids == i).view(-1) - if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - - return (out.view(M, -1, w2.shape[1]) * - topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) - - -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_moe_batched_experts( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, -): - current_platform.seed_everything(7) - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) - - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - - if True: - triton_output = torch_batched_moe(a, - w1, - w2, - topk_weight, - topk_ids) - else: - b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e) - triton_output = fused_batched_experts( - b_a, - w1, - w2, - topk_weight, - topk_ids, - global_num_experts=e - ) - - if False: - torch.set_printoptions(profile="full") - print("BASELINE") - print(torch_output) - print("OUTPUT") - print(triton_output) - - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) - - -def rank_chunk(num, r, w): - rem = num % w - return (num // w) + (1 if r < rem else 0) - - -def chunk_by_rank(t, r, w): - chunk = rank_chunk(t.shape[0], r, w) - #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") - return t[(r * chunk):(r + 1)*chunk] - - -def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): - assert torch.cuda.current_device() == pgi.local_rank - - num_tokens, hidden_dim = a.shape - num_experts = w1.shape[0] - block_size = 128 - device = pgi.device - rank = pgi.rank - world_size = pgi.world_size - rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = num_tokens - #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}") - - ata = AllToAll( - max_num_tokens=max_num_tokens, - num_experts=num_experts, - experts_per_token=topk, - rank=rank, - world_size=pgi.world_size, - dp_size=dp_size, - hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=( - 0 - if a.dtype.itemsize != 1 - else ( - (hidden_dim + block_size - 1) - // block_size - * torch.float32.itemsize - ) - ), - ) - - dispatch_combine = PplxDispatchCombine( - ata, - max_num_tokens, - pgi.world_size, - dp_size, - rank, - a.dtype, - ) - - a_chunk = chunk_by_rank(a, rank, world_size).to(device) - score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - - #print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}") - - b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( - a_chunk, - None, - None, - chunk_topk_ids, - num_experts, # store at PplxDispatchCombine creation? - None - ) - - #topk_weight, topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) - - torch.distributed.all_reduce(tokens_per_expert) - #max_num = tokens_per_expert.max() - tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32) - - #print(f"tpe {tokens_per_expert}") - #print(f"ent {expert_num_tokens}") - - #torch.set_printoptions(profile="full") - #torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX) - #torch.distributed.broadcast(naive_b_a, src=rank) - - #naive_b_a = chunk_by_rank(naive_b_a, rank, world_size) - - #print("b_a", b_a.shape, b_a) #[:, :naive_b_a.shape[1]]) - #print("naive_b_a", naive_b_a.shape, naive_b_a) - - torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0) - #torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0) - - b_a = b_a * 1.5 - - out = torch.full( - (rank_num_tokens * world_size, hidden_dim), - torch.nan, - dtype=a.dtype, - device=device, - ) - - dispatch_combine.combine( - out, - b_a, - chunk_topk_weight, - chunk_topk_ids, - ) - torch.cuda.synchronize() - - ata.destroy() - - #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}") - - #torch.distributed.all_reduce(out) - - #print(f"AR OUT {rank}: {out.shape} {out}") - - return out[:rank_num_tokens] - - -def _pplx_dispatch_combine( - pgi: ProcessGroupInfo, - dp_size: int, - m, n, k, e, - #a: torch.Tensor, - #w1: torch.Tensor, - #w2: torch.Tensor, - #score: torch.Tensor, - topk: int, - dtype: torch.dtype, -): - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) - device = pgi.device - - a = torch.randn((m, k), device=device, dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 - score = torch.randn((m, e), device=device, dtype=dtype) - - #m, k = a.shape - #e, _, n = w2.shape - - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - #print(f"a {a.shape}") - a_rep = torch.repeat_interleave(a, topk, dim=0) - #print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}") - - torch_output = (a_rep.view(-1, topk, k) * 1.5 * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) - - #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") - - pplx_output = torch_pplx_dispatch_combine(pgi, - dp_size, - a, - w1, - w2, - score, - topk) - - if False: - torch.set_printoptions(profile="full") - print("BASELINE") - print(torch_output) - print("OUTPUT") - print(pplx_output) - - torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) - - torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) - - nvshmem_finalize() - - -@pytest.mark.parametrize("m", [4, 32, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) -def test_pplx_dispatch_combine( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, - world_dp_size: Tuple[int, int], -): - current_platform.seed_everything(7) - world_size, dp_size = world_dp_size - - parallel_launch( - #world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype - world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype - ) - - -def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): - assert torch.cuda.current_device() == pgi.local_rank - - num_tokens, hidden_dim = a.shape - num_experts = w1.shape[0] - block_size = 128 - device = pgi.device - rank = pgi.rank - world_size = pgi.world_size - rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = num_tokens - - ata = AllToAll( - max_num_tokens=max_num_tokens, - num_experts=num_experts, - experts_per_token=topk, - rank=rank, - world_size=pgi.world_size, - dp_size=dp_size, - hidden_dim=hidden_dim, - hidden_dim_bytes=hidden_dim * a.dtype.itemsize, - hidden_dim_scale_bytes=( - 0 - if a.dtype.itemsize != 1 - else ( - (hidden_dim + block_size - 1) - // block_size - * torch.float32.itemsize - ) - ), - ) - - w1 = w1.to(device) - w2 = w2.to(device) - - dispatch_combine = PplxDispatchCombine( - ata, - max_num_tokens, - pgi.world_size, - dp_size, - rank, - a.dtype, - ) - - experts = BatchedExperts(rank, pgi.world_size, max_num_tokens) - - fused_experts = FusedMoEModularKernel( - dispatch_combine, - experts, - ) - - a_chunk = chunk_by_rank(a, rank, world_size).to(device) - score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False) - - #print(f"chunk_topk_ids {rank} {chunk_topk_ids.shape} {chunk_topk_ids.view(-1)}") - - out = fused_experts( - a_chunk, - # Chunking weights like this only works for batched format - chunk_by_rank(w1, rank, world_size), - chunk_by_rank(w2, rank, world_size), - #w1, - #w2, - chunk_topk_weight, - chunk_topk_ids, - global_num_experts=num_experts #? num_local_experts? - ) - - torch.cuda.synchronize() - - ata.destroy() - - #print(f"OUT {rank}: {out.shape} {out}") - - return out[:rank_num_tokens] - - -def _pplx_moe( - pgi: ProcessGroupInfo, - dp_size: int, - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - score: torch.Tensor, - topk: int, - dtype: torch.dtype, -): - uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) - - m, k = a.shape - e, _, n = w2.shape - - torch.set_printoptions(profile="full") - - vllm_config = VllmConfig() - with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - #print(f"topk_ids {pgi.rank} {topk_ids.shape} {topk_ids.view(-1)}") - - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - - pplx_output = torch_pplx_moe(pgi, - dp_size, - a, - w1, - w2, - score, - topk) - - if False: - print("BASELINE") - print(torch_output) - print("OUTPUT") - print(pplx_output) - - torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) - - #print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}") - - torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) - - nvshmem_finalize() - - -# TODO: M == 1 doesn't work -@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) #, 1024 * 128]) -@pytest.mark.parametrize("n", [128, 1024])# , 2048]) -@pytest.mark.parametrize("k", [128, 512]) # , 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) -def test_pplx_moe( - m: int, - n: int, - k: int, - e: int, - topk: int, - dtype: torch.dtype, - world_dp_size: Tuple[int, int], -): - current_platform.seed_everything(7) - world_size, dp_size = world_dp_size - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) - - parallel_launch( - world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, dtype - #world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype - ) - diff --git a/vllm/forward_context.py b/vllm/forward_context.py index ded15df1f94..c1aa875c7a2 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -98,16 +98,15 @@ def set_forward_context(attn_metadata: Any, from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) #TODO device? - max_tokens_across_dp = torch.max(num_tokens_tensor) #.to(device="cuda") + max_tokens_across_dp = torch.max( + num_tokens_tensor) #.to(device="cuda") cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_rank_num_tokens = torch.tensor( [num_tokens], dtype=torch.uint32, device=vllm_config.device_config.device) - dp_metadata = DPMetadata(max_tokens_across_dp, - num_tokens_tensor, - cu_tokens_across_dp_cpu, - dp_rank_num_tokens) + dp_metadata = DPMetadata(max_tokens_across_dp, num_tokens_tensor, + cu_tokens_across_dp_cpu, dp_rank_num_tokens) global _forward_context prev_context = _forward_context diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index a694c53d9f3..266ba3bfa07 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -134,9 +134,7 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) - self.activation(activation, - workspace2, - workspace1.view(-1, N)) + self.activation(activation, workspace2, workspace1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 907670cbb7b..be700f7b2e9 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -7,24 +7,24 @@ import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.fused_moe import ( - get_config_dtype_str, - try_get_optimal_moe_config, -) + get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache + @triton.jit -def batched_silu_and_mul_kernel(output, # [E, MAX_NUM_TOKENS, D] - input, # [E, MAX_NUM_TOKENS, D * 2] - expert_num_tokens, # [E] - stride_oe, - stride_om, - stride_ie, - stride_im, - compute_type: tl.constexpr, - D, - BLOCK_M: tl.constexpr, - BLOCK_D: tl.constexpr): +def batched_silu_and_mul_kernel( + output, # [E, MAX_NUM_TOKENS, D] + input, # [E, MAX_NUM_TOKENS, D * 2] + expert_num_tokens, # [E] + stride_oe, + stride_om, + stride_ie, + stride_im, + compute_type: tl.constexpr, + D, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr): expert_id = tl.program_id(axis=0) e_num_tokens = tl.load(expert_num_tokens + expert_id) @@ -57,50 +57,53 @@ def batched_silu_and_mul_kernel(output, # [E, MAX_NUM_TOKENS, D] mask_D = offs_D < (D - (d * BLOCK_D)) mask_tile = mask_m & mask_D - x_tile = tl.load(cta_input_ptrs, mask=mask_tile, other=0.0).to(dtype=tl.float32) + x_tile = tl.load(cta_input_ptrs, mask=mask_tile, + other=0.0).to(dtype=tl.float32) y_tile = tl.load(cta_input_ptrs + D, mask=mask_tile, other=0.0) # silu and mul - out_tile = (x_tile * (1.0 / (1.0 + tl.exp(-x_tile)))).to(dtype=compute_type) + out_tile = (x_tile * (1.0 / + (1.0 + tl.exp(-x_tile)))).to(dtype=compute_type) out_tile = out_tile * y_tile tl.store(cta_output_ptrs, out_tile, mask=mask_tile) cta_input_ptrs = cta_input_ptrs + BLOCK_D cta_output_ptrs = cta_output_ptrs + BLOCK_D + @triton.jit def moe_mmk( - a_ptrs, - b_ptrs, - K, - expert_id, - a_scale_ptr, - b_scale_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ak, - stride_bk, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Offsets and masks - offs_m, - offs_n, - mask_m, - # Block size for block-wise quantization - group_n: tl.constexpr, - group_k: tl.constexpr, - # Meta-parameters - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - compute_type: tl.constexpr, - use_w8a8: tl.constexpr, - use_w8a16: tl.constexpr): + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr): offs_k = tl.arange(0, BLOCK_K) @@ -131,12 +134,9 @@ def moe_mmk( # Load the next block of A and B, generate a mask by checking the # K dimension. a = tl.load(a_ptrs, - mask=mask_m[:, None] & - (offs_k[None, :] < K - k * BLOCK_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_K, + mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) # We accumulate along the K dimension. if use_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -177,41 +177,42 @@ def moe_mmk( @triton.jit -def expert_triton_kernel(a_ptr, #[max_tokens, K] - b_ptr, #[K, N] - c_ptr, #[max_tokens, N] - expert_id, - compute_type: tl.constexpr, - # Dimensions - M, - N, - K, - # Quantization data - a_scale_ptr, - b_scale_ptr, - b_zp_ptr, - # strides - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Blockwise quantization data - group_n, - group_k, - # Quantization schemes - use_fp8_w8a8: tl.constexpr, - use_int8_w8a16: tl.constexpr, - # Kernel config - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr): +def expert_triton_kernel( + a_ptr, #[max_tokens, K] + b_ptr, #[K, N] + c_ptr, #[max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) % N @@ -221,7 +222,6 @@ def expert_triton_kernel(a_ptr, #[max_tokens, K] a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn - accumulator = moe_mmk( a_ptrs, b_ptrs, @@ -261,48 +261,50 @@ def expert_triton_kernel(a_ptr, #[max_tokens, K] c_mask = mask_m[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) + @triton.jit -def batched_triton_kernel(a_ptr, # [E, max_num_tokens, K] - b_ptr, # [E, K, N] - c_ptr, # [E, max_num_tokens, N] - expert_num_tokens, # [E] - compute_type: tl.constexpr, - # Dimensions - max_num_tokens, - K, - N, - # Quantization data - a_scale_ptr, - b_scale_ptr, - b_zp_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ae, - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_ce, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Blockwise quantization data - group_n: tl.constexpr, - group_k: tl.constexpr, - # Quantization schemes - use_fp8_w8a8: tl.constexpr, - use_int8_w8a16: tl.constexpr, - # Kernel config - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr): +def batched_triton_kernel( + a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae, + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_ce, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr): expert_id = tl.program_id(axis=0) e_num_tokens = tl.load(expert_num_tokens + expert_id) if e_num_tokens == 0: @@ -310,7 +312,7 @@ def batched_triton_kernel(a_ptr, # [E, max_num_tokens, K] return pid_mn = tl.program_id(axis=1) - num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) pid_m = pid_mn // num_pid_n pid_n = pid_mn % num_pid_n @@ -326,58 +328,61 @@ def batched_triton_kernel(a_ptr, # [E, max_num_tokens, K] a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn - c_ptr = c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + cta_n_start * stride_cn - - expert_triton_kernel(a_ptr, - b_ptr, - c_ptr, - expert_id, - compute_type, - cta_m_size, # M - cta_n_size, # N - K, # K - a_scale_ptr, - b_scale_ptr, - b_zp_ptr, - # Strides - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Blockwise quantization data - group_n, - group_k, - # Quantization schemes - use_fp8_w8a8, - use_int8_w8a16, - # Kernel config - BLOCK_M, - BLOCK_N, - BLOCK_K) - - -def invoke_moe_batched_triton_kernel(A: torch.Tensor, # [E, max_tokens, K] - B: torch.Tensor, # [E, K, N] - C: torch.Tensor, # [E, max_tokens, N] - expert_num_tokens: torch.Tensor, # [E] - compute_type: tl.dtype, - # Quantization data - A_scale: torch.Tensor, - B_scale: torch.Tensor, - B_zp: torch.Tensor, - # Quantization schemes - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - config: dict[str, int], - block_shape: Optional[list[int]] = None): + c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + + cta_n_start * stride_cn) + + expert_triton_kernel( + a_ptr, + b_ptr, + c_ptr, + expert_id, + compute_type, + cta_m_size, # M + cta_n_size, # N + K, # K + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M, + BLOCK_N, + BLOCK_K) + + +def invoke_moe_batched_triton_kernel( + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: torch.Tensor, + B_scale: torch.Tensor, + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None): assert not use_int4_w4a16 max_num_tokens = A.size(1) @@ -389,53 +394,54 @@ def invoke_moe_batched_triton_kernel(A: torch.Tensor, # [E, max_tokens, K] BLOCK_K = config['BLOCK_SIZE_K'] assert max_num_tokens % BLOCK_M == 0 - grid = (expert_num_tokens.size(0), - triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.shape[1], BLOCK_N)) - - batched_triton_kernel[grid](A, - B, - C, - expert_num_tokens, - compute_type, - # Dimensions - max_num_tokens, - K, - N, - # Quantization data - A_scale, - B_scale, - B_zp, - # Strides - A.stride(0), - A.stride(1), - A.stride(2), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(0), - C.stride(1), - C.stride(2), - A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, - B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, - # Blockwise quantization data - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], - # Quantization schemes - use_fp8_w8a8, - use_int8_w8a16, - # Kernel config - BLOCK_M = BLOCK_M, - BLOCK_N = BLOCK_N, - BLOCK_K = BLOCK_K) - - -def invoke_batched_silu_and_mul(output : torch.Tensor, #[E, MAX_TOKENS, D] - input: torch.Tensor, #[E, MAX_TOKENS, D * 2] - expert_num_tokens: torch.Tensor): + grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * + triton.cdiv(B.shape[1], BLOCK_N)) + + batched_triton_kernel[grid]( + A, + B, + C, + expert_num_tokens, + compute_type, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + A_scale, + B_scale, + B_zp, + # Strides + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + # Blockwise quantization data + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + # Kernel config + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K) + +def invoke_batched_silu_and_mul( + output: torch.Tensor, #[E, MAX_TOKENS, D] + input: torch.Tensor, #[E, MAX_TOKENS, D * 2] + expert_num_tokens: torch.Tensor): num_experts = output.size(0) max_num_tokens = output.size(1) @@ -444,24 +450,19 @@ def invoke_batched_silu_and_mul(output : torch.Tensor, #[E, MAX_TOKENS, D] BLOCK_D = 1024 BLOCK_M = 1 - compute_tl_dtype = {torch.float16 : tl.float16, - torch.float32 : tl.float32, - torch.bfloat16 : tl.bfloat16}[output.dtype] + compute_tl_dtype = { + torch.float16: tl.float16, + torch.float32: tl.float32, + torch.bfloat16: tl.bfloat16 + }[output.dtype] #print(f"compute type {compute_tl_dtype}") grid = (num_experts, triton.cdiv(max_num_tokens, BLOCK_M)) - batched_silu_and_mul_kernel[grid](output, - input, - expert_num_tokens, - output.stride(0), - output.stride(1), - input.stride(0), - input.stride(1), - compute_tl_dtype, - D, - BLOCK_M, - BLOCK_D) + batched_silu_and_mul_kernel[grid](output, input, expert_num_tokens, + output.stride(0), output.stride(1), + input.stride(0), input.stride(1), + compute_tl_dtype, D, BLOCK_M, BLOCK_D) class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): @@ -621,8 +622,9 @@ def apply( if num > 0: tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation( - activation, tmp, hidden_states[expert, :num, :] - @ w1[expert].transpose(0, 1)) + activation, tmp, + hidden_states[expert, :num, :] @ w1[expert].transpose( + 0, 1)) out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out @@ -685,15 +687,15 @@ def apply( ) -> torch.Tensor: num_tokens = topk_ids.size(0) - #print_debug = expert_map[0] != -1 and num_tokens < 50 and num_tokens != 1 and False # Check constraints. if self.use_int4_w4a16: assert hidden_states.shape[-1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[-1] == w1.shape[ - 2], f"Hidden size mismatch {hidden_states.shape[-1]} != {w1.shape[2]}" + assert hidden_states.shape[-1] == w1.shape[2], \ + (f"Hidden size mismatch {hidden_states.shape[-1]} " + f"!= {w1.shape[2]}") assert hidden_states.is_contiguous( ), "Hidden_states must be contiguous" @@ -764,7 +766,7 @@ def apply( input=intermediate_cache1, expert_num_tokens=expert_num_tokens) - qintermediate_cache2 = intermediate_cache2 + #qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale # TODO (varun) : support w8a8 assert not self.use_fp8_w8a8 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ec3587973e1..ac834a89195 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1207,28 +1207,29 @@ def fused_experts(hidden_states: torch.Tensor, block_shape=block_shape) -def fused_experts_impl(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: +def fused_experts_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None) -> torch.Tensor: # Check constraints. if use_int4_w4a16: assert hidden_states.shape[1] // 2 == w1.shape[ @@ -1631,22 +1632,32 @@ def apply( intermediate_cache3 = _resize_cache(workspace13, (num_tokens, top_k_num, K)) - if hidden_states.dim() == 2: #block_m is None: + if hidden_states.dim() == 2: #block_m is None: sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size( - topk_ids, - config['BLOCK_SIZE_M'], - global_num_experts, expert_map - )) + moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) else: max_num_tokens = hidden_states.shape[1] - sorted_token_ids = torch.arange(0, hidden_states.shape[0] * max_num_tokens, device=hidden_states.device, dtype=torch.int) + sorted_token_ids = torch.arange(0, + hidden_states.shape[0] * + max_num_tokens, + device=hidden_states.device, + dtype=torch.int) sorted_token_ids = sorted_token_ids.flatten() - expert_ids = torch.arange(0, global_num_experts, device=hidden_states.device, dtype=torch.int) - expert_ids = torch.repeat_interleave(expert_ids, max_num_tokens, dim=0) + expert_ids = torch.arange(0, + global_num_experts, + device=hidden_states.device, + dtype=torch.int) + expert_ids = torch.repeat_interleave(expert_ids, + max_num_tokens, + dim=0) print(f"EXPERT_IDS {expert_ids}") - #num_tokens_post_padded = torch.tensor([num_tokens], device=hidden_states.device, dtype=torch.int32) - num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int32) + #num_tokens_post_padded = torch.tensor([num_tokens], + # device=hidden_states.device, + # dtype=torch.int32) + num_tokens_post_padded = torch.zeros(1, + device=hidden_states.device, + dtype=torch.int32) num_tokens_post_padded.fill_(max_num_tokens) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) #print(f"P = {sorted_token_ids}, {hidden_states.shape}") @@ -1705,170 +1716,6 @@ def apply( return intermediate_cache3 -class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, - world_size: int, - rank: int): - super().__init__() - self.world_size = world_size - self.rank = rank - - def dispatch( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - assert topk_ids.dim() == 2 - assert topk_ids.shape[0] == a1.shape[0] - - num_tokens = a1.shape[0] - topk = topk_ids.shape[1] - - tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - max_num_tokens = tokens_per_expert.max() - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=a1.device) - - b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), - dtype=a1.dtype, device=a1.device) - - #print(f"START DISPATCH {hex(id(self))}") - - for token in range(num_tokens): - for j in range(topk): - expert_id = topk_ids[token, j] - idx = expert_counts[expert_id] - b_a1[expert_id, idx:idx+1, :] = a1[token, :] - expert_counts[expert_id] = expert_counts[expert_id] + 1 - - #print(f"END DISPATCH {hex(id(self))}: tokens_per_expert {(tokens_per_expert > 0).nonzero().view(-1)}") - - return b_a1, a1_scale, tokens_per_expert - - def combine( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ) -> None: - if False: - print(f"topk_ids {topk_ids.shape}") - print(f"fused_expert_output {fused_expert_output.shape}") - print(f"output {output.shape}") - print(f"counts {self.expert_counts.shape}") - - #print(f"START COMBINE {hex(id(self))}") - - num_tokens, topk = topk_ids.shape - num_experts, _, K = fused_expert_output.shape - expert_counts = torch.zeros(num_experts, dtype=torch.int, device=fused_expert_output.device) - for token in range(num_tokens): - expert_ids = topk_ids[token] - for i in range(topk_ids.shape[1]): - expert_id = expert_ids[i] - if expert_id < num_experts: - idx = expert_counts[expert_id] - output[token, :] = output[token, :] + fused_expert_output[expert_id, idx:idx+1, :] * topk_weights[token, i] - expert_counts[expert_id] = expert_counts[expert_id] + 1 - - #print(f"END COMBINE {hex(id(self))}") - - -def rank_chunk(num, r, w): - rem = num % w - return (num // w) + (1 if r < rem else 0) - - -class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): - - def __init__( - self, - rank: int = 0, - world_size: int = 1, - max_num_tokens: Optional[int] = None, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - block_shape: Optional[List[int]] = None, - block_m: Optional[int] = None, - ): - super().__init__() - assert not use_fp8_w8a8 - assert not use_int4_w4a16 - assert not use_int8_w8a16 - assert block_shape is None - assert block_m is None - self.max_num_tokens = max_num_tokens - self.rank = rank - self.world_size = world_size - - def workspace_shapes( - self, - a: torch.Tensor, - M: int, - N: int, - K: int, - topk: int, - num_experts: int, - ) -> Tuple[int, int, torch.dtype]: - #assert self.max_num_tokens >= a.shape[1] - max_num_tokens = a.shape[1] if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack - workspace2 = max_num_tokens * N - return (workspace13, workspace2, a.dtype) - - def apply( - self, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: - #print("START EXPERTS") - assert hidden_states.dim() == 3 - assert expert_num_tokens is not None - num_tokens, topk = topk_ids.shape - _, tmp_max_num_tokens, K = hidden_states.shape - max_num_tokens = tmp_max_num_tokens if self.max_num_tokens is None else self.max_num_tokens - #print(f"global_num_experts = {global_num_experts}") - num_experts = global_num_experts - out = _resize_cache(workspace13, (num_experts, max_num_tokens, w2.shape[1])) - num_local_experts = expert_num_tokens.numel() - #print(f"shapes = {hidden_states.shape}, {w1.shape}, {w2.shape}, {out.shape} {expert_num_tokens.shape} {workspace2.shape} {num_experts}") - - # TODO: don't need world_size or rank if expert_base always == 0 - #assert w1.shape[0] == num_experts, f"{w1.shape} == {num_experts}" - #expert_base = rank_chunk(w1.shape[0], self.rank, self.world_size) * self.rank - expert_base = 0 - - for expert in range(num_local_experts): # num_experts - num = expert_num_tokens[expert] - assert num <= max_num_tokens, f"{num}, {max_num_tokens}" - #print(f"{type(num)}, {num}, {max_num_tokens}") - if num > 0: - tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) - self.activation(activation, tmp, hidden_states[expert,:num,:] @ w1[expert_base + expert].transpose(0, 1)) - out[expert, :num, :] = tmp @ w2[expert_base + expert].transpose(0, 1) - - return out - - def modular_triton_fused_moe( use_fp8_w8a8: bool, use_int8_w8a8: bool, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6147b17127f..e9023793698 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,9 +29,10 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine - from .fused_moe import TritonExperts, fused_experts from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts - from .modular_kernel import FusedMoEModularKernel, FusedMoEQuantizeDispatchCombine + from .fused_moe import TritonExperts, fused_experts + from .modular_kernel import (FusedMoEModularKernel, + FusedMoEQuantizeDispatchCombine) from .pplx_dispatch_combine import PplxDispatchCombine else: fused_experts = None # type: ignore @@ -80,7 +81,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + def set_dispatch_combine( + self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: return False @abstractmethod @@ -241,29 +243,31 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input) # Maybe extra args - def set_dispatch_combine(self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + def set_dispatch_combine( + self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: assert self.fused_experts == fused_experts - block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) + #block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) - if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): - logger.info(f"BatchedTritonExperts {self.moe}") + if isinstance(dispatch_combine, + (BatchedDispatchCombine, PplxDispatchCombine)): + logger.info("BatchedTritonExperts %s", self.moe) experts = BatchedTritonExperts( - use_fp8_w8a8 = False, - use_int8_w8a8 = False, - use_int8_w8a16 = False, - use_int4_w4a16 = False, - block_shape = None, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, ) else: - logger.info(f"TritonExperts {self.moe}") + logger.info("TritonExperts %s", self.moe) experts = TritonExperts( - use_fp8_w8a8 = False, - use_int8_w8a8 = False, - use_int8_w8a16 = False, - use_int4_w4a16 = False, - block_shape = None, - per_channel_quant = False, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + per_channel_quant=False, ) self.fused_experts = FusedMoEModularKernel( @@ -598,8 +602,8 @@ def __init__( dp_rank=self.dp_rank, ep_size=self.ep_size, ep_rank=self.ep_rank, - in_dtype = params_dtype, # this is probably not right, where to get? - out_dtype = params_dtype, # ditto. + in_dtype=params_dtype, # this is probably not right, where to get? + out_dtype=params_dtype, # ditto. ) # Note: get_quant_method will look at the layer's local_num_experts @@ -618,46 +622,41 @@ def __init__( # TODO: move to method? if self.dp_size > 1: logger.info("using pplx dispatch") - max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size + max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size world_size = moe.ep_size - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank all_to_all = get_all_to_all( max_num_tokens=max_num_tokens, num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk + experts_per_token=moe.experts_per_token, # topk rank=rank, world_size=world_size, dp_size=dp_size, hidden_dim=moe.hidden_dim, hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to ceil_div(hidden_dim, block_size) * sizeof(float32) + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=( - 0 - if moe.in_dtype.itemsize != 1 - else ( - (moe.hidden_dim + moe.block_size - 1) - // moe.block_size - * torch.float32.itemsize - ) - ) - ) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( + (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize))) dispatch_combine = PplxDispatchCombine( all_to_all, max_num_tokens, world_size, dp_size, - rank, # just for debugging + rank, # just for debugging moe.in_dtype, ) elif True: logger.info("using standard dispatch") dispatch_combine = StandardDispatchCombine( moe.in_dtype, - quant_config.weight_block_size if quant_config is not None else None, + quant_config.weight_block_size + if quant_config is not None else None, ) else: logger.info("using batched dispatch") @@ -668,7 +667,8 @@ def __init__( success = self.quant_method.set_dispatch_combine(dispatch_combine) if not success: - logger.warning("DP+EP not supported for %s.", type(self.quant_method)) + logger.warning("DP+EP not supported for %s.", + type(self.quant_method)) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -1019,12 +1019,14 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, ctx = get_forward_context() max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp - cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu + #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp #In this function we define two ranges: - # 1. chunk_range - The current iteration of the loops's range over the DP world tokens - # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP rank owns. + # 1. chunk_range - The current iteration of the loops's range over the + # DP world tokens + # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP + # rank owns. moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size @@ -1072,8 +1074,11 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, # TODO: needed for non-pplx? if False and self.dp_size > 1: - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[ - self.dp_rank - 1] + if self.dp_rank == 0: + start = 0 + else: + start = cu_tokens_across_dp_this_iter[self.dp_rank - 1] + end = cu_tokens_across_dp_this_iter[self.dp_rank] all_hidden_states = get_dp_group().all_reduce( @@ -1081,7 +1086,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states = all_hidden_states[start:end, :] # TODO: needed for non-pplx? - if False and self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + if False and self.reduce_results and (self.tp_size > 1 + or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -1156,7 +1162,8 @@ def forward_impl(self, hidden_states: torch.Tensor, final_hidden_states = all_hidden_states[start:end, :] # TODO: needed for non-pplx? - if False and self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): + if False and self.reduce_results and (self.tp_size > 1 + or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index d550c8b040c..eec5a7406d9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -67,8 +67,8 @@ def _moe_problem_size( M = a1.shape[0] else: assert a1.dim() == 3 - assert E == a1.shape[0] - M = a1.shape[1] # This is max_num_tokens + assert a1.shape[0] == E + M = a1.shape[1] # This is max_num_tokens assert topk_ids.dim() == 2 topk = topk_ids.shape[1] diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 576c454ec31..420a81f3f5c 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -9,6 +9,11 @@ moe_kernel_quantize_input) +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + # Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. @@ -97,7 +102,7 @@ def dispatch( ) # This argument is optional, defaults to indices.shape[0] - num_tokens = a1.shape[0] # M + num_tokens = a1.shape[0] # M bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) # TODO: optimize this? @@ -123,8 +128,9 @@ def combine( apply_router_weight_on_input: bool, ) -> None: # This argument is optional - num_tokens = output.shape[0] # M - bound_m = torch.tensor([num_tokens], dtype=torch.uint32, + num_tokens = output.shape[0] # M + bound_m = torch.tensor([num_tokens], + dtype=torch.uint32, device=fused_expert_output.device) assert output.shape[0] <= self.max_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index e85f3514160..0d0212b7591 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -1,36 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 -import importlib.util from typing import List, Optional, Tuple import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts, - _valid_deep_gemm_shape, - _valid_deep_gemm, -) + DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExpert + class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - block_shape: Optional[List[int]] = None, - block_m: Optional[int] = None, - allow_deep_gemm: bool = False - ): + def __init__(self, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]] = None, + block_m: Optional[int] = None, + allow_deep_gemm: bool = False): super().__init__() - self.triton_expert = TritonExpert( - use_fp8_w8a8, - use_int4_w4a16, - use_int8_w8a16, - block_shape, - block_m - ) + self.triton_expert = TritonExpert(use_fp8_w8a8, use_int4_w4a16, + use_int8_w8a16, block_shape, block_m) self.deep_gemm_expert = DeepGemmExperts() self.allow_deep_gemm = allow_deep_gemm self.use_fp8_w8a8 = use_fp8_w8a8 @@ -48,9 +38,11 @@ def workspace_shapes( # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): - return self.deep_gemm_expert.workspace_shapes(a, M, N, K, topk, num_experts) + return self.deep_gemm_expert.workspace_shapes( + a, M, N, K, topk, num_experts) else: - return self.triton_expert.workspace_shapes(a, M, N, K, topk, num_experts) + return self.triton_expert.workspace_shapes(a, M, N, K, topk, + num_experts) def apply( self, @@ -73,7 +65,7 @@ def apply( ) -> torch.Tensor: N = w1.shape[1] if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 - and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): return self.deep_gemm_expert( hidden_states, w1, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a88827dacce..9fabf4edadf 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -10,8 +10,8 @@ from torch.nn.parameter import Parameter import vllm.envs as envs -from vllm import _custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, @@ -431,7 +431,6 @@ def __init__(self, quant_config: Fp8Config): from vllm.model_executor.layers.fused_moe import fused_experts self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None - self.allow_deep_gemm = allow_deep_gemm # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization @@ -779,21 +778,24 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w2_input_scale # Maybe extra args - def set_dispatch_combine(self, dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: + def set_dispatch_combine( + self, + dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: + from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) + if self.use_marlin: return False - from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import TritonOrDeepGemmExperts - #block_m = MOE_DP_CHUNK_SIZE * (moe.ep_size // moe.dp_size) #print(f"block_m = {block_m}") experts = TritonOrDeepGemmExperts( - use_fp8_w8a8 = True, - use_int8_w8a16 = False, - use_int4_w4a16 = False, - block_shape = self.quant_config.weight_block_size, - block_m = None, # TODO + use_fp8_w8a8=True, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=self.quant_config.weight_block_size, + block_m=None, # TODO allow_deep_gemm=self.allow_deep_gemm, ) @@ -851,8 +853,8 @@ def apply( else: return self.fused_experts( hidden_states=x, - layer.w13_weight, - layer.w2_weight, + w1=layer.w13_weight, + w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, From 7f0973872e2a0cc1b0cbca3a13e2b62c047fb312 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Apr 2025 23:19:29 +0000 Subject: [PATCH 160/190] more lint stuff Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 5 +++++ .../layers/fused_moe/triton_deep_gemm_moe.py | 10 +++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e9023793698..624a91ac025 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -32,6 +32,7 @@ from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts from .modular_kernel import (FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, FusedMoEQuantizeDispatchCombine) from .pplx_dispatch_combine import PplxDispatchCombine else: @@ -249,6 +250,8 @@ def set_dispatch_combine( #block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) + experts: FusedMoEPermuteExpertsUnpermute = None + if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): logger.info("BatchedTritonExperts %s", self.moe) @@ -619,6 +622,8 @@ def __init__( assert quant_method is not None self.quant_method = quant_method + dispatch_combine: FusedMoEQuantizeDispatchCombine = None + # TODO: move to method? if self.dp_size > 1: logger.info("using pplx dispatch") diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 0d0212b7591..d24ae4768a6 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -6,21 +6,25 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) -from vllm.model_executor.layers.fused_moe.fused_moe import TritonExpert +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + per_channel_quant: bool, block_shape: Optional[List[int]] = None, block_m: Optional[int] = None, allow_deep_gemm: bool = False): super().__init__() - self.triton_expert = TritonExpert(use_fp8_w8a8, use_int4_w4a16, - use_int8_w8a16, block_shape, block_m) + self.triton_expert = TritonExperts(use_fp8_w8a8, use_int8_w8a8, + use_int4_w4a16, use_int8_w8a16, + per_channel_quant, block_shape, + block_m) self.deep_gemm_expert = DeepGemmExperts() self.allow_deep_gemm = allow_deep_gemm self.use_fp8_w8a8 = use_fp8_w8a8 From d42186fd24b60700a371b23e0285c3aed30061ba Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 02:26:57 +0000 Subject: [PATCH 161/190] add guards for pplx import Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 21 +++++++++++++++---- vllm/distributed/parallel_state.py | 12 +++++++---- vllm/model_executor/layers/fused_moe/layer.py | 12 ++++++++--- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index aeedadea385..ff45c0798cf 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -10,10 +10,16 @@ import pytest import torch -from pplx_kernels import AllToAll -from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) + +try: + from pplx_kernels import AllToAll + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) + has_pplx = False +except ImportError as ex: + has_pplx = False + from torch.multiprocessing import ( spawn) # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec @@ -45,6 +51,11 @@ reason="Requires multi-node environment", ) +requires_pplx = pytest.mark.skipif( + not has_pplx, + reason="Requires PPLX kernels", +) + @dataclasses.dataclass class ProcessGroupInfo: @@ -420,6 +431,7 @@ def _pplx_dispatch_combine( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) +@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.") def test_pplx_dispatch_combine( m: int, n: int, @@ -543,6 +555,7 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.") def test_pplx_moe( m: int, n: int, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index cf715681c87..c2bd6dba537 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -23,6 +23,7 @@ """ import contextlib import gc +import importlib import pickle import weakref from collections import namedtuple @@ -34,9 +35,6 @@ import torch import torch.distributed -from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) from torch.distributed import Backend, ProcessGroup import vllm.envs as envs @@ -920,7 +918,12 @@ def init_distributed_environment( @run_once def pplx_init(rank, world_size): - if world_size > 1: + has_pplx = importlib.util.find_spec("pplx_kernels") is not None + + if has_pplx and world_size > 1: + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, + nvshmem_init) try: global PPLX_DID_INIT logger.debug(f"PPLX_INIT {rank} {world_size}") @@ -940,6 +943,7 @@ def pplx_init(rank, world_size): def pplx_finalize(): global PPLX_DID_INIT if PPLX_DID_INIT: + from pplx_kernels.nvshmem import nvshmem_finalize nvshmem_finalize() diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 624a91ac025..a1ea38c79ae 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib import threading import weakref from abc import abstractmethod @@ -7,7 +8,6 @@ from enum import Enum from typing import Callable, List, Optional, Tuple -import pplx_kernels as pplx # TODO: guard this import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter @@ -27,6 +27,8 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op +has_pplx = importlib.util.find_spec("pplx_kernels") is not None + if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts @@ -34,7 +36,8 @@ from .modular_kernel import (FusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, FusedMoEQuantizeDispatchCombine) - from .pplx_dispatch_combine import PplxDispatchCombine + if has_pplx: + from .pplx_dispatch_combine import PplxDispatchCombine else: fused_experts = None # type: ignore if current_platform.is_tpu(): @@ -115,6 +118,9 @@ def __init__(self): self._lock = threading.RLock() # Reentrant lock for thread safety def get_or_create(self, **kwargs): + assert has_pplx + import pplx_kernels as pplx + # Create a hashable key from the kwargs key = tuple(sorted((k, v) for k, v in kwargs.items())) @@ -625,7 +631,7 @@ def __init__( dispatch_combine: FusedMoEQuantizeDispatchCombine = None # TODO: move to method? - if self.dp_size > 1: + if self.dp_size > 1 and has_pplx: logger.info("using pplx dispatch") max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size world_size = moe.ep_size From 764e6469791103ea470802d7dbfadd8ff5ce7d5c Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 30 Apr 2025 10:55:48 -0400 Subject: [PATCH 162/190] fix forward_chunked Signed-off-by: Varun Sundar Rabindranath Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 62 +++++-------------- 1 file changed, 15 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a1ea38c79ae..3e6e6ed9c93 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1024,40 +1024,16 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) + def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): - ctx = get_forward_context() - - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp - #cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu - num_tokens_across_dp = ctx.dp_metadata.num_tokens_across_dp - - #In this function we define two ranges: - # 1. chunk_range - The current iteration of the loops's range over the - # DP world tokens - # 2. my_tokens_in_chunk - The tokens within chunk_range that this DP - # rank owns. - - moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size - - num_tokens_remaining_across_dp = num_tokens_across_dp - chunk_start = 0 - chunk_end = min(moe_dp_chunk_size_per_rank, - full_hidden_states.shape[0]) full_final_hidden_states = torch.empty_like(full_hidden_states) - assert full_hidden_states.shape[0] == full_router_logits.shape[0] - - for iter in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): + def process_chunk(chunk_start, chunk_end, skip_result_store = False): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - cu_tokens_across_dp_this_iter = torch.cumsum( - num_tokens_remaining_across_dp.clamp( - max=moe_dp_chunk_size_per_rank), - dim=0) - # TODO: still may be needed for non-pplx, put into dispatcher class. if False: hidden_states = self.naive_multicast( @@ -1103,30 +1079,22 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor, final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - full_final_hidden_states[chunk_start:chunk_end, :].copy_( - final_hidden_states) - - # Update bounds - num_tokens_remaining_across_dp = torch.clamp( - num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, - min=0) + if not skip_result_store: + full_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states) - # HACK FIX - if num_tokens_remaining_across_dp.sum() == 0: - break + max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp + moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size - def update_chunk_bound(x: int): - return min(x + moe_dp_chunk_size_per_rank, - full_hidden_states.shape[0]) + num_tokens = full_hidden_states.size(0) + for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp) + # clamp start and end + chunk_start = min(chunk_start, num_tokens - 1) + chunk_end = min(chunk_end, num_tokens) - #chunk_start = update_chunk_bound(chunk_start) - #chunk_end = update_chunk_bound(chunk_end) - if chunk_end == full_hidden_states.shape[0]: - # simply redo computation - pass - else: - chunk_start = update_chunk_bound(chunk_start) - chunk_end = update_chunk_bound(chunk_end) + process_chunk(chunk_start, chunk_end, skip_result_store = chunk_start_ >= num_tokens) return full_final_hidden_states From 2a31f903581d0cf13de34e9fb47d18ceade699f1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 17:04:54 +0000 Subject: [PATCH 163/190] fix more lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 2 +- vllm/distributed/parallel_state.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index ff45c0798cf..9557758f0ed 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -17,7 +17,7 @@ nvshmem_finalize, nvshmem_get_unique_id, nvshmem_init) has_pplx = False -except ImportError as ex: +except ImportError: has_pplx = False from torch.multiprocessing import ( diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c2bd6dba537..6a3725b88c8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -922,16 +922,15 @@ def pplx_init(rank, world_size): if has_pplx and world_size > 1: from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, - nvshmem_init) + nvshmem_get_unique_id, nvshmem_init) try: global PPLX_DID_INIT - logger.debug(f"PPLX_INIT {rank} {world_size}") + logger.debug("PPLX_INIT %s %d", rank, world_size) uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() uid_gpu = uid.cuda() get_world_group().broadcast(uid_gpu, src=0) - logger.debug(f"PPLX_INIT UID={uid_gpu}") + logger.debug("PPLX_INIT UID = %s", uid_gpu) uid = uid_gpu.to(device='cpu') nvshmem_init(uid, rank, world_size) PPLX_DID_INIT = True @@ -944,6 +943,7 @@ def pplx_finalize(): global PPLX_DID_INIT if PPLX_DID_INIT: from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX finalize") nvshmem_finalize() From 6a3daba30956625c993e1f18d53eae581c986781 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 21:27:29 +0000 Subject: [PATCH 164/190] cleanups Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 48 +++++----- vllm/forward_context.py | 2 +- .../layers/fused_moe/deep_gemm_moe.py | 2 + vllm/model_executor/layers/fused_moe/layer.py | 94 ++++++++++--------- .../layers/fused_moe/pplx_dispatch_combine.py | 14 ++- .../layers/fused_moe/triton_deep_gemm_moe.py | 10 +- .../model_executor/layers/quantization/fp8.py | 7 -- 7 files changed, 90 insertions(+), 87 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 9557758f0ed..6dd028894b3 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -16,7 +16,7 @@ from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, nvshmem_finalize, nvshmem_get_unique_id, nvshmem_init) - has_pplx = False + has_pplx = True except ImportError: has_pplx = False @@ -46,11 +46,6 @@ P = ParamSpec("P") -require_multi_node = pytest.mark.skipif( - "MASTER_ADDR" not in os.environ, - reason="Requires multi-node environment", -) - requires_pplx = pytest.mark.skipif( not has_pplx, reason="Requires PPLX kernels", @@ -180,6 +175,9 @@ def torch_dispatch( tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + + assert tokens_per_expert.numel() == num_experts + if max_num_tokens is None: max_num_tokens = int(tokens_per_expert.max().item()) @@ -259,7 +257,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128]) +@pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -309,7 +307,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): rank = pgi.rank world_size = pgi.world_size rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = num_tokens + max_num_tokens = max(num_tokens, 1) ata = AllToAll.internode( max_num_tokens=max_num_tokens, @@ -350,22 +348,23 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): False, ) - naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, - num_experts) + if False: + naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, + num_experts) - torch.distributed.all_reduce(tokens_per_expert) - tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, - world_size).to(dtype=torch.int32) + torch.distributed.all_reduce(tokens_per_expert) + tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, + world_size).to(dtype=torch.int32) - torch.testing.assert_close(tokens_per_expert, - expert_num_tokens, - atol=0, - rtol=0) + torch.testing.assert_close(tokens_per_expert, + expert_num_tokens, + atol=0, + rtol=0) b_a = b_a * 1.5 out = torch.full( - (rank_num_tokens * world_size, hidden_dim), + (rank_num_tokens, hidden_dim), torch.nan, dtype=a.dtype, device=device, @@ -424,14 +423,15 @@ def _pplx_dispatch_combine( nvshmem_finalize() +# TODO: M < world_size doesn't appear to be supported by pplx? @pytest.mark.parametrize("m", [4, 32, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128? +@pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) -@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.") +@requires_pplx def test_pplx_dispatch_combine( m: int, n: int, @@ -502,11 +502,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): # Chunking weights like this only works for batched format chunk_by_rank(w1, rank, world_size), chunk_by_rank(w2, rank, world_size), - #w1, - #w2, chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_experts #? num_local_experts? + global_num_experts=num_experts ) torch.cuda.synchronize() @@ -547,7 +545,7 @@ def _pplx_moe( nvshmem_finalize() -# TODO: M == 1 doesn't work +# TODO: M < world_size doesn't appear to be supported by pplx? @pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @@ -555,7 +553,7 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) -@pytest.mark.skipif(not has_pplx, reason="PPLX kernels not available.") +@requires_pplx def test_pplx_moe( m: int, n: int, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c1aa875c7a2..2d6153095eb 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -97,7 +97,7 @@ def set_forward_context(attn_metadata: Any, dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) - #TODO device? + #TODO device? (tms) max_tokens_across_dp = torch.max( num_tokens_tensor) #.to(device="cuda") cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 266ba3bfa07..4a0fb374bd4 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import functools import importlib.util from typing import Optional, Tuple @@ -19,6 +20,7 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None +@functools.cache def deep_gemm_block_shape() -> list[int]: # Lazy import to avoid CUDA initialization problems. import deep_gemm as dg diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3e6e6ed9c93..84f3189c4d4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -63,8 +63,7 @@ class MoEConfig: ep_size: int ep_rank: int - in_dtype: torch.dtype - out_dtype: torch.dtype + in_dtype: torch.dtype # The activation type. # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 @@ -142,7 +141,6 @@ def get_all_to_all(**kwargs): return _all_to_all_cache.get_or_create(**kwargs) -#TODO: Every change in this class is a broken hack!! @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" @@ -249,18 +247,15 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) - # Maybe extra args def set_dispatch_combine( self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: assert self.fused_experts == fused_experts - #block_m = MOE_DP_CHUNK_SIZE * (self.moe.ep_size // self.moe.dp_size) - experts: FusedMoEPermuteExpertsUnpermute = None if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): - logger.info("BatchedTritonExperts %s", self.moe) + logger.debug("BatchedTritonExperts %s", self.moe) experts = BatchedTritonExperts( use_fp8_w8a8=False, use_int8_w8a8=False, @@ -269,7 +264,7 @@ def set_dispatch_combine( block_shape=None, ) else: - logger.info("TritonExperts %s", self.moe) + logger.debug("TritonExperts %s", self.moe) experts = TritonExperts( use_fp8_w8a8=False, use_int8_w8a8=False, @@ -611,8 +606,7 @@ def __init__( dp_rank=self.dp_rank, ep_size=self.ep_size, ep_rank=self.ep_rank, - in_dtype=params_dtype, # this is probably not right, where to get? - out_dtype=params_dtype, # ditto. + in_dtype=params_dtype, # TODO: is this right? ) # Note: get_quant_method will look at the layer's local_num_experts @@ -628,12 +622,42 @@ def __init__( assert quant_method is not None self.quant_method = quant_method - dispatch_combine: FusedMoEQuantizeDispatchCombine = None + dispatch_combine = self._construct_dispatch_combine( + moe, quant_config) + + success = self.quant_method.set_dispatch_combine(dispatch_combine) + + if not success: + logger.warning("DP+EP not supported for %s.", + type(self.quant_method)) + + self.apply_router_weight_on_input = apply_router_weight_on_input + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ + in ("GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod")): + moe_quant_params["intermediate_size_full"] = intermediate_size + + self.quant_method.create_weights(layer=self, **moe_quant_params) - # TODO: move to method? + # TODO: return Optional? + def _construct_dispatch_combine( + self, + moe: MoEConfig, + quant_config: Optional[QuantizationConfig], + ) -> FusedMoEQuantizeDispatchCombine: if self.dp_size > 1 and has_pplx: - logger.info("using pplx dispatch") - max_num_tokens = MOE_DP_CHUNK_SIZE # // moe.dp_size + logger.debug("using pplx dispatch") + max_num_tokens = MOE_DP_CHUNK_SIZE world_size = moe.ep_size dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank @@ -654,51 +678,28 @@ def __init__( (moe.hidden_dim + moe.block_size - 1) // moe.block_size * torch.float32.itemsize))) - dispatch_combine = PplxDispatchCombine( + return PplxDispatchCombine( all_to_all, max_num_tokens, world_size, dp_size, - rank, # just for debugging + rank, moe.in_dtype, ) elif True: - logger.info("using standard dispatch") - dispatch_combine = StandardDispatchCombine( + logger.debug("using standard dispatch") + return StandardDispatchCombine( moe.in_dtype, quant_config.weight_block_size if quant_config is not None else None, ) else: - logger.info("using batched dispatch") - dispatch_combine = BatchedDispatchCombine( + logger.debug("using batched dispatch") + return BatchedDispatchCombine( moe.ep_size, moe.ep_rank, ) - success = self.quant_method.set_dispatch_combine(dispatch_combine) - if not success: - logger.warning("DP+EP not supported for %s.", - type(self.quant_method)) - - self.apply_router_weight_on_input = apply_router_weight_on_input - moe_quant_params = { - "num_experts": self.local_num_experts, - "hidden_size": hidden_size, - "intermediate_size_per_partition": - self.intermediate_size_per_partition, - "params_dtype": params_dtype, - "weight_loader": self.weight_loader, - } - # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod")): - moe_quant_params["intermediate_size_full"] = intermediate_size - - self.quant_method.create_weights(layer=self, **moe_quant_params) - def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -1016,9 +1017,14 @@ def naive_multicast(self, x: torch.Tensor, return buffer + # TODO: will this be cudagraph-able? (probably not) + # This should not be necessary. + def invalid_pplx(self, hidden_states: torch.Tensor) -> bool: + return has_pplx and hidden_states.shape[0] < self.dp_size + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - if self.use_direct_call: + if self.use_direct_call or self.invalid_pplx(hidden_states): return self.forward_impl(hidden_states, router_logits) else: return torch.ops.vllm.moe_forward(hidden_states, router_logits, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 420a81f3f5c..4c00edd0b3d 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -28,6 +28,7 @@ def __init__(self, quant_dtype: Optional[torch.dtype] = None, block_shape: Optional[List[int]] = None): super().__init__() + assert max_num_tokens > 0 self.a2a = a2a self.block_shape = block_shape self.max_num_tokens = max_num_tokens @@ -47,13 +48,15 @@ def dispatch( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - # Is this always going to be a1.device? - device = a1.device + num_tokens = a1.shape[0] # M hidden_dim = a1.shape[-1] # K - # ?? + assert rank_topk_ids.shape[0] == num_tokens # assert expert_map is None, "NYI" + # Is this always going to be a1.device? + device = a1.device + if apply_router_weight_on_input: topk = rank_topk_ids.shape[1] # TODO: this only works for topK=1, will need to update for topK>1 @@ -102,7 +105,6 @@ def dispatch( ) # This argument is optional, defaults to indices.shape[0] - num_tokens = a1.shape[0] # M bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) # TODO: optimize this? @@ -133,7 +135,9 @@ def combine( dtype=torch.uint32, device=fused_expert_output.device) - assert output.shape[0] <= self.max_num_tokens + assert topk_ids.shape[0] <= num_tokens + assert output.shape[0] <= self.max_num_tokens, \ + f"{output.shape[0]} <= {self.max_num_tokens}" assert output.shape[1] == fused_expert_output.shape[-1] # Set weights to 1 if we did them in dispatch. This is hacky. diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index d24ae4768a6..5ddb0e66842 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -12,11 +12,11 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, block_shape: Optional[List[int]] = None, block_m: Optional[int] = None, allow_deep_gemm: bool = False): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9fabf4edadf..9156cb568f9 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -777,7 +777,6 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale - # Maybe extra args def set_dispatch_combine( self, dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: @@ -787,15 +786,9 @@ def set_dispatch_combine( if self.use_marlin: return False - #block_m = MOE_DP_CHUNK_SIZE * (moe.ep_size // moe.dp_size) - #print(f"block_m = {block_m}") - experts = TritonOrDeepGemmExperts( use_fp8_w8a8=True, - use_int8_w8a16=False, - use_int4_w4a16=False, block_shape=self.quant_config.weight_block_size, - block_m=None, # TODO allow_deep_gemm=self.allow_deep_gemm, ) From 9590b96831cad33a602213bfae3d101fc1d9e937 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 21:32:39 +0000 Subject: [PATCH 165/190] cleanups + lint, layer.py wip Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 3 +-- vllm/model_executor/layers/fused_moe/layer.py | 21 +++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 6dd028894b3..111a5a30176 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -504,8 +504,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): chunk_by_rank(w2, rank, world_size), chunk_topk_weight, chunk_topk_ids, - global_num_experts=num_experts - ) + global_num_experts=num_experts) torch.cuda.synchronize() diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 84f3189c4d4..6dbac1aac59 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -622,8 +622,7 @@ def __init__( assert quant_method is not None self.quant_method = quant_method - dispatch_combine = self._construct_dispatch_combine( - moe, quant_config) + dispatch_combine = self._construct_dispatch_combine(moe, quant_config) success = self.quant_method.set_dispatch_combine(dispatch_combine) @@ -1030,13 +1029,12 @@ def forward(self, hidden_states: torch.Tensor, return torch.ops.vllm.moe_forward(hidden_states, router_logits, self.layer_name) - def forward_impl_chunked(self, full_hidden_states: torch.Tensor, full_router_logits: torch.Tensor): full_final_hidden_states = torch.empty_like(full_hidden_states) - def process_chunk(chunk_start, chunk_end, skip_result_store = False): + def process_chunk(chunk_start, chunk_end, skip_result_store=False): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] @@ -1089,18 +1087,23 @@ def process_chunk(chunk_start, chunk_end, skip_result_store = False): full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) - max_tokens_across_dp = get_forward_context().dp_metadata.max_tokens_across_dp + max_tokens_across_dp = get_forward_context( + ).dp_metadata.max_tokens_across_dp moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size num_tokens = full_hidden_states.size(0) - for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): - chunk_start = chunk_start_ - chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dp) + for chunk_start_ in range(0, max_tokens_across_dp, + moe_dp_chunk_size_per_rank): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, + max_tokens_across_dp) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) - process_chunk(chunk_start, chunk_end, skip_result_store = chunk_start_ >= num_tokens) + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= num_tokens) return full_final_hidden_states From ff40a9c758f5c75ce455473d3eff418387147555 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Apr 2025 21:43:57 +0000 Subject: [PATCH 166/190] fix parallel_state lint Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6a3725b88c8..2cedaa06018 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -23,7 +23,7 @@ """ import contextlib import gc -import importlib +import importlib.util import pickle import weakref from collections import namedtuple @@ -925,7 +925,7 @@ def pplx_init(rank, world_size): nvshmem_get_unique_id, nvshmem_init) try: global PPLX_DID_INIT - logger.debug("PPLX_INIT %s %d", rank, world_size) + logger.info("PPLX_INIT rank=%d world=%d", rank, world_size) uid = nvshmem_get_unique_id( ) if rank == 0 else nvshmem_alloc_empty_unique_id() uid_gpu = uid.cuda() @@ -943,7 +943,7 @@ def pplx_finalize(): global PPLX_DID_INIT if PPLX_DID_INIT: from pplx_kernels.nvshmem import nvshmem_finalize - logger.debug("PPLX finalize") + logger.info("PPLX finalize") nvshmem_finalize() From c7cb7df433f297dfd2419883b5b308d109de87c5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 1 May 2025 02:48:00 +0000 Subject: [PATCH 167/190] fix M=1 pplx test Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 106 +++++++++--------- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 2 files changed, 56 insertions(+), 52 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 111a5a30176..b6c15b1a2bb 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -297,18 +297,24 @@ def chunk_by_rank(t, r, w): return t[(r * chunk):(r + 1) * chunk] -def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): +ata = None + +def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): assert torch.cuda.current_device() == pgi.local_rank + topk = topk_ids.shape[1] + + #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) + num_tokens, hidden_dim = a.shape - num_experts = w1.shape[0] block_size = 128 device = pgi.device rank = pgi.rank world_size = pgi.world_size - rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = max(num_tokens, 1) + max_num_tokens = rank_chunk(num_tokens, 0, world_size) + print(f"MAX_NUM_TOKENS = {max_num_tokens}") + global ata ata = AllToAll.internode( max_num_tokens=max_num_tokens, num_experts=num_experts, @@ -333,9 +339,11 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) - score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, - False) + num_tokens = a_chunk.shape[0] + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + print(f"{rank}: shapes {a_chunk.shape}, {chunk_topk_weight.shape}, {chunk_topk_ids.shape}, E={num_experts}") b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, @@ -343,11 +351,13 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): None, chunk_topk_weight, chunk_topk_ids, - num_experts, # store at PplxDispatchCombine creation? + num_experts, None, False, ) + #torch.cuda.synchronize() + if False: naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts) @@ -364,7 +374,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): b_a = b_a * 1.5 out = torch.full( - (rank_num_tokens, hidden_dim), + (max_num_tokens, hidden_dim), torch.nan, dtype=a.dtype, device=device, @@ -377,22 +387,21 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk): chunk_topk_ids, False, ) - torch.cuda.synchronize() - ata.destroy() + #torch.cuda.synchronize() + + #ata.destroy() - return out[:rank_num_tokens] + return out[:num_tokens] def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, - m, - n, - k, - e, - topk: int, - dtype: torch.dtype, + a, + topk_weight, + topk_ids, + num_experts, ): uid = nvshmem_get_unique_id( ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() @@ -400,37 +409,34 @@ def _pplx_dispatch_combine( nvshmem_init(uid, pgi.rank, pgi.world_size) device = pgi.device - a = torch.randn((m, k), device=device, dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 - score = torch.randn((m, e), device=device, dtype=dtype) - - topk_weight, topk_ids = fused_topk(a, score, topk, False) + k = a.shape[1] + topk = topk_ids.shape[1] - a_rep = torch.repeat_interleave(a, topk, dim=0) + a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) torch_output = (a_rep.view(-1, topk, k) * 1.5 * - topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype) + topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(a.dtype) - pplx_output = torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, score, - topk) + pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) + print(f"{pgi.rank}: out shapes {pplx_output.shape}, {torch_output.shape}") + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() # TODO: M < world_size doesn't appear to be supported by pplx? -@pytest.mark.parametrize("m", [4, 32, 64, 222]) +@pytest.mark.parametrize("m", [1, 4, 32, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #[[4, 2]]) @requires_pplx def test_pplx_dispatch_combine( m: int, @@ -443,22 +449,27 @@ def test_pplx_dispatch_combine( ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size + device = "cuda" + + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + topk_weight, topk_ids = fused_topk(a, score, topk, False) - parallel_launch(world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, - topk, dtype) + parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, topk_weight, topk_ids, e) -def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): +def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): assert torch.cuda.current_device() == pgi.local_rank - num_tokens, hidden_dim = a.shape + hidden_dim = a.shape[1] num_experts = w1.shape[0] block_size = 128 device = pgi.device rank = pgi.rank world_size = pgi.world_size - rank_num_tokens = rank_chunk(num_tokens, rank, world_size) - max_num_tokens = num_tokens + topk = topk_ids.shape[1] + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) ata = AllToAll.internode( max_num_tokens=max_num_tokens, @@ -474,9 +485,6 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): torch.float32.itemsize)), ) - w1 = w1.to(device) - w2 = w2.to(device) - dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, @@ -493,15 +501,14 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) - score_chunk = chunk_by_rank(scores, rank, world_size).to(device) - chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, - False) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) out = fused_experts( a_chunk, # Chunking weights like this only works for batched format - chunk_by_rank(w1, rank, world_size), - chunk_by_rank(w2, rank, world_size), + chunk_by_rank(w1, rank, world_size).to(device), + chunk_by_rank(w2, rank, world_size).to(device), chunk_topk_weight, chunk_topk_ids, global_num_experts=num_experts) @@ -510,7 +517,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk): ata.destroy() - return out[:rank_num_tokens] + return out def _pplx_moe( @@ -521,7 +528,6 @@ def _pplx_moe( w2: torch.Tensor, score: torch.Tensor, topk: int, - dtype: torch.dtype, ): uid = nvshmem_get_unique_id( ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() @@ -534,7 +540,7 @@ def _pplx_moe( with set_current_vllm_config(vllm_config): topk_weight, topk_ids = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplx_output = torch_pplx_moe(pgi, dp_size, a, w1, w2, score, topk) + pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) @@ -544,8 +550,7 @@ def _pplx_moe( nvshmem_finalize() -# TODO: M < world_size doesn't appear to be supported by pplx? -@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222]) +@pytest.mark.parametrize("m", [1, 2, 3, 32, 45, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @@ -569,5 +574,4 @@ def test_pplx_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, - dtype) + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6dbac1aac59..a5c3c7a873d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -251,7 +251,7 @@ def set_dispatch_combine( self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: assert self.fused_experts == fused_experts - experts: FusedMoEPermuteExpertsUnpermute = None + experts: Optional[FusedMoEPermuteExpertsUnpermute] = None if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): From 4f4584ae229906a06f49b9576ba2cc01551b47ce Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 1 May 2025 04:04:24 +0000 Subject: [PATCH 168/190] fix M=1 pplx test Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 68 +++++++++--------------------- 1 file changed, 19 insertions(+), 49 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index b6c15b1a2bb..26021d20193 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,7 +28,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedExperts) + BatchedExperts, BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) @@ -293,34 +293,26 @@ def rank_chunk(num, r, w): def chunk_by_rank(t, r, w): chunk = rank_chunk(t.shape[0], r, w) - #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}") return t[(r * chunk):(r + 1) * chunk] -ata = None - def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): assert torch.cuda.current_device() == pgi.local_rank topk = topk_ids.shape[1] - - #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - - num_tokens, hidden_dim = a.shape + num_tokens, hidden_dim = a.shape[1] block_size = 128 device = pgi.device rank = pgi.rank world_size = pgi.world_size max_num_tokens = rank_chunk(num_tokens, 0, world_size) - print(f"MAX_NUM_TOKENS = {max_num_tokens}") - global ata ata = AllToAll.internode( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, rank=rank, - world_size=pgi.world_size, + world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, @@ -332,19 +324,15 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, - pgi.world_size, + world_size, dp_size, rank, - a.dtype, ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) - num_tokens = a_chunk.shape[0] chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) - print(f"{rank}: shapes {a_chunk.shape}, {chunk_topk_weight.shape}, {chunk_topk_ids.shape}, E={num_experts}") - b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch( a_chunk, None, @@ -356,21 +344,6 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): False, ) - #torch.cuda.synchronize() - - if False: - naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, - num_experts) - - torch.distributed.all_reduce(tokens_per_expert) - tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, - world_size).to(dtype=torch.int32) - - torch.testing.assert_close(tokens_per_expert, - expert_num_tokens, - atol=0, - rtol=0) - b_a = b_a * 1.5 out = torch.full( @@ -388,9 +361,11 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): False, ) - #torch.cuda.synchronize() + torch.cuda.synchronize() - #ata.destroy() + ata.destroy() + + num_tokens = a_chunk.shape[0] return out[:num_tokens] @@ -399,8 +374,8 @@ def _pplx_dispatch_combine( pgi: ProcessGroupInfo, dp_size: int, a, - topk_weight, - topk_ids, + score, + topk, num_experts, ): uid = nvshmem_get_unique_id( @@ -409,8 +384,8 @@ def _pplx_dispatch_combine( nvshmem_init(uid, pgi.rank, pgi.world_size) device = pgi.device + topk_weight, topk_ids = fused_topk(a, score, topk, False) k = a.shape[1] - topk = topk_ids.shape[1] a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) @@ -422,21 +397,19 @@ def _pplx_dispatch_combine( torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) - print(f"{pgi.rank}: out shapes {pplx_output.shape}, {torch_output.shape}") - torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() -# TODO: M < world_size doesn't appear to be supported by pplx? -@pytest.mark.parametrize("m", [1, 4, 32, 64, 222]) +# TODO: this test point does not work for M == 1 +@pytest.mark.parametrize("m", [4, 32, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #[[4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @requires_pplx def test_pplx_dispatch_combine( m: int, @@ -450,13 +423,10 @@ def test_pplx_dispatch_combine( current_platform.seed_everything(7) world_size, dp_size = world_dp_size device = "cuda" - a = torch.randn((m, k), device=device, dtype=dtype) / 10 score = torch.randn((m, e), device=device, dtype=dtype) - topk_weight, topk_ids = fused_topk(a, score, topk, False) - - parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, topk_weight, topk_ids, e) + parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score, topk, e) def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): @@ -476,7 +446,7 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): num_experts=num_experts, experts_per_token=topk, rank=rank, - world_size=pgi.world_size, + world_size=world_size, dp_size=dp_size, hidden_dim=hidden_dim, hidden_dim_bytes=hidden_dim * a.dtype.itemsize, @@ -488,12 +458,12 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, - pgi.world_size, + world_size, dp_size, rank, ) - experts = BatchedExperts(max_num_tokens) + experts = BatchedExperts(a.shape[0]) fused_experts = FusedMoEModularKernel( dispatch_combine, @@ -556,7 +526,7 @@ def _pplx_moe( @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) @requires_pplx def test_pplx_moe( m: int, From 73226b9c7fcf4d657ca8be0e64db64df1bb78ea7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 1 May 2025 04:04:45 +0000 Subject: [PATCH 169/190] fix M=1 pplx test Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 26021d20193..d7916b31d3c 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -300,7 +300,7 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): assert torch.cuda.current_device() == pgi.local_rank topk = topk_ids.shape[1] - num_tokens, hidden_dim = a.shape[1] + num_tokens, hidden_dim = a.shape block_size = 128 device = pgi.device rank = pgi.rank From a77fb2c14ed6a7e7a069dcba86439e974d651d60 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 1 May 2025 12:47:50 +0000 Subject: [PATCH 170/190] lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index d7916b31d3c..5dd52ed3564 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,7 +28,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedExperts, BatchedTritonExperts) + BatchedExperts) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) @@ -390,9 +390,11 @@ def _pplx_dispatch_combine( a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) torch_output = (a_rep.view(-1, topk, k) * 1.5 * - topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(a.dtype) + topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to( + a.dtype) - pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts) + pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, + num_experts) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) @@ -426,7 +428,8 @@ def test_pplx_dispatch_combine( a = torch.randn((m, k), device=device, dtype=dtype) / 10 score = torch.randn((m, e), device=device, dtype=dtype) - parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score, topk, e) + parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score, + topk, e) def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): From 48ba146818d2df245f249308fd998887d171a9bd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 1 May 2025 14:44:34 +0000 Subject: [PATCH 171/190] remove valid pplx check Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a5c3c7a873d..4b0be6f61d0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1016,14 +1016,9 @@ def naive_multicast(self, x: torch.Tensor, return buffer - # TODO: will this be cudagraph-able? (probably not) - # This should not be necessary. - def invalid_pplx(self, hidden_states: torch.Tensor) -> bool: - return has_pplx and hidden_states.shape[0] < self.dp_size - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - if self.use_direct_call or self.invalid_pplx(hidden_states): + if self.use_direct_call: return self.forward_impl(hidden_states, router_logits) else: return torch.ops.vllm.moe_forward(hidden_states, router_logits, From 4c40380639d80c104a31504d3b02eebc68353213 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 2 May 2025 00:43:14 +0000 Subject: [PATCH 172/190] semi-working cudagraphs Signed-off-by: Bill Nell --- csrc/dispatch_utils.h | 13 ++++ csrc/moe/moe_align_sum_kernels.cu | 8 +-- csrc/moe/topk_softmax_kernels.cu | 63 +++++++++++++------ examples/offline_inference/data_parallel.py | 16 +++-- pyproject.toml | 4 +- vllm/compilation/compiler_interface.py | 4 +- vllm/distributed/utils.py | 10 ++- .../layers/fused_moe/fused_batched_moe.py | 22 +++---- .../layers/fused_moe/fused_moe.py | 3 +- vllm/model_executor/layers/fused_moe/layer.py | 40 +++++++----- .../layers/fused_moe/pplx_dispatch_combine.py | 16 ++--- vllm/model_executor/layers/fused_moe/utils.py | 2 +- vllm/platforms/cuda.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 1 + 14 files changed, 135 insertions(+), 69 deletions(-) diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index dc6e0769b87..10a183dc950 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -65,5 +65,18 @@ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__) + #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index d7be769458e..6b6a9d04a60 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, } if (use_global_memory) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors @@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, cumsum_buffer.data_ptr()); }); } else if (use_i16) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // set dynamic shared mem auto kernel = @@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.numel()); }); } else { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto kernel = vllm::moe::moe_align_block_size_kernel; @@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, TORCH_CHECK(num_experts == 256, "sgl_moe_align_block_size kernel only supports deepseek v3."); - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors auto options_int = diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index de9747b6025..a9379032245 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__ } } -template -__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, - int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +template +__launch_bounds__(TPB) __global__ void moeTopK( + const float* inputs_after_softmax, + const bool* finished, + float* output, + IndType* indices, + int* source_rows, + const int num_experts, + const int k, + const int start_expert, + const int end_expert) { using cub_kvp = cub::KeyValuePair; @@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, int* source_rows, const int k, const int start_expert, const int end_expert) { // We begin by enforcing compile time assertions and setting up compile time constants. @@ -397,8 +405,8 @@ struct TopkConstants }; } // namespace detail -template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; @@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f token_expert_indices, num_tokens, topk, 0, num_experts, \ stream); +template void topkGatingSoftmaxKernelLauncher( const float* gating_output, float* topk_weights, - int* topk_indicies, + IndType* topk_indicies, int* token_expert_indices, float* softmax_workspace, const int num_tokens, @@ -493,14 +502,32 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); + + if(topk_indices.scalar_type() == at::ScalarType::Int) + { + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } + else + { + assert(topk_indices.scalar_type() == at::ScalarType::UInt32); + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } } diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 965915beaf5..9364924b305 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -31,6 +31,7 @@ from time import sleep from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig from vllm.utils import get_open_port @@ -65,11 +66,14 @@ def parse_args(): type=int, default=0, help="Master node port") + parser.add_argument("--enforce-eager", + action='store_true', + help="Enforce eager mode execution.") return parser.parse_args() def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, - dp_master_port, GPUs_per_dp_rank): + dp_master_port, GPUs_per_dp_rank, enforce_eager): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) @@ -109,10 +113,14 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, max_tokens=[16, 20][global_dp_rank % 2]) # Create an LLM. + cconfig = CompilationConfig( + level=0, + ) llm = LLM(model=model, tensor_parallel_size=GPUs_per_dp_rank, - enforce_eager=True, - enable_expert_parallel=True) + enforce_eager=enforce_eager, + enable_expert_parallel=True, + compilation_config=cconfig) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for i, output in enumerate(outputs): @@ -155,7 +163,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, proc = Process(target=main, args=(args.model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, dp_master_port, - tp_size)) + tp_size, args.enforce_eager)) proc.start() procs.append(proc) exit_code = 0 diff --git a/pyproject.toml b/pyproject.toml index 069e295bfb9..4fa012145cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,8 @@ build-backend = "setuptools.build_meta" [project] name = "vllm" authors = [{name = "vLLM Team"}] -license = "Apache-2.0" -license-files = ["LICENSE"] +#license = "Apache-2.0" +#license-files = ["LICENSE"] readme = "README.md" description = "A high-throughput and memory-efficient inference and serving engine for LLMs" classifiers = [ diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index b7e7a79bef0..0cb4a2d7c5f 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -326,9 +326,9 @@ def _get_shape_env() -> AlwaysHitShapeEnv: # compilation cache. if not envs.VLLM_DISABLE_COMPILE_CACHE: assert hash_str is not None, ( - "failed to get the hash of the compiled graph") + f"failed to get the hash of the compiled graph: {file_path}") assert file_path is not None, ( - "failed to get the file path of the compiled graph") + "failed to get the file path of the compiled graph: {file_path}") return compiled_graph, (hash_str, file_path) def load(self, diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index e4d4008cd0a..6bb39672a32 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -360,7 +360,11 @@ def stateless_destroy_torch_distributed_process_group( Destroy ProcessGroup returned by stateless_init_torch_distributed_process_group(). """ - # Lazy import for non-CUDA backends. - from torch.distributed.distributed_c10d import _shutdown_backend - _shutdown_backend(pg) + # TODO: pytorch < 2.7? + if False: + # Lazy import for non-CUDA backends. + from torch.distributed.distributed_c10d import _shutdown_backend + _shutdown_backend(pg) + else: + pg.shutdown() _unregister_process_group(pg.group_name) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index be700f7b2e9..f2d7ab0e843 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -577,11 +577,11 @@ def workspace_shapes( topk: int, num_experts: int, ) -> Tuple[int, int, torch.dtype]: + assert a.dim() == 2 max_num_tokens = a.shape[ - 1] if self.max_num_tokens is None else self.max_num_tokens - # TODO: *2 is a hack - workspace13 = num_experts * max_num_tokens * K * topk * 2 - workspace2 = max_num_tokens * N + 0] if self.max_num_tokens is None else self.max_num_tokens + workspace13 = num_experts * max_num_tokens * max(K, N) + workspace2 = max_num_tokens * (N // 2) return (workspace13, workspace2, a.dtype) def apply( @@ -605,6 +605,7 @@ def apply( ) -> torch.Tensor: assert hidden_states.dim() == 3 assert expert_num_tokens is not None + hidden_dim = hidden_states.shape[-1] if self.max_num_tokens is None: max_num_tokens = hidden_states.shape[1] @@ -613,13 +614,13 @@ def apply( num_experts = global_num_experts out = _resize_cache(workspace13, - (num_experts, max_num_tokens, w2.shape[1])) + (num_experts, max_num_tokens, hidden_dim)) num_local_experts = expert_num_tokens.numel() for expert in range(num_local_experts): num = expert_num_tokens[expert] - assert num <= max_num_tokens, f"{num}, {max_num_tokens}" - if num > 0: + #assert num <= max_num_tokens, f"{num}, {max_num_tokens}" + if True or num > 0: # CUDAGRAPH unfriendly? tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) self.activation( activation, tmp, @@ -660,8 +661,9 @@ def workspace_shapes( topk: int, num_experts: int, ) -> Tuple[int, int, torch.dtype]: + assert a.dim() == 2 max_num_tokens = a.shape[ - 1] if self.max_num_tokens is None else self.max_num_tokens + 0] if self.max_num_tokens is None else self.max_num_tokens workspace13 = num_experts * max_num_tokens * max(K, N) workspace2 = num_experts * max_num_tokens * (N // 2) return (workspace13, workspace2, a.dtype) @@ -685,9 +687,6 @@ def apply( workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: - - num_tokens = topk_ids.size(0) - # Check constraints. if self.use_int4_w4a16: assert hidden_states.shape[-1] // 2 == w1.shape[ @@ -705,6 +704,7 @@ def apply( torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn ] + # TODO: num_tokens -> max_num_tokens? E, num_tokens, N, K, top_k_num = mk._moe_problem_size( hidden_states, w1, w2, topk_ids) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ac834a89195..f461d8c60cc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -870,6 +870,7 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, + indices_type: torch.dtype = torch.int32, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -882,7 +883,7 @@ def fused_topk( device=hidden_states.device) topk_ids = torch.empty(M, topk, - dtype=torch.int32, + dtype=indices_type, device=hidden_states.device) token_expert_indicies = torch.empty(M, topk, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4b0be6f61d0..a078713f051 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -31,7 +31,10 @@ if current_platform.is_cuda_alike(): from .dispatch_combine import StandardDispatchCombine - from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts + from .fused_batched_moe import ( + BatchedDispatchCombine, + BatchedTritonExperts, + BatchedExperts) from .fused_moe import TritonExperts, fused_experts from .modular_kernel import (FusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, @@ -257,6 +260,7 @@ def set_dispatch_combine( (BatchedDispatchCombine, PplxDispatchCombine)): logger.debug("BatchedTritonExperts %s", self.moe) experts = BatchedTritonExperts( + max_num_tokens=MOE_DP_CHUNK_SIZE, use_fp8_w8a8=False, use_int8_w8a8=False, use_int8_w8a16=False, @@ -624,11 +628,11 @@ def __init__( dispatch_combine = self._construct_dispatch_combine(moe, quant_config) - success = self.quant_method.set_dispatch_combine(dispatch_combine) - - if not success: - logger.warning("DP+EP not supported for %s.", - type(self.quant_method)) + if dispatch_combine is not None: + success = self.quant_method.set_dispatch_combine(dispatch_combine) + if not success: + logger.warning("DP+EP not supported for %s.", + type(self.quant_method)) self.apply_router_weight_on_input = apply_router_weight_on_input moe_quant_params = { @@ -653,7 +657,7 @@ def _construct_dispatch_combine( self, moe: MoEConfig, quant_config: Optional[QuantizationConfig], - ) -> FusedMoEQuantizeDispatchCombine: + ) -> Optional[FusedMoEQuantizeDispatchCombine]: if self.dp_size > 1 and has_pplx: logger.debug("using pplx dispatch") max_num_tokens = MOE_DP_CHUNK_SIZE @@ -685,7 +689,9 @@ def _construct_dispatch_combine( rank, moe.in_dtype, ) - elif True: + elif False: + return None + elif False: logger.debug("using standard dispatch") return StandardDispatchCombine( moe.in_dtype, @@ -694,9 +700,11 @@ def _construct_dispatch_combine( ) else: logger.debug("using batched dispatch") + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + rank = moe.ep_rank return BatchedDispatchCombine( - moe.ep_size, - moe.ep_rank, + dp_size, + rank, ) def _load_per_tensor_weight_scale(self, shard_id: str, @@ -984,11 +992,13 @@ def select_experts(hidden_states: torch.Tensor, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) elif custom_routing_function is None: - topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) + topk_weights, topk_ids, token_expert_indices = fused_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + # XXXXX how to do this? + indices_type=torch.uint32, + ) else: topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 4c00edd0b3d..d605d4d7bc2 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -105,10 +105,11 @@ def dispatch( ) # This argument is optional, defaults to indices.shape[0] - bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) + #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) + bound_m = None # TODO: optimize this? - indices = rank_topk_ids.to(dtype=torch.uint32) + #indices = rank_topk_ids.to(dtype=torch.uint32) self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, @@ -116,7 +117,7 @@ def dispatch( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=indices, + indices=rank_topk_ids, bound_m=bound_m, ) return expert_x, expert_x_scale, expert_num_tokens @@ -131,9 +132,10 @@ def combine( ) -> None: # This argument is optional num_tokens = output.shape[0] # M - bound_m = torch.tensor([num_tokens], - dtype=torch.uint32, - device=fused_expert_output.device) + #bound_m = torch.tensor([num_tokens], + # dtype=torch.uint32, + # device=fused_expert_output.device) + bound_m = None assert topk_ids.shape[0] <= num_tokens assert output.shape[0] <= self.max_num_tokens, \ @@ -145,7 +147,7 @@ def combine( topk_weights = torch.ones_like(topk_weights) self.a2a.combine(out_tokens=output, - indices=topk_ids.to(torch.uint32), + indices=topk_ids, #.to(torch.uint32), weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index d53da1d7926..aab23abe136 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -17,7 +17,7 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - assert prod(v) <= x.numel(), f"{prod(v)} <= {x.numel()}" + #assert prod(v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly? return x.flatten()[:prod(v)].view(*v) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ab03dece8c1..31a28f5064d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -157,7 +157,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") - if (parallel_config.data_parallel_size > 1 + if (False and parallel_config.data_parallel_size > 1 and compilation_config.use_cudagraph): logger.info( "Data Parallel: Forcing enforce eager to be True since DP is " diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e0c3d05c797..df1176212d0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1542,6 +1542,7 @@ def _dummy_run( self.drafter.dummy_run(num_tokens) logit_indices = np.cumsum(num_scheduled_tokens) - 1 + #logit_indices = torch.from_numpy(logit_indices).to(hidden_states.device) return hidden_states[logit_indices] @torch.inference_mode() From 1938bc8f6c769383a24979624e251d239d7f0733 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 2 May 2025 21:42:28 +0000 Subject: [PATCH 173/190] fix reference implementations Signed-off-by: Bill Nell --- examples/offline_inference/data_parallel.py | 5 +- tests/kernels/moe/test_batched_moe.py | 6 +- tests/kernels/moe/test_pplx_moe.py | 38 ++- .../layers/fused_moe/fused_batched_moe.py | 262 +++++++++++++++--- vllm/model_executor/layers/fused_moe/layer.py | 29 +- .../layers/fused_moe/modular_kernel.py | 41 +-- vllm/model_executor/layers/fused_moe/utils.py | 2 +- 7 files changed, 303 insertions(+), 80 deletions(-) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 9364924b305..c813b22c4e8 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -115,12 +115,15 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, # Create an LLM. cconfig = CompilationConfig( level=0, + #cudagraph_capture_sizes=[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208], + #cudagraph_capture_sizes=[512,256,1], ) llm = LLM(model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=enforce_eager, enable_expert_parallel=True, - compilation_config=cconfig) + compilation_config=cconfig, + ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for i, output in enumerate(outputs): diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 1bb8f4e09dd..39b5d5c6793 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -62,9 +62,9 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, @pytest.mark.parametrize("num_experts", [16, 32]) -@pytest.mark.parametrize("max_tokens_per_expert", [512]) -@pytest.mark.parametrize("K", [256]) -@pytest.mark.parametrize("N", [512]) +@pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512]) +@pytest.mark.parametrize("K", [128, 256, 1024]) +@pytest.mark.parametrize("N", [128, 256, 512, 1024]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, N: int, dtype: torch.dtype): diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 5dd52ed3564..4886e2879ef 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,6 +28,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedDispatchCombine, BatchedExperts) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( @@ -170,7 +171,7 @@ def torch_dispatch( assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a.shape[0] - num_tokens = a.shape[0] + num_tokens, hidden_dim = a.shape topk = topk_ids.shape[1] tokens_per_expert = torch.bincount(topk_ids.view(-1), @@ -181,7 +182,7 @@ def torch_dispatch( if max_num_tokens is None: max_num_tokens = int(tokens_per_expert.max().item()) - b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]), + b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device) @@ -198,7 +199,7 @@ def torch_dispatch( def torch_combine(b_out, topk_weight, topk_ids): - num_tokens, topk = topk_ids.shape + num_tokens = topk_ids.shape[0] num_experts = b_out.shape[0] K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) @@ -240,6 +241,22 @@ def torch_batched_moe(a, w1, w2, topk_weight, topk_ids): return torch_combine(out, topk_weight, topk_ids) +def batched_moe(a, w1, w2, topk_weight, topk_ids): + num_experts = w1.shape[0] + + fused_experts = FusedMoEModularKernel( + BatchedDispatchCombine(a.shape[0], world_size=1, dp_size=1, rank=0), + BatchedExperts(a.shape[0]) + ) + + return fused_experts(a, + w1, + w2, + topk_weight, + topk_ids, + num_experts) + + # TODO: same as torch_moe but with fused_topk factored out. def torch_moe2(a, w1, w2, topk_weight, topk_ids): M, K = a.shape @@ -262,7 +279,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): @pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_fused_moe_batched_experts( m: int, n: int, @@ -280,10 +297,13 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids = fused_topk(a, score, topk, False) - torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - triton_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) + baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) + batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0) + torch.set_printoptions(profile="full") + torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0) def rank_chunk(num, r, w): @@ -473,6 +493,8 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): experts, ) + # TODO: workers with the same dp_rank must use the exact same inputs. + a_chunk = chunk_by_rank(a, rank, world_size).to(device) chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) @@ -528,7 +550,7 @@ def _pplx_moe( @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @requires_pplx def test_pplx_moe( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index f2d7ab0e843..4159bbf0591 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.utils import direct_register_custom_op @triton.jit @@ -467,10 +468,12 @@ def invoke_batched_silu_and_mul( class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, world_size: int, rank: int): + def __init__(self, max_num_tokens: Optional[int], world_size: int, dp_size: int, rank: int): super().__init__() self.world_size = world_size + self.dp_size = dp_size self.rank = rank + self.max_num_tokens = max_num_tokens def dispatch( self, @@ -493,26 +496,29 @@ def dispatch( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) - num_tokens = a1.shape[0] + num_tokens, hidden_dim = a1.shape topk = topk_ids.shape[1] tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) - max_num_tokens = tokens_per_expert.max() - expert_counts = torch.zeros(num_experts, - dtype=torch.int, - device=a1.device) - b_a1 = torch.zeros((num_experts, max_num_tokens, a1.shape[1]), + if self.max_num_tokens is None: + self.max_num_tokens = int(tokens_per_expert.max().item()) + + b_a1 = torch.zeros((num_experts, self.max_num_tokens, hidden_dim), dtype=a1.dtype, device=a1.device) + token_counts = torch.zeros(num_experts, + dtype=torch.int, + device=a1.device) + for token in range(num_tokens): for j in range(topk): expert_id = topk_ids[token, j] - idx = expert_counts[expert_id] + idx = token_counts[expert_id] b_a1[expert_id, idx:idx + 1, :] = a1[token, :] - expert_counts[expert_id] = expert_counts[expert_id] + 1 + token_counts[expert_id] = token_counts[expert_id] + 1 return b_a1, a1_scale, tokens_per_expert @@ -526,25 +532,26 @@ def combine( ) -> None: num_tokens = topk_ids.shape[0] num_experts = fused_expert_output.shape[0] - expert_counts = torch.zeros(num_experts, - dtype=torch.int, - device=fused_expert_output.device) + K = fused_expert_output.shape[-1] + assert output.shape[0] == num_tokens and output.shape[1] == K + expert_counts = torch.zeros( + num_experts, + dtype=torch.int, + device=fused_expert_output.device) + + output.fill_(0) + for token in range(num_tokens): expert_ids = topk_ids[token] - for i in range(topk_ids.shape[1]): + for i in range(expert_ids.numel()): expert_id = expert_ids[i] - if expert_id < num_experts: - idx = expert_counts[expert_id] - if apply_router_weight_on_input: - output[token, :] = output[ - token, :] + fused_expert_output[expert_id, - idx:idx + 1, :] - else: - output[ - token, :] = output[token, :] + fused_expert_output[ - expert_id, - idx:idx + 1, :] * topk_weights[token, i] - expert_counts[expert_id] = expert_counts[expert_id] + 1 + assert expert_id < num_experts + idx = expert_counts[expert_id] + accum = fused_expert_output[expert_id, idx:idx + 1, :] + if not apply_router_weight_on_input: + accum = accum * topk_weights[token, i] + output[token, :] = output[token, :] + accum + expert_counts[expert_id] = expert_counts[expert_id] + 1 class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -580,8 +587,8 @@ def workspace_shapes( assert a.dim() == 2 max_num_tokens = a.shape[ 0] if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * max(K, N) - workspace2 = max_num_tokens * (N // 2) + workspace13 = num_experts * max_num_tokens * K + workspace2 = max_num_tokens * N return (workspace13, workspace2, a.dtype) def apply( @@ -616,21 +623,183 @@ def apply( out = _resize_cache(workspace13, (num_experts, max_num_tokens, hidden_dim)) num_local_experts = expert_num_tokens.numel() + assert num_local_experts == w1.shape[0] + + N = w1.shape[1] // 2 for expert in range(num_local_experts): - num = expert_num_tokens[expert] - #assert num <= max_num_tokens, f"{num}, {max_num_tokens}" - if True or num > 0: # CUDAGRAPH unfriendly? - tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2)) - self.activation( - activation, tmp, - hidden_states[expert, :num, :] @ w1[expert].transpose( - 0, 1)) + num = expert_num_tokens[expert].item() + assert num <= max_num_tokens, f"{num} <= {max_num_tokens}" + if num > 0: # CUDAGRAPH unfriendly + tmp = _resize_cache(workspace2, (num, N)) + input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) + assert input.shape[1] == N * 2 + self.activation(activation, tmp, input) out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out +def _apply( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]], +) -> torch.Tensor: + # Check constraints. + if use_int4_w4a16: + assert hidden_states.shape[-1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[-1] == w1.shape[2], \ + (f"Hidden size mismatch {hidden_states.shape[-1]} " + f"!= {w1.shape[2]}") + + assert hidden_states.is_contiguous( + ), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + ] + + # TODO: num_tokens -> max_num_tokens? + E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + assert w1.shape[0] == E + assert w2.shape[0] == E + + config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + dtype=hidden_states.dtype) + + config = try_get_optimal_moe_config( + w1.shape, + w2.shape, + top_k_num, + config_dtype, + num_tokens, + block_shape=block_shape, + ) + + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + elif hidden_states.dtype == torch.float8_e4m3fn: + compute_type = tl.bfloat16 + else: + raise ValueError( + f"Unsupported compute_type: {hidden_states.dtype}") + + #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N)) + intermediate_cache2 = _resize_cache(workspace2, + (E, num_tokens, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K)) + + # MM1 + invoke_moe_batched_triton_kernel(A=hidden_states, + B=w1, + C=intermediate_cache1, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a1q_scale, + B_scale=w1_scale, + B_zp=w1_zp, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + config=config, + block_shape=block_shape) + + # Fix activations + assert activation == "silu" + invoke_batched_silu_and_mul(output=intermediate_cache2, + input=intermediate_cache1, + expert_num_tokens=expert_num_tokens) + + #qintermediate_cache2 = intermediate_cache2 + a2q_scale = a2_scale + # TODO (varun) : support w8a8 + assert not use_fp8_w8a8 + #if self.use_fp8_w8a8: + # qintermediate_cache2, a2q_scale = _fp8_quantize( + # intermediate_cache2, a2_scale, self.block_shape) + + invoke_moe_batched_triton_kernel(A=intermediate_cache2, + B=w2, + C=intermediate_cache3, + expert_num_tokens=expert_num_tokens, + compute_type=compute_type, + A_scale=a2q_scale, + B_scale=w2_scale, + B_zp=w2_zp, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + config=config, + block_shape=block_shape) + + return intermediate_cache3 + + +def _apply_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: Optional[List[int]], +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="_apply", + op_func=_apply, + mutates_args=[], + fake_impl=_apply_fake, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -687,6 +856,29 @@ def apply( workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: + return torch.ops.vllm._apply( + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_num_tokens, + self.use_fp8_w8a8, + self.use_int8_w8a16, + self.use_int4_w4a16, + self.block_shape, + ) + # Check constraints. if self.use_int4_w4a16: assert hidden_states.shape[-1] // 2 == w1.shape[ diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a078713f051..025f7f77b9a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -658,10 +658,11 @@ def _construct_dispatch_combine( moe: MoEConfig, quant_config: Optional[QuantizationConfig], ) -> Optional[FusedMoEQuantizeDispatchCombine]: - if self.dp_size > 1 and has_pplx: + max_num_tokens = MOE_DP_CHUNK_SIZE + world_size = moe.ep_size + + if False and self.dp_size > 1 and has_pplx: logger.debug("using pplx dispatch") - max_num_tokens = MOE_DP_CHUNK_SIZE - world_size = moe.ep_size dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank @@ -691,21 +692,23 @@ def _construct_dispatch_combine( ) elif False: return None - elif False: + elif self.dp_size > 1: + logger.debug("using batched dispatch") + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + rank = moe.ep_rank + return BatchedDispatchCombine( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + rank=rank, + ) + else: logger.debug("using standard dispatch") return StandardDispatchCombine( moe.in_dtype, quant_config.weight_block_size if quant_config is not None else None, ) - else: - logger.debug("using batched dispatch") - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. - rank = moe.ep_rank - return BatchedDispatchCombine( - dp_size, - rank, - ) def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, @@ -997,7 +1000,7 @@ def select_experts(hidden_states: torch.Tensor, topk=top_k, renormalize=renormalize, # XXXXX how to do this? - indices_type=torch.uint32, + #indices_type=torch.uint32, ) else: topk_weights, topk_ids = custom_routing_function( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index eec5a7406d9..fce8bd8091d 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -67,7 +67,7 @@ def _moe_problem_size( M = a1.shape[0] else: assert a1.dim() == 3 - assert a1.shape[0] == E + assert a1.shape[0] == E, f"{a1.shape[0]} == {E}" M = a1.shape[1] # This is max_num_tokens assert topk_ids.dim() == 2 @@ -338,24 +338,27 @@ def forward( a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, expert_map, apply_router_weight_on_input) - fused_out = self.fused_experts.apply( - a1q, - w1, - w2, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) + if True: + fused_out = self.fused_experts.apply( + a1q, + w1, + w2, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) + else: + fused_out = torch.empty_like(a1q) self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index aab23abe136..3e7b2b4047a 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -17,7 +17,7 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - #assert prod(v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly? + assert prod(v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly? return x.flatten()[:prod(v)].view(*v) From 0f2e37ab0d082f697a0bedce49142606252e90ba Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 5 May 2025 15:37:51 +0000 Subject: [PATCH 174/190] wip ref impl Signed-off-by: Bill Nell --- csrc/activation_kernels.cu | 1 + examples/offline_inference/data_parallel.py | 6 +- tests/kernels/moe/test_pplx_moe.py | 46 ++++++++++++ .../layers/fused_moe/fused_batched_moe.py | 73 +++++++++++++------ .../layers/fused_moe/fused_moe.py | 12 +-- vllm/model_executor/layers/fused_moe/layer.py | 19 +++-- .../layers/fused_moe/pplx_dispatch_combine.py | 9 +-- 7 files changed, 125 insertions(+), 41 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83..0c020be65ff 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -70,6 +70,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ + if (num_tokens == 0) { return; } \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index c813b22c4e8..94286dcc816 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -114,9 +114,11 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, # Create an LLM. cconfig = CompilationConfig( - level=0, + level=3, #cudagraph_capture_sizes=[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208], #cudagraph_capture_sizes=[512,256,1], + #cudagraph_capture_sizes=[192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1] + #cudagraph_capture_sizes=[128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1] ) llm = LLM(model=model, tensor_parallel_size=GPUs_per_dp_rank, @@ -171,7 +173,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=300) + proc.join(timeout=3000) if proc.exitcode is None: print(f"Killing process {proc.pid} that " f"didn't stop within 5 minutes.") diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 4886e2879ef..8b3fc6bf9a6 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -515,6 +515,50 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): return out +def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): + assert torch.cuda.current_device() == pgi.local_rank + + hidden_dim = a.shape[1] + num_experts = w1.shape[0] + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + topk = topk_ids.shape[1] + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + + dispatch_combine = BatchedDispatchCombine( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + rank=rank, + ) + + experts = BatchedExperts(a.shape[0]) + + fused_experts = FusedMoEModularKernel( + dispatch_combine, + experts, + ) + + # TODO: workers with the same dp_rank must use the exact same inputs. + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + out = fused_experts( + a_chunk, + # Chunking weights like this only works for batched format + chunk_by_rank(w1, rank, world_size).to(device), + chunk_by_rank(w2, rank, world_size).to(device), + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + return out + + def _pplx_moe( pgi: ProcessGroupInfo, dp_size: int, @@ -536,11 +580,13 @@ def _pplx_moe( topk_weight, topk_ids = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) + batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 4159bbf0591..0b985793071 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -466,6 +466,11 @@ def invoke_batched_silu_and_mul( compute_tl_dtype, D, BLOCK_M, BLOCK_D) +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): def __init__(self, max_num_tokens: Optional[int], world_size: int, dp_size: int, rank: int): @@ -505,20 +510,31 @@ def dispatch( if self.max_num_tokens is None: self.max_num_tokens = int(tokens_per_expert.max().item()) - b_a1 = torch.zeros((num_experts, self.max_num_tokens, hidden_dim), + rem_experts = num_experts % self.world_size + num_local_experts = ((num_experts // self.world_size) + + (1 if self.rank < rem_experts else 0)) + + b_a1 = torch.zeros((num_local_experts, self.max_num_tokens, hidden_dim), dtype=a1.dtype, device=a1.device) - token_counts = torch.zeros(num_experts, + token_counts = torch.zeros(num_local_experts, dtype=torch.int, device=a1.device) + first_expert = (((num_experts // self.world_size) * self.rank) + + rem_experts - self.rank) + last_expert = first_expert + num_local_experts + #expert_id_range = range(first_expert, last_expert) + for token in range(num_tokens): for j in range(topk): expert_id = topk_ids[token, j] - idx = token_counts[expert_id] - b_a1[expert_id, idx:idx + 1, :] = a1[token, :] - token_counts[expert_id] = token_counts[expert_id] + 1 + if expert_id >= first_expert and expert_id < last_expert: + rel_index = expert_id - first_expert + idx = token_counts[rel_index] + b_a1[rel_index, idx:idx + 1, :] = a1[token, :] + token_counts[rel_index] = token_counts[rel_index] + 1 return b_a1, a1_scale, tokens_per_expert @@ -531,7 +547,8 @@ def combine( apply_router_weight_on_input: bool, ) -> None: num_tokens = topk_ids.shape[0] - num_experts = fused_expert_output.shape[0] + num_local_experts = fused_expert_output.shape[0] + num_experts = num_local_experts * self.world_size # NOT QUITE RIGHT K = fused_expert_output.shape[-1] assert output.shape[0] == num_tokens and output.shape[1] == K expert_counts = torch.zeros( @@ -541,17 +558,21 @@ def combine( output.fill_(0) + first_expert = num_local_experts * self.rank # NOT QUITE RIGHT + last_expert = first_expert + num_local_experts + for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(expert_ids.numel()): expert_id = expert_ids[i] - assert expert_id < num_experts - idx = expert_counts[expert_id] - accum = fused_expert_output[expert_id, idx:idx + 1, :] - if not apply_router_weight_on_input: - accum = accum * topk_weights[token, i] - output[token, :] = output[token, :] + accum - expert_counts[expert_id] = expert_counts[expert_id] + 1 + if expert_id >= first_expert and expert_id < last_expert: + assert expert_id < num_experts + idx = expert_counts[expert_id] + accum = fused_expert_output[expert_id - first_expert, idx:idx + 1, :] + if not apply_router_weight_on_input: + accum = accum * topk_weights[token, i] + output[token, :] = output[token, :] + accum + expert_counts[expert_id] = expert_counts[expert_id] + 1 class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -622,20 +643,26 @@ def apply( num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens, hidden_dim)) - num_local_experts = expert_num_tokens.numel() - assert num_local_experts == w1.shape[0] + num_local_experts = w1.shape[0] #expert_num_tokens.numel() + assert num_local_experts == w1.shape[0], f"{num_local_experts} == {w1.shape[0]}" N = w1.shape[1] // 2 + # Not cudagraph friendly + assert (torch.cuda.is_current_stream_capturing() or + torch.all(expert_num_tokens <= max_num_tokens)), ( + f"{expert_num_tokens} <= {max_num_tokens}") + for expert in range(num_local_experts): - num = expert_num_tokens[expert].item() - assert num <= max_num_tokens, f"{num} <= {max_num_tokens}" - if num > 0: # CUDAGRAPH unfriendly - tmp = _resize_cache(workspace2, (num, N)) - input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) - assert input.shape[1] == N * 2 - self.activation(activation, tmp, input) - out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) + # Indexing expert_num_tokens doesn't work w/cudagraphs + if torch.cuda.is_current_stream_capturing(): + num = max_num_tokens + else: + num = int(expert_num_tokens[expert].item()) + tmp = _resize_cache(workspace2, (num, N)) + input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) + self.activation(activation, tmp, input) + out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) return out diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f461d8c60cc..1d77d3ff30a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -870,7 +870,7 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, - indices_type: torch.dtype = torch.int32, + indices_type: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -881,10 +881,12 @@ def fused_topk( topk, dtype=torch.float32, device=hidden_states.device) - topk_ids = torch.empty(M, - topk, - dtype=indices_type, - device=hidden_states.device) + topk_ids = torch.empty( + M, + topk, + dtype=torch.int32 if indices_type is None else indices_type, + device=hidden_states.device + ) token_expert_indicies = torch.empty(M, topk, dtype=torch.int32, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 025f7f77b9a..7d46052e3b3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -128,7 +128,7 @@ def get_or_create(self, **kwargs): with self._lock: instance = self._cache.get(key) - if instance is None: + if True or instance is None: # TODO: should be intranode instance = pplx.AllToAll.internode(**kwargs) self._cache[key] = instance @@ -152,6 +152,7 @@ def __init__(self, moe: MoEConfig): super().__init__() self.fused_experts = fused_experts self.moe = moe + self.using_pplx = False def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -256,6 +257,8 @@ def set_dispatch_combine( experts: Optional[FusedMoEPermuteExpertsUnpermute] = None + self.using_pplx = False + if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): logger.debug("BatchedTritonExperts %s", self.moe) @@ -267,6 +270,7 @@ def set_dispatch_combine( use_int4_w4a16=False, block_shape=None, ) + self.using_pplx = isinstance(dispatch_combine, PplxDispatchCombine) else: logger.debug("TritonExperts %s", self.moe) experts = TritonExperts( @@ -313,7 +317,8 @@ def forward_cuda( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + indices_type=torch.uint32 if self.using_pplx else None) return self.fused_experts( hidden_states=x, @@ -661,7 +666,7 @@ def _construct_dispatch_combine( max_num_tokens = MOE_DP_CHUNK_SIZE world_size = moe.ep_size - if False and self.dp_size > 1 and has_pplx: + if self.dp_size > 1 and has_pplx: logger.debug("using pplx dispatch") dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank @@ -977,7 +982,8 @@ def select_experts(hidden_states: torch.Tensor, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None): + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None): from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, grouped_topk) @@ -985,6 +991,7 @@ def select_experts(hidden_states: torch.Tensor, if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None + assert indices_type is None or indices_type == torch.int32 topk_weights, topk_ids = grouped_topk( hidden_states=hidden_states, gating_output=router_logits, @@ -999,10 +1006,10 @@ def select_experts(hidden_states: torch.Tensor, gating_output=router_logits, topk=top_k, renormalize=renormalize, - # XXXXX how to do this? - #indices_type=torch.uint32, + indices_type=indices_type, ) else: + assert indices_type is None or indices_type == torch.int32 topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index d605d4d7bc2..e53393afe08 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -105,12 +105,10 @@ def dispatch( ) # This argument is optional, defaults to indices.shape[0] + # There's not much point setting this unless it is != indices.shape[0] #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) bound_m = None - # TODO: optimize this? - #indices = rank_topk_ids.to(dtype=torch.uint32) - self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -130,14 +128,15 @@ def combine( topk_ids: torch.Tensor, apply_router_weight_on_input: bool, ) -> None: - # This argument is optional num_tokens = output.shape[0] # M + # This argument is optional + # There's not much point setting this unless it is != topk_ids.shape[0] #bound_m = torch.tensor([num_tokens], # dtype=torch.uint32, # device=fused_expert_output.device) bound_m = None - assert topk_ids.shape[0] <= num_tokens + assert topk_ids.shape[0] == num_tokens assert output.shape[0] <= self.max_num_tokens, \ f"{output.shape[0]} <= {self.max_num_tokens}" assert output.shape[1] == fused_expert_output.shape[-1] From a003bd8bd2f198ea4e39c78315e421ee2e4d31f3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 6 May 2025 15:23:21 +0000 Subject: [PATCH 175/190] improve ref impl Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 2 +- .../layers/fused_moe/fused_batched_moe.py | 74 ++++++++++--------- vllm/model_executor/layers/fused_moe/layer.py | 6 +- .../layers/fused_moe/pplx_dispatch_combine.py | 1 + 4 files changed, 46 insertions(+), 37 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 8b3fc6bf9a6..a7382f3dadd 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -276,7 +276,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids): @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 0b985793071..f6dbe55cbd4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -491,6 +491,7 @@ def dispatch( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.shape[0] == a1.shape[0] @@ -504,11 +505,13 @@ def dispatch( num_tokens, hidden_dim = a1.shape topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), - minlength=num_experts) - if self.max_num_tokens is None: + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) self.max_num_tokens = int(tokens_per_expert.max().item()) + else: + tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, + device=a1.device) rem_experts = num_experts % self.world_size num_local_experts = ((num_experts // self.world_size) + @@ -518,23 +521,27 @@ def dispatch( dtype=a1.dtype, device=a1.device) - token_counts = torch.zeros(num_local_experts, - dtype=torch.int, - device=a1.device) - first_expert = (((num_experts // self.world_size) * self.rank) + rem_experts - self.rank) last_expert = first_expert + num_local_experts - #expert_id_range = range(first_expert, last_expert) - for token in range(num_tokens): - for j in range(topk): - expert_id = topk_ids[token, j] - if expert_id >= first_expert and expert_id < last_expert: - rel_index = expert_id - first_expert - idx = token_counts[rel_index] - b_a1[rel_index, idx:idx + 1, :] = a1[token, :] - token_counts[rel_index] = token_counts[rel_index] + 1 + # rhs = torch.empty((self.max_num_tokens, hidden_dim), + # dtype=a1.dtype, device=a1.device) + + # for expert_id in range(first_expert, last_expert): + # topks = torch.any(topk_ids == expert_id, dim=1).flatten() + # rows = torch.count_nonzero(topks.flatten()) + # #rhs[:rows] = a1[:topks.numel()][topks] + # topks_idx = topks.nonzero() + # torch.index_select(a1, dim=0, index=topks_idx.flatten(), out=rhs[:rows]) + # b_a1[expert_id - first_expert, :rows, :] = rhs[:rows] + # tokens_per_expert[expert_id - first_expert] = rows + + for expert_id in range(first_expert, last_expert): + topks = torch.any(topk_ids == expert_id, dim=1).flatten() + rows = torch.count_nonzero(topks.flatten()) + b_a1[expert_id - first_expert, :rows, :] = a1[:topks.numel()][topks] + tokens_per_expert[expert_id - first_expert] = rows return b_a1, a1_scale, tokens_per_expert @@ -548,31 +555,32 @@ def combine( ) -> None: num_tokens = topk_ids.shape[0] num_local_experts = fused_expert_output.shape[0] - num_experts = num_local_experts * self.world_size # NOT QUITE RIGHT + topk = topk_weights.shape[1] K = fused_expert_output.shape[-1] assert output.shape[0] == num_tokens and output.shape[1] == K - expert_counts = torch.zeros( - num_experts, - dtype=torch.int, - device=fused_expert_output.device) output.fill_(0) first_expert = num_local_experts * self.rank # NOT QUITE RIGHT last_expert = first_expert + num_local_experts - for token in range(num_tokens): - expert_ids = topk_ids[token] - for i in range(expert_ids.numel()): - expert_id = expert_ids[i] - if expert_id >= first_expert and expert_id < last_expert: - assert expert_id < num_experts - idx = expert_counts[expert_id] - accum = fused_expert_output[expert_id - first_expert, idx:idx + 1, :] - if not apply_router_weight_on_input: - accum = accum * topk_weights[token, i] - output[token, :] = output[token, :] + accum - expert_counts[expert_id] = expert_counts[expert_id] + 1 + # for expert_id in range(first_expert, last_expert): + # topkws = topk_ids == expert_id + # topks = torch.any(topkws, dim=1).flatten() + # outrhs = output[topks] + # rhs = fused_expert_output[expert_id - first_expert, :outrhs.shape[0], :] + # if not apply_router_weight_on_input: + # rhs.mul_(topk_weights[topkws].view(rhs.shape[0], 1)) + # output[topks] = outrhs + rhs + + for expert_id in range(first_expert, last_expert): + topkws = topk_ids == expert_id + topks = torch.any(topkws, dim=1).flatten() + rows = torch.count_nonzero(topks) + rhs = fused_expert_output[expert_id - first_expert, :rows, :] + if not apply_router_weight_on_input: + rhs.mul_(topk_weights[topkws].view(rhs.shape[0], 1)) + output[topks] = output[topks] + rhs class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7d46052e3b3..ff680e06aa1 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -262,7 +262,7 @@ def set_dispatch_combine( if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): logger.debug("BatchedTritonExperts %s", self.moe) - experts = BatchedTritonExperts( + experts = BatchedExperts( max_num_tokens=MOE_DP_CHUNK_SIZE, use_fp8_w8a8=False, use_int8_w8a8=False, @@ -695,8 +695,6 @@ def _construct_dispatch_combine( rank, moe.in_dtype, ) - elif False: - return None elif self.dp_size > 1: logger.debug("using batched dispatch") dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. @@ -707,6 +705,8 @@ def _construct_dispatch_combine( dp_size=dp_size, rank=rank, ) + elif False: + return None else: logger.debug("using standard dispatch") return StandardDispatchCombine( diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index e53393afe08..d46d76b407c 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -72,6 +72,7 @@ def dispatch( per_act_token, self.block_shape) + # TODO: does rem_experts need to be 0 for pplx to work properly? rem_experts = num_experts % self.world_size num_local_experts = ((num_experts // self.world_size) + (1 if self.rank < rem_experts else 0)) From 2bafbe0b794736de99258e5733525ab6b675a9db Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 6 May 2025 16:10:12 +0000 Subject: [PATCH 176/190] wip Signed-off-by: Bill Nell --- pyproject.toml | 4 +- tests/kernels/moe/test_cutlass_moe.py | 24 ++++----- tests/kernels/moe/test_pplx_moe.py | 6 +-- .../layers/fused_moe/fused_moe.py | 53 ++++++++++--------- vllm/model_executor/layers/fused_moe/layer.py | 13 ++--- 5 files changed, 51 insertions(+), 49 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4fa012145cf..069e295bfb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,8 @@ build-backend = "setuptools.build_meta" [project] name = "vllm" authors = [{name = "vLLM Team"}] -#license = "Apache-2.0" -#license-files = ["LICENSE"] +license = "Apache-2.0" +license-files = ["LICENSE"] readme = "README.md" description = "A high-throughput and memory-efficient inference and serving engine for LLMs" classifiers = [ diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 7d24307e353..7db4fe0f46e 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -241,10 +241,10 @@ def test_cutlass_moe_8_bit_no_graph( per_out_ch) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. @@ -285,10 +285,10 @@ def test_cutlass_moe_8_bit_cuda_graph( per_out_ch) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. @@ -338,10 +338,10 @@ def test_cutlass_moe_8_bit_EP( per_out_channel) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index a7382f3dadd..87377418cf9 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -296,7 +296,7 @@ def test_fused_moe_batched_experts( score = torch.randn((m, e), device="cuda", dtype=dtype) with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) @@ -404,7 +404,7 @@ def _pplx_dispatch_combine( nvshmem_init(uid, pgi.rank, pgi.world_size) device = pgi.device - topk_weight, topk_ids = fused_topk(a, score, topk, False) + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) k = a.shape[1] a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) @@ -577,7 +577,7 @@ def _pplx_moe( e, _, n = w2.shape with set_current_vllm_config(vllm_config): - topk_weight, topk_ids = fused_topk(a, score, topk, False) + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1d77d3ff30a..fd2cb1e2644 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -887,10 +887,10 @@ def fused_topk( dtype=torch.int32 if indices_type is None else indices_type, device=hidden_states.device ) - token_expert_indicies = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + token_expert_indices = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. @@ -1211,28 +1211,29 @@ def fused_experts(hidden_states: torch.Tensor, def fused_experts_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_fp8_w8a8: bool = False, - use_int8_w8a8: bool = False, - use_int8_w8a16: bool = False, - use_int4_w4a16: bool = False, - per_channel_quant: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - w1_zp: Optional[torch.Tensor] = None, - w2_zp: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, - block_shape: Optional[List[int]] = None) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +) -> torch.Tensor: # Check constraints. if use_int4_w4a16: assert hidden_states.shape[1] // 2 == w1.shape[ diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ff680e06aa1..1bd07d8eacd 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1002,12 +1002,13 @@ def select_experts(hidden_states: torch.Tensor, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) elif custom_routing_function is None: - topk_weights, topk_ids, token_expert_indices = fused_topk(hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - indices_type=indices_type, - ) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + indices_type=indices_type, + ) else: assert indices_type is None or indices_type == torch.int32 topk_weights, topk_ids = custom_routing_function( From ca763c3d533ff41afd77e8dd9f46097510a5e380 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 6 May 2025 18:06:53 +0000 Subject: [PATCH 177/190] fix merge Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 87377418cf9..a2c31ad3c3e 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -341,6 +341,8 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): torch.float32.itemsize)), ) + topk_ids = topk_ids.to(dtype=torch.uint32) + dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, @@ -478,6 +480,8 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): torch.float32.itemsize)), ) + topk_ids = topk_ids.to(dtype=torch.uint32) + dispatch_combine = PplxDispatchCombine( ata, max_num_tokens, From 054c10a53deece370ebcbd2f9bf7872ca341fc47 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 6 May 2025 19:08:41 +0000 Subject: [PATCH 178/190] fix merge Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index a2c31ad3c3e..bacc23cdfc5 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -432,7 +432,7 @@ def _pplx_dispatch_combine( @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @requires_pplx def test_pplx_dispatch_combine( @@ -584,13 +584,13 @@ def _pplx_moe( topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) - batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) + #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) - torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) + #torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() From 0851b31eed51e3abb6ed278202fb4361003d5559 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 1 May 2025 16:05:58 -0400 Subject: [PATCH 179/190] wip Signed-off-by: Bill Nell --- examples/offline_inference/data_parallel.py | 10 +- tests/kernels/moe/test_pplx_moe.py | 9 +- .../layers/fused_moe/fused_batched_moe.py | 226 ++------- vllm/model_executor/layers/fused_moe/layer.py | 432 ++++++++++++------ .../layers/fused_moe/modular_kernel.py | 1 + .../layers/fused_moe/pplx_dispatch_combine.py | 5 +- .../model_executor/layers/quantization/fp8.py | 7 +- vllm/model_executor/models/deepseek_v2.py | 18 +- vllm/model_executor/models/granitemoe.py | 6 + vllm/model_executor/models/llama4.py | 2 +- vllm/model_executor/models/qwen2_moe.py | 2 +- vllm/model_executor/models/qwen3_moe.py | 2 +- 12 files changed, 360 insertions(+), 360 deletions(-) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 94286dcc816..f48f64ba8e4 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -69,11 +69,14 @@ def parse_args(): parser.add_argument("--enforce-eager", action='store_true', help="Enforce eager mode execution.") + parser.add_argument("--trust-remote-code", + action='store_true', + help="Trust remote code.") return parser.parse_args() def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, - dp_master_port, GPUs_per_dp_rank, enforce_eager): + dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) @@ -125,6 +128,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, enforce_eager=enforce_eager, enable_expert_parallel=True, compilation_config=cconfig, + trust_remote_code=trust_remote_code, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -168,12 +172,12 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, proc = Process(target=main, args=(args.model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, dp_master_port, - tp_size, args.enforce_eager)) + tp_size, args.enforce_eager, args.trust_remote_code)) proc.start() procs.append(proc) exit_code = 0 for proc in procs: - proc.join(timeout=3000) + proc.join(timeout=300) if proc.exitcode is None: print(f"Killing process {proc.pid} that " f"didn't stop within 5 minutes.") diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index bacc23cdfc5..c9cf6bf8905 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -347,8 +347,9 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts): ata, max_num_tokens, world_size, - dp_size, rank, + dp_size, + a.dtype, ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) @@ -486,8 +487,8 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): ata, max_num_tokens, world_size, - dp_size, rank, + dp_size, ) experts = BatchedExperts(a.shape[0]) @@ -584,13 +585,13 @@ def _pplx_moe( topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) - #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) + batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) - #torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) nvshmem_finalize() diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index f6dbe55cbd4..b9732b3f68e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -587,6 +587,8 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + world_size: int, + dp_size: int, max_num_tokens: Optional[int] = None, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, @@ -603,6 +605,8 @@ def __init__( assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" self.max_num_tokens = max_num_tokens + self.world_size = world_size + self.dp_size = dp_size def workspace_shapes( self, @@ -614,10 +618,12 @@ def workspace_shapes( num_experts: int, ) -> Tuple[int, int, torch.dtype]: assert a.dim() == 2 + num_dp = self.world_size // self.dp_size max_num_tokens = a.shape[ 0] if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * K - workspace2 = max_num_tokens * N + #print(f"WORKSPACE {max_num_tokens} {num_dp}") + workspace13 = num_experts * max_num_tokens * num_dp * K + workspace2 = max_num_tokens * num_dp * N return (workspace13, workspace2, a.dtype) def apply( @@ -648,23 +654,24 @@ def apply( else: max_num_tokens = self.max_num_tokens + num_dp = self.world_size // self.dp_size num_experts = global_num_experts out = _resize_cache(workspace13, - (num_experts, max_num_tokens, hidden_dim)) + (num_experts, max_num_tokens * num_dp, hidden_dim)) num_local_experts = w1.shape[0] #expert_num_tokens.numel() assert num_local_experts == w1.shape[0], f"{num_local_experts} == {w1.shape[0]}" N = w1.shape[1] // 2 # Not cudagraph friendly - assert (torch.cuda.is_current_stream_capturing() or - torch.all(expert_num_tokens <= max_num_tokens)), ( - f"{expert_num_tokens} <= {max_num_tokens}") + # assert (torch.cuda.is_current_stream_capturing() or + # torch.all(expert_num_tokens <= max_num_tokens)), ( + # f"{expert_num_tokens} <= {max_num_tokens}") for expert in range(num_local_experts): # Indexing expert_num_tokens doesn't work w/cudagraphs - if torch.cuda.is_current_stream_capturing(): - num = max_num_tokens + if True or torch.cuda.is_current_stream_capturing(): + num = max_num_tokens * num_dp else: num = int(expert_num_tokens[expert].item()) tmp = _resize_cache(workspace2, (num, N)) @@ -675,166 +682,6 @@ def apply( return out -def _apply( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - block_shape: Optional[List[int]], -) -> torch.Tensor: - # Check constraints. - if use_int4_w4a16: - assert hidden_states.shape[-1] // 2 == w1.shape[ - 2], "Hidden size mismatch" - else: - assert hidden_states.shape[-1] == w1.shape[2], \ - (f"Hidden size mismatch {hidden_states.shape[-1]} " - f"!= {w1.shape[2]}") - - assert hidden_states.is_contiguous( - ), "Hidden_states must be contiguous" - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn - ] - - # TODO: num_tokens -> max_num_tokens? - E, num_tokens, N, K, top_k_num = mk._moe_problem_size( - hidden_states, w1, w2, topk_ids) - - assert w1.shape[0] == E - assert w2.shape[0] == E - - config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - dtype=hidden_states.dtype) - - config = try_get_optimal_moe_config( - w1.shape, - w2.shape, - top_k_num, - config_dtype, - num_tokens, - block_shape=block_shape, - ) - - if hidden_states.dtype == torch.bfloat16: - compute_type = tl.bfloat16 - elif hidden_states.dtype == torch.float16: - compute_type = tl.float16 - elif hidden_states.dtype == torch.float32: - compute_type = tl.float32 - elif hidden_states.dtype == torch.float8_e4m3fn: - compute_type = tl.bfloat16 - else: - raise ValueError( - f"Unsupported compute_type: {hidden_states.dtype}") - - #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N)) - intermediate_cache2 = _resize_cache(workspace2, - (E, num_tokens, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K)) - - # MM1 - invoke_moe_batched_triton_kernel(A=hidden_states, - B=w1, - C=intermediate_cache1, - expert_num_tokens=expert_num_tokens, - compute_type=compute_type, - A_scale=a1q_scale, - B_scale=w1_scale, - B_zp=w1_zp, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - config=config, - block_shape=block_shape) - - # Fix activations - assert activation == "silu" - invoke_batched_silu_and_mul(output=intermediate_cache2, - input=intermediate_cache1, - expert_num_tokens=expert_num_tokens) - - #qintermediate_cache2 = intermediate_cache2 - a2q_scale = a2_scale - # TODO (varun) : support w8a8 - assert not use_fp8_w8a8 - #if self.use_fp8_w8a8: - # qintermediate_cache2, a2q_scale = _fp8_quantize( - # intermediate_cache2, a2_scale, self.block_shape) - - invoke_moe_batched_triton_kernel(A=intermediate_cache2, - B=w2, - C=intermediate_cache3, - expert_num_tokens=expert_num_tokens, - compute_type=compute_type, - A_scale=a2q_scale, - B_scale=w2_scale, - B_zp=w2_zp, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - config=config, - block_shape=block_shape) - - return intermediate_cache3 - - -def _apply_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - block_shape: Optional[List[int]], -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -direct_register_custom_op( - op_name="_apply", - op_func=_apply, - mutates_args=[], - fake_impl=_apply_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), -) - - class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -845,6 +692,8 @@ def __init__( use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, block_shape: Optional[List[int]] = None, + world_size: int = 1, + dp_size: int = 1, ): super().__init__() self.use_fp8_w8a8 = use_fp8_w8a8 @@ -855,6 +704,8 @@ def __init__( self.max_num_tokens = max_num_tokens assert not use_int8_w8a8, "NYI" assert not use_int4_w4a16, "NYI" + self.world_size = world_size + self.dp_size = dp_size def workspace_shapes( self, @@ -866,10 +717,11 @@ def workspace_shapes( num_experts: int, ) -> Tuple[int, int, torch.dtype]: assert a.dim() == 2 + num_dp = self.world_size // self.dp_size max_num_tokens = a.shape[ 0] if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * max(K, N) - workspace2 = num_experts * max_num_tokens * (N // 2) + workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) + workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) return (workspace13, workspace2, a.dtype) def apply( @@ -891,29 +743,6 @@ def apply( workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], ) -> torch.Tensor: - return torch.ops.vllm._apply( - hidden_states, - w1, - w2, - topk_ids, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1q_scale, - a2_scale, - workspace13, - workspace2, - expert_num_tokens, - self.use_fp8_w8a8, - self.use_int8_w8a16, - self.use_int4_w4a16, - self.block_shape, - ) - # Check constraints. if self.use_int4_w4a16: assert hidden_states.shape[-1] // 2 == w1.shape[ @@ -988,10 +817,13 @@ def apply( block_shape=self.block_shape) # Fix activations - assert activation == "silu" - invoke_batched_silu_and_mul(output=intermediate_cache2, - input=intermediate_cache1, - expert_num_tokens=expert_num_tokens) + # assert activation == "silu" + # invoke_batched_silu_and_mul(output=intermediate_cache2, + # input=intermediate_cache1, + # expert_num_tokens=expert_num_tokens) + self.activation(activation, + intermediate_cache2.view(-1, N//2), + intermediate_cache1.view(-1, N)) #qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1bd07d8eacd..6c337269d90 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,7 +13,7 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import get_current_vllm_config +from vllm.config import get_current_vllm_config, ParallelConfig from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -53,6 +53,112 @@ MOE_DP_CHUNK_SIZE = 256 +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + + use_ep: bool # whether to use EP or not + + @property + def use_pplx_kernels(self): + return self.use_ep and has_pplx + + @staticmethod + def make(tp_size_: int, dp_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, dp_size_, + ep_size_ and vllm's parallel config, determine what level's of parallelism + to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vllm's parallel config object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size = tp_size, + tp_rank = tp_rank, + dp_size = dp_size, + dp_rank = dp_rank, + ep_size = 1, + ep_rank = 0, + use_ep = False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor parallel. + # Update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size = 1, + tp_rank = 0, + dp_size = dp_size, + dp_rank = dp_rank, + ep_size = ep_size, + ep_rank = ep_rank, + use_ep = True) + # Adapted from pplx-kernels tests/all_to_all_utils.py @dataclass class MoEConfig: @@ -61,16 +167,45 @@ class MoEConfig: hidden_dim: int num_local_experts: int - dp_size: int - dp_rank: int - ep_size: int - ep_rank: int + moe_parallel_config: FusedMoEParallelConfig in_dtype: torch.dtype # The activation type. # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -88,7 +223,11 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, raise NotImplementedError def set_dispatch_combine( - self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + self, + dp_size: int, + world_size: int, + dispatch_combine: FusedMoEQuantizeDispatchCombine, + ) -> bool: return False @abstractmethod @@ -119,16 +258,22 @@ def __init__(self): self._cache = weakref.WeakValueDictionary() self._lock = threading.RLock() # Reentrant lock for thread safety + def __del__(self): + logger.info("Deleting AllToAllCache") + def get_or_create(self, **kwargs): assert has_pplx import pplx_kernels as pplx + if False: + return pplx.AllToAll.internode(**kwargs) + # Create a hashable key from the kwargs key = tuple(sorted((k, v) for k, v in kwargs.items())) with self._lock: instance = self._cache.get(key) - if True or instance is None: + if instance is None: # TODO: should be intranode instance = pplx.AllToAll.internode(**kwargs) self._cache[key] = instance @@ -252,7 +397,11 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input) def set_dispatch_combine( - self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool: + self, + dp_size: int, + world_size: int, + dispatch_combine: FusedMoEQuantizeDispatchCombine, + ) -> bool: assert self.fused_experts == fused_experts experts: Optional[FusedMoEPermuteExpertsUnpermute] = None @@ -262,8 +411,10 @@ def set_dispatch_combine( if isinstance(dispatch_combine, (BatchedDispatchCombine, PplxDispatchCombine)): logger.debug("BatchedTritonExperts %s", self.moe) - experts = BatchedExperts( + experts = BatchedTritonExperts( max_num_tokens=MOE_DP_CHUNK_SIZE, + world_size=world_size, + dp_size=dp_size, use_fp8_w8a8=False, use_int8_w8a8=False, use_int8_w8a16=False, @@ -487,6 +638,61 @@ def determine_expert_map( return (local_num_experts, expert_map) +def _construct_dispatch_combine( + moe: MoEConfig, + quant_config: Optional[QuantizationConfig] +) -> Optional[FusedMoEQuantizeDispatchCombine]: + max_num_tokens = MOE_DP_CHUNK_SIZE + world_size = moe.ep_size + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + rank = moe.ep_rank + + if moe.use_ep and has_pplx: + logger.debug("using pplx dispatch") + + all_to_all = get_all_to_all( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=rank, + world_size=world_size, + dp_size= dp_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( + (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize))) + + return PplxDispatchCombine( + all_to_all, + max_num_tokens=max_num_tokens, + world_size=world_size, + rank=rank, + dp_size=dp_size, + quant_dtype=moe.in_dtype, + ) + elif moe.use_ep: + logger.debug("using batched dispatch") + return BatchedDispatchCombine( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + rank=rank, + ) + elif True: + return None + else: + logger.debug("using standard dispatch") + return StandardDispatchCombine( + moe.in_dtype, + quant_config.weight_block_size + if quant_config is not None else None, + ) + + class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -537,21 +743,13 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - # Note: here we guard against accessing the TP and DP groups when - # uninitialized (this happens when testing) - self.tp_size = (tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()) - tp_rank = 0 if self.tp_size == 1 else get_tensor_model_parallel_rank() - self.dp_size = (dp_size - if dp_size is not None else get_dp_group().world_size) - self.dp_rank = (0 - if self.dp_size == 1 else get_dp_group().rank_in_group) - self.global_num_experts = num_experts - - # Use expert parallelism instead of tensor parallelism? vllm_config = get_current_vllm_config() - use_ep = (vllm_config.parallel_config.enable_expert_parallel - and self.tp_size * self.dp_size > 1) + self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_ = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()), + dp_size_ = (dp_size if dp_size is not None else get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config) + + self.global_num_experts = num_experts # For smuggling this layer into the fused moe custom op self.use_direct_call = self.dp_size == 1 @@ -562,26 +760,15 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix - if use_ep: - # Set TP size to 1 to adjust for EP and adjust EP size and rank - # for DP attention. - self.ep_rank = tp_rank + self.tp_size * self.dp_rank - self.tp_rank = 0 - self.ep_size = self.tp_size * self.dp_size - self.tp_size = 1 - + # Determine expert maps + if self.use_ep: self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) else: - # Adjust TP size for DP attention - self.tp_rank = tp_rank + self.tp_size * self.dp_rank - self.ep_rank = 0 - self.tp_size = self.tp_size * self.dp_size - self.ep_size = 1 - self.local_num_experts = self.global_num_experts - self.expert_map = None + self.local_num_experts, self.expert_map = (self.global_num_experts, None) + self.top_k = top_k assert intermediate_size % self.tp_size == 0 @@ -611,11 +798,8 @@ def __init__( experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, - dp_size=self.dp_size, - dp_rank=self.dp_rank, - ep_size=self.ep_size, - ep_rank=self.ep_rank, - in_dtype=params_dtype, # TODO: is this right? + moe_parallel_config=self.moe_parallel_config, + in_dtype=params_dtype, # TODO: is this right? ) # Note: get_quant_method will look at the layer's local_num_experts @@ -631,10 +815,13 @@ def __init__( assert quant_method is not None self.quant_method = quant_method - dispatch_combine = self._construct_dispatch_combine(moe, quant_config) + dispatch_combine = _construct_dispatch_combine(moe, quant_config) if dispatch_combine is not None: - success = self.quant_method.set_dispatch_combine(dispatch_combine) + world_size = moe.ep_size + dp_size = moe.ep_size // moe.dp_size + success = self.quant_method.set_dispatch_combine( + dp_size, world_size, dispatch_combine) if not success: logger.warning("DP+EP not supported for %s.", type(self.quant_method)) @@ -657,63 +844,37 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) - # TODO: return Optional? - def _construct_dispatch_combine( - self, - moe: MoEConfig, - quant_config: Optional[QuantizationConfig], - ) -> Optional[FusedMoEQuantizeDispatchCombine]: - max_num_tokens = MOE_DP_CHUNK_SIZE - world_size = moe.ep_size - - if self.dp_size > 1 and has_pplx: - logger.debug("using pplx dispatch") - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. - rank = moe.ep_rank - - all_to_all = get_all_to_all( - max_num_tokens=max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=rank, - world_size=world_size, - dp_size=dp_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( - (moe.hidden_dim + moe.block_size - 1) // moe.block_size * - torch.float32.itemsize))) - - return PplxDispatchCombine( - all_to_all, - max_num_tokens, - world_size, - dp_size, - rank, - moe.in_dtype, - ) - elif self.dp_size > 1: - logger.debug("using batched dispatch") - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. - rank = moe.ep_rank - return BatchedDispatchCombine( - max_num_tokens=max_num_tokens, - world_size=world_size, - dp_size=dp_size, - rank=rank, - ) - elif False: - return None - else: - logger.debug("using standard dispatch") - return StandardDispatchCombine( - moe.in_dtype, - quant_config.weight_block_size - if quant_config is not None else None, - ) + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, @@ -991,7 +1152,6 @@ def select_experts(hidden_states: torch.Tensor, if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - assert indices_type is None or indices_type == torch.int32 topk_weights, topk_ids = grouped_topk( hidden_states=hidden_states, gating_output=router_logits, @@ -1001,6 +1161,8 @@ def select_experts(hidden_states: torch.Tensor, topk_group=topk_group, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) elif custom_routing_function is None: topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states=hidden_states, @@ -1010,12 +1172,13 @@ def select_experts(hidden_states: torch.Tensor, indices_type=indices_type, ) else: - assert indices_type is None or indices_type == torch.int32 topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize) + if indices_type is not None: + topk_ids = topk_ids.to(dtype=indices_type) return topk_weights, topk_ids @@ -1037,6 +1200,19 @@ def naive_multicast(self, x: torch.Tensor, return buffer + def must_reduce_shared_outputs(self) -> bool: + return self.dp_size > 1 and self.use_ep and has_pplx + + def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): + """ + The pplx combine kernel reduce across GPU ranks by default. The pplx kernels are + used when EP is enabled. In that case, this function is a no-op. + """ + if self.dp_size > 1 and self.use_ep and has_pplx: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): if self.use_direct_call: @@ -1054,13 +1230,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - # TODO: still may be needed for non-pplx, put into dispatcher class. - if False: - hidden_states = self.naive_multicast( - hidden_states, cu_tokens_across_dp_this_iter) - router_logits = self.naive_multicast( - router_logits, cu_tokens_across_dp_this_iter) - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1079,33 +1248,13 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): activation=self.activation, ) - # TODO: needed for non-pplx? - if False and self.dp_size > 1: - if self.dp_rank == 0: - start = 0 - else: - start = cu_tokens_across_dp_this_iter[self.dp_rank - 1] - - end = cu_tokens_across_dp_this_iter[self.dp_rank] - - all_hidden_states = get_dp_group().all_reduce( - final_hidden_states) - final_hidden_states = all_hidden_states[start:end, :] - - # TODO: needed for non-pplx? - if False and self.reduce_results and (self.tp_size > 1 - or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs.) - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - if not skip_result_store: full_final_hidden_states[chunk_start:chunk_end, :].copy_( final_hidden_states) - max_tokens_across_dp = get_forward_context( - ).dp_metadata.max_tokens_across_dp - moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE // self.dp_size + ctx = get_forward_context() + max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp + moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, @@ -1126,9 +1275,10 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + if self.dp_size > 1 and self.use_ep and has_pplx: + return self.forward_impl_chunked(hidden_states, router_logits) - # TODO: still may be needed for non-pplx - if False and self.dp_size > 1: + if self.dp_size > 1: ctx = get_forward_context() cu_tokens_across_dp_cpu = ctx.dp_metadata.cu_tokens_across_dp_cpu @@ -1156,8 +1306,7 @@ def forward_impl(self, hidden_states: torch.Tensor, apply_router_weight_on_input=self.apply_router_weight_on_input, ) - # TODO: needed for non-pplx? - if False and self.dp_size > 1: + if self.dp_size > 1: start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ self.dp_rank - 1] end = cu_tokens_across_dp_cpu[self.dp_rank] @@ -1165,9 +1314,8 @@ def forward_impl(self, hidden_states: torch.Tensor, all_hidden_states = get_dp_group().all_reduce(final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - # TODO: needed for non-pplx? - if False and self.reduce_results and (self.tp_size > 1 - or self.ep_size > 1): + if self.reduce_results and (self.tp_size > 1 + or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -1219,7 +1367,7 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None - return self.forward_impl_chunked(hidden_states, router_logits) + return self.forward_impl(hidden_states, router_logits) def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index fce8bd8091d..299d98c7f15 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -171,6 +171,7 @@ def workspace_shapes( def activation(self, activation: str, output: torch.Tensor, input: torch.Tensor) -> None: + assert output.shape[-1] * 2 == input.shape[-1] if activation == "silu": torch.ops._C.silu_and_mul(output, input) elif activation == "gelu": diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index d46d76b407c..002f689d585 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -23,8 +23,8 @@ def __init__(self, a2a: pplx.AllToAll, max_num_tokens: int, world_size: int, - dp_size: int, rank: int, + dp_size: int, quant_dtype: Optional[torch.dtype] = None, block_shape: Optional[List[int]] = None): super().__init__() @@ -33,8 +33,8 @@ def __init__(self, self.block_shape = block_shape self.max_num_tokens = max_num_tokens self.world_size = world_size - self.dp_size = dp_size self.rank = rank + self.dp_size = dp_size self.quant_dtype = quant_dtype def dispatch( @@ -119,6 +119,7 @@ def dispatch( indices=rank_topk_ids, bound_m=bound_m, ) + return expert_x, expert_x_scale, expert_num_tokens def combine( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9156cb568f9..86f48296c3a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -778,8 +778,11 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w2_input_scale def set_dispatch_combine( - self, - dispatch_combine: mk.FusedMoEQuantizeDispatchCombine) -> bool: + self, + dp_size: int, + world_size: int, + dispatch_combine: mk.FusedMoEQuantizeDispatchCombine, + ) -> bool: from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 9e9b8d336de..bbc2645bf81 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -32,8 +32,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -143,7 +142,13 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + # When just tensor-parallel is used, it isn't required + # to reduce the shared_output result. Instead we reduce + # at the end of the forward pass. + # With EP and the pplx kernels - this is no longer viable + # as all GPU ranks in DP, produce the complete set of hidden_states. + # Therefore reduce the shared experts early. + reduce_results=self.experts.must_reduce_shared_outputs(), prefix=f"{prefix}.shared_experts", ) @@ -154,6 +159,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) + if hidden_states.dtype != torch.float16: final_hidden_states = self.experts( hidden_states=hidden_states, @@ -172,10 +178,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = final_hidden_states + shared_output \ * (1. / self.routed_scaling_factor) - # TODO: check if needed for non-pplx? - if False and self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 7fff14cb9f1..09bbeea9b13 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -70,6 +70,7 @@ def __init__(self, prefix: str = ""): super().__init__() self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(hidden_size, @@ -97,6 +98,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) + + # Needed? + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) + return final_hidden_states.view(orig_shape) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 0fdc30f36f9..68e427d272c 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -102,7 +102,7 @@ def forward(self, hidden_states): experts_out = routed_out + shared_out if self.tp_size > 1: - experts_out = tensor_model_parallel_all_reduce(experts_out) + experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(experts_out) return experts_out diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 47d90919ed8..aca0d658d88 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -156,7 +156,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index fe6b303ba0b..655b4d25fc7 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -137,7 +137,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits=router_logits) final_hidden_states = final_hidden_states if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states) return final_hidden_states.view(orig_shape) From 3e2cf4b80515f18a468a46092dd56cc179b5f6b4 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 7 May 2025 02:30:59 -0400 Subject: [PATCH 180/190] zero out attn outputs during profile run Signed-off-by: Varun Sundar Rabindranath Signed-off-by: Bill Nell --- vllm/v1/attention/backends/mla/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0d18a5639c2..b4f8317e096 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -882,7 +882,7 @@ def forward( if attn_metadata is None: # Profiling run. - return output + return output.fill_(0) num_actual_toks = attn_metadata.num_actual_tokens From d2862c0800e176f9c73db1e499353d87efd0be86 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 7 May 2025 15:10:25 +0000 Subject: [PATCH 181/190] lint Signed-off-by: Bill Nell --- csrc/activation_kernels.cu | 4 +- csrc/dispatch_utils.h | 17 ++-- examples/offline_inference/data_parallel.py | 23 ++--- tests/kernels/moe/test_batched_moe.py | 3 +- tests/kernels/moe/test_moe.py | 4 +- tests/kernels/moe/test_pplx_moe.py | 26 +++--- .../layers/fused_moe/fused_batched_moe.py | 27 +++--- .../layers/fused_moe/fused_moe.py | 16 +--- vllm/model_executor/layers/fused_moe/layer.py | 92 +++++++------------ .../layers/fused_moe/moe_permute_unpermute.py | 21 +++-- .../layers/fused_moe/pplx_dispatch_combine.py | 23 ++--- vllm/model_executor/layers/fused_moe/utils.py | 3 +- vllm/model_executor/models/deepseek_v2.py | 8 +- vllm/model_executor/models/granitemoe.py | 4 +- vllm/model_executor/models/llama4.py | 6 +- vllm/model_executor/models/qwen2_moe.py | 4 +- vllm/model_executor/models/qwen3_moe.py | 4 +- 17 files changed, 123 insertions(+), 162 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 0c020be65ff..55e65967970 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -70,7 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ - if (num_tokens == 0) { return; } \ + if (num_tokens == 0) { \ + return; \ + } \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 10a183dc950..f7b75c48373 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -66,17 +66,18 @@ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) #define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__) #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index f48f64ba8e4..f636a08c0b0 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -31,7 +31,6 @@ from time import sleep from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig from vllm.utils import get_open_port @@ -116,20 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, max_tokens=[16, 20][global_dp_rank % 2]) # Create an LLM. - cconfig = CompilationConfig( - level=3, - #cudagraph_capture_sizes=[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208], - #cudagraph_capture_sizes=[512,256,1], - #cudagraph_capture_sizes=[192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1] - #cudagraph_capture_sizes=[128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1] + llm = LLM( + model=model, + tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, + enable_expert_parallel=True, + trust_remote_code=trust_remote_code, ) - llm = LLM(model=model, - tensor_parallel_size=GPUs_per_dp_rank, - enforce_eager=enforce_eager, - enable_expert_parallel=True, - compilation_config=cconfig, - trust_remote_code=trust_remote_code, - ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for i, output in enumerate(outputs): @@ -172,7 +164,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, proc = Process(target=main, args=(args.model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, dp_master_port, - tp_size, args.enforce_eager, args.trust_remote_code)) + tp_size, args.enforce_eager, + args.trust_remote_code)) proc.start() procs.append(proc) exit_code = 0 diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 39b5d5c6793..f9f3f6506a5 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -62,7 +62,8 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, @pytest.mark.parametrize("num_experts", [16, 32]) -@pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512]) +@pytest.mark.parametrize("max_tokens_per_expert", + [32, 64, 128, 192, 224, 256, 512]) @pytest.mark.parametrize("K", [128, 256, 1024]) @pytest.mark.parametrize("N", [128, 256, 512, 1024]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 85df55d0ac1..befd07075c9 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -11,8 +11,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, - torch_moe_single) +from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -27,7 +26,6 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types -from vllm.model_executor.layers.activation import SiluAndMul NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index c9cf6bf8905..d30f4cef3bb 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -28,8 +28,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedDispatchCombine, - BatchedExperts) + BatchedDispatchCombine, BatchedExperts) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) @@ -246,15 +245,9 @@ def batched_moe(a, w1, w2, topk_weight, topk_ids): fused_experts = FusedMoEModularKernel( BatchedDispatchCombine(a.shape[0], world_size=1, dp_size=1, rank=0), - BatchedExperts(a.shape[0]) - ) + BatchedExperts(a.shape[0])) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - num_experts) + return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) # TODO: same as torch_moe but with fused_topk factored out. @@ -301,9 +294,15 @@ def test_fused_moe_batched_experts( torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) - torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(baseline_output, + torch_output, + atol=2e-2, + rtol=0) torch.set_printoptions(profile="full") - torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0) + torch.testing.assert_close(baseline_output, + batched_output, + atol=2e-2, + rtol=0) def rank_chunk(num, r, w): @@ -585,7 +584,8 @@ def _pplx_moe( topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) - batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids) + batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, + topk_ids) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index b9732b3f68e..d9143619224 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -10,7 +10,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import _resize_cache -from vllm.utils import direct_register_custom_op @triton.jit @@ -473,7 +472,8 @@ def rank_chunk(num, r, w): class BatchedDispatchCombine(mk.FusedMoEQuantizeDispatchCombine): - def __init__(self, max_num_tokens: Optional[int], world_size: int, dp_size: int, rank: int): + def __init__(self, max_num_tokens: Optional[int], world_size: int, + dp_size: int, rank: int): super().__init__() self.world_size = world_size self.dp_size = dp_size @@ -510,16 +510,18 @@ def dispatch( minlength=num_experts) self.max_num_tokens = int(tokens_per_expert.max().item()) else: - tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, + tokens_per_expert = torch.zeros(num_experts, + dtype=torch.int, device=a1.device) rem_experts = num_experts % self.world_size num_local_experts = ((num_experts // self.world_size) + (1 if self.rank < rem_experts else 0)) - b_a1 = torch.zeros((num_local_experts, self.max_num_tokens, hidden_dim), - dtype=a1.dtype, - device=a1.device) + b_a1 = torch.zeros( + (num_local_experts, self.max_num_tokens, hidden_dim), + dtype=a1.dtype, + device=a1.device) first_expert = (((num_experts // self.world_size) * self.rank) + rem_experts - self.rank) @@ -540,7 +542,8 @@ def dispatch( for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() rows = torch.count_nonzero(topks.flatten()) - b_a1[expert_id - first_expert, :rows, :] = a1[:topks.numel()][topks] + b_a1[expert_id - + first_expert, :rows, :] = a1[:topks.numel()][topks] tokens_per_expert[expert_id - first_expert] = rows return b_a1, a1_scale, tokens_per_expert @@ -561,7 +564,7 @@ def combine( output.fill_(0) - first_expert = num_local_experts * self.rank # NOT QUITE RIGHT + first_expert = num_local_experts * self.rank # NOT QUITE RIGHT last_expert = first_expert + num_local_experts # for expert_id in range(first_expert, last_expert): @@ -658,8 +661,9 @@ def apply( num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens * num_dp, hidden_dim)) - num_local_experts = w1.shape[0] #expert_num_tokens.numel() - assert num_local_experts == w1.shape[0], f"{num_local_experts} == {w1.shape[0]}" + num_local_experts = w1.shape[0] #expert_num_tokens.numel() + assert num_local_experts == w1.shape[ + 0], f"{num_local_experts} == {w1.shape[0]}" N = w1.shape[1] // 2 @@ -821,8 +825,7 @@ def apply( # invoke_batched_silu_and_mul(output=intermediate_cache2, # input=intermediate_cache1, # expert_num_tokens=expert_num_tokens) - self.activation(activation, - intermediate_cache2.view(-1, N//2), + self.activation(activation, intermediate_cache2.view(-1, N // 2), intermediate_cache1.view(-1, N)) #qintermediate_cache2 = intermediate_cache2 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fd2cb1e2644..80e3942c5ee 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -21,7 +21,7 @@ _resize_cache, moe_kernel_quantize_input) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op, round_up +from vllm.utils import direct_register_custom_op from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -885,8 +885,7 @@ def fused_topk( M, topk, dtype=torch.int32 if indices_type is None else indices_type, - device=hidden_states.device - ) + device=hidden_states.device) token_expert_indices = torch.empty(M, topk, dtype=torch.int32, @@ -980,7 +979,7 @@ def get_config_dtype_str( return None -# TODO: use scalar_type? +# TODO: use scalar_type instead of bools? def get_config_qtype( use_fp8_w8a8: bool, use_int8_w8a8: bool, @@ -1239,8 +1238,8 @@ def fused_experts_impl( assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[2], \ - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}" + assert hidden_states.shape[1] == w1.shape[2], ( + f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" @@ -1655,16 +1654,11 @@ def apply( expert_ids = torch.repeat_interleave(expert_ids, max_num_tokens, dim=0) - print(f"EXPERT_IDS {expert_ids}") - #num_tokens_post_padded = torch.tensor([num_tokens], - # device=hidden_states.device, - # dtype=torch.int32) num_tokens_post_padded = torch.zeros(1, device=hidden_states.device, dtype=torch.int32) num_tokens_post_padded.fill_(max_num_tokens) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - #print(f"P = {sorted_token_ids}, {hidden_states.shape}") invoke_fused_moe_kernel(hidden_states, w1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6c337269d90..b5d34b4b1c9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -13,7 +13,7 @@ from torch.nn.parameter import UninitializedParameter import vllm.envs as envs -from vllm.config import get_current_vllm_config, ParallelConfig +from vllm.config import ParallelConfig, get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -30,11 +30,7 @@ has_pplx = importlib.util.find_spec("pplx_kernels") is not None if current_platform.is_cuda_alike(): - from .dispatch_combine import StandardDispatchCombine - from .fused_batched_moe import ( - BatchedDispatchCombine, - BatchedTritonExperts, - BatchedExperts) + from .fused_batched_moe import BatchedDispatchCombine, BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts from .modular_kernel import (FusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, @@ -138,26 +134,27 @@ def flatten_tp_across_dp(dp_rank: int): tp_size, tp_rank = flatten_tp_across_dp(dp_rank) if not use_ep: - return FusedMoEParallelConfig(tp_size = tp_size, - tp_rank = tp_rank, - dp_size = dp_size, - dp_rank = dp_rank, - ep_size = 1, - ep_rank = 0, - use_ep = False) + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor parallel. # Update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank - return FusedMoEParallelConfig(tp_size = 1, - tp_rank = 0, - dp_size = dp_size, - dp_rank = dp_rank, - ep_size = ep_size, - ep_rank = ep_rank, - use_ep = True) + return FusedMoEParallelConfig(tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True) + # Adapted from pplx-kernels tests/all_to_all_utils.py @dataclass @@ -258,16 +255,10 @@ def __init__(self): self._cache = weakref.WeakValueDictionary() self._lock = threading.RLock() # Reentrant lock for thread safety - def __del__(self): - logger.info("Deleting AllToAllCache") - def get_or_create(self, **kwargs): assert has_pplx import pplx_kernels as pplx - if False: - return pplx.AllToAll.internode(**kwargs) - # Create a hashable key from the kwargs key = tuple(sorted((k, v) for k, v in kwargs.items())) @@ -639,12 +630,11 @@ def determine_expert_map( def _construct_dispatch_combine( - moe: MoEConfig, - quant_config: Optional[QuantizationConfig] + moe: MoEConfig, quant_config: Optional[QuantizationConfig] ) -> Optional[FusedMoEQuantizeDispatchCombine]: max_num_tokens = MOE_DP_CHUNK_SIZE world_size = moe.ep_size - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. + dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank if moe.use_ep and has_pplx: @@ -656,15 +646,15 @@ def _construct_dispatch_combine( experts_per_token=moe.experts_per_token, # topk rank=rank, world_size=world_size, - dp_size= dp_size, + dp_size=dp_size, hidden_dim=moe.hidden_dim, hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, # For blocked per token: set to # ceil_div(hidden_dim, block_size) * sizeof(float32) # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( - (moe.hidden_dim + moe.block_size - 1) // moe.block_size * - torch.float32.itemsize))) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else + ((moe.hidden_dim + moe.block_size - 1) // + moe.block_size * torch.float32.itemsize))) return PplxDispatchCombine( all_to_all, @@ -674,23 +664,8 @@ def _construct_dispatch_combine( dp_size=dp_size, quant_dtype=moe.in_dtype, ) - elif moe.use_ep: - logger.debug("using batched dispatch") - return BatchedDispatchCombine( - max_num_tokens=max_num_tokens, - world_size=world_size, - dp_size=dp_size, - rank=rank, - ) - elif True: - return None - else: - logger.debug("using standard dispatch") - return StandardDispatchCombine( - moe.in_dtype, - quant_config.weight_block_size - if quant_config is not None else None, - ) + + return None class FusedMoE(torch.nn.Module): @@ -745,8 +720,10 @@ def __init__( vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( - tp_size_ = (tp_size if tp_size is not None else get_tensor_model_parallel_world_size()), - dp_size_ = (dp_size if dp_size is not None else get_dp_group().world_size), + tp_size_=(tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()), + dp_size_=(dp_size + if dp_size is not None else get_dp_group().world_size), vllm_parallel_config=vllm_config.parallel_config) self.global_num_experts = num_experts @@ -767,7 +744,8 @@ def __init__( ep_rank=self.ep_rank, global_num_experts=self.global_num_experts) else: - self.local_num_experts, self.expert_map = (self.global_num_experts, None) + self.local_num_experts, self.expert_map = (self.global_num_experts, + None) self.top_k = top_k @@ -799,7 +777,7 @@ def __init__( hidden_dim=hidden_size, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, - in_dtype=params_dtype, # TODO: is this right? + in_dtype=params_dtype, # TODO: is this right? ) # Note: get_quant_method will look at the layer's local_num_experts @@ -1203,7 +1181,8 @@ def naive_multicast(self, x: torch.Tensor, def must_reduce_shared_outputs(self) -> bool: return self.dp_size > 1 and self.use_ep and has_pplx - def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): + def maybe_all_reduce_tensor_model_parallel( + self, final_hidden_states: torch.Tensor): """ The pplx combine kernel reduce across GPU ranks by default. The pplx kernels are used when EP is enabled. In that case, this function is a no-op. @@ -1314,8 +1293,7 @@ def forward_impl(self, hidden_states: torch.Tensor, all_hidden_states = get_dp_group().all_reduce(final_hidden_states) final_hidden_states = all_hidden_states[start:end, :] - if self.reduce_results and (self.tp_size > 1 - or self.ep_size > 1): + if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index cba1e0ef506..a72d0917e0a 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -import torch from typing import Optional, Tuple +import torch + from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -84,21 +85,21 @@ def moe_permute( fill_invalid_expert: int = -1 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - This function expands and permutes activation to gather uncontinuous tokens + This function expands and permutes activation to gather uncontinuous tokens for each expert. Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - hidden_states (torch.Tensor): The input tensor to the MoE layer. - topk_weights (torch.Tensor): topk expert route weight for each token. - topk_ids (torch.Tensor): topk expert route id for each token. - token_expert_indices (torch.Tensor): indice for expanded hidden. - topk (int): The number of top-k experts to select. - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. - - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - - fill_invalid_expert(int): fill expert id in m_indices for invalid expert + - fill_invalid_expert(int): fill expert id in m_indices for invalid expert to workaround DeepGemm unsupported -1 in m_indices Returns: - permuted_hidden_states (torch.Tensor): permuted activation. @@ -106,7 +107,7 @@ def moe_permute( of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. - - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records + - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records the group which the j-th row of the LHS belong to.` """ n_token, n_hidden = hidden_states.shape @@ -154,7 +155,7 @@ def moe_unpermute( n_local_expert: int, ) -> torch.Tensor: """ - This function expands and permutes activation to gathering uncontinuous + This function expands and permutes activation to gathering uncontinuous tokens for each expert. Parameters: - permuted_hidden_states (torch.Tensor): permuted activation. @@ -166,8 +167,8 @@ def moe_unpermute( - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. Returns: - - hidden_states (torch.Tensor): The reduced and unpermuted activation - tensor. + - hidden_states (torch.Tensor): The reduced and unpermuted activation + tensor. """ n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1] assert (n_hidden * permuted_hidden_states.element_size() diff --git a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py index 002f689d585..7392fe418a4 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py +++ b/vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py @@ -9,11 +9,6 @@ moe_kernel_quantize_input) -def rank_chunk(num, r, w): - rem = num % w - return (num // w) + (1 if r < rem else 0) - - # Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. @@ -72,8 +67,9 @@ def dispatch( per_act_token, self.block_shape) - # TODO: does rem_experts need to be 0 for pplx to work properly? + # rem_experts need to be 0 for pplx to work properly. rem_experts = num_experts % self.world_size + assert rem_experts == 0 num_local_experts = ((num_experts // self.world_size) + (1 if self.rank < rem_experts else 0)) @@ -107,7 +103,6 @@ def dispatch( # This argument is optional, defaults to indices.shape[0] # There's not much point setting this unless it is != indices.shape[0] - #bound_m = torch.tensor([num_tokens], dtype=torch.uint32, device=device) bound_m = None self.a2a.dispatch( @@ -133,9 +128,6 @@ def combine( num_tokens = output.shape[0] # M # This argument is optional # There's not much point setting this unless it is != topk_ids.shape[0] - #bound_m = torch.tensor([num_tokens], - # dtype=torch.uint32, - # device=fused_expert_output.device) bound_m = None assert topk_ids.shape[0] == num_tokens @@ -147,8 +139,9 @@ def combine( if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - self.a2a.combine(out_tokens=output, - indices=topk_ids, #.to(torch.uint32), - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m) + self.a2a.combine( + out_tokens=output, + indices=topk_ids, #.to(torch.uint32), + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 3e7b2b4047a..c16389c25aa 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -17,7 +17,8 @@ def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor: Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - assert prod(v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly? + assert prod( + v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly? return x.flatten()[:prod(v)].view(*v) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index bbc2645bf81..acab253f731 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -31,8 +31,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -143,7 +142,7 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, # When just tensor-parallel is used, it isn't required - # to reduce the shared_output result. Instead we reduce + # to reduce the shared_output result. Instead we reduce # at the end of the forward pass. # With EP and the pplx kernels - this is no longer viable # as all GPU ranks in DP, produce the complete set of hidden_states. @@ -179,7 +178,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: * (1. / self.routed_scaling_factor) if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 09bbeea9b13..3a09841e722 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -99,9 +99,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, router_logits) - # Needed? if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 68e427d272c..2ba2d797883 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -25,8 +25,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -102,7 +101,8 @@ def forward(self, hidden_states): experts_out = routed_out + shared_out if self.tp_size > 1: - experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(experts_out) + experts_out = self.experts.maybe_all_reduce_tensor_model_parallel( + experts_out) return experts_out diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index aca0d658d88..8670e3facc2 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -33,9 +33,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 655b4d25fc7..fc96e329c7f 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -30,9 +30,7 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE From e478c1a3da4dddeab8514928191119c46bc99b00 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 7 May 2025 15:23:38 +0000 Subject: [PATCH 182/190] lint Signed-off-by: Bill Nell --- requirements/test.txt | 21 +++++- tests/kernels/moe/test_pplx_moe.py | 3 - vllm/compilation/compiler_interface.py | 3 +- .../layers/fused_moe/fused_batched_moe.py | 57 +++++---------- vllm/model_executor/layers/fused_moe/layer.py | 70 ++++++++++++------- vllm/model_executor/models/deepseek_v2.py | 5 +- vllm/model_executor/models/granitemoe.py | 2 +- vllm/model_executor/models/qwen2_moe.py | 2 +- vllm/model_executor/models/qwen3_moe.py | 2 +- vllm/platforms/cuda.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 1 - 11 files changed, 89 insertions(+), 79 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 9a15d9a0d82..e2a853a1469 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -27,6 +27,10 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration +async-timeout==5.0.1 + # via + # aiohttp + # redis attrs==24.2.0 # via # aiohttp @@ -126,6 +130,11 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval +exceptiongroup==1.2.2 + # via + # anyio + # hypothesis + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -623,7 +632,6 @@ setuptools==77.0.3 # via # mamba-ssm # pytablewriter - # torch # triton shellingham==1.5.4 # via typer @@ -683,8 +691,13 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers +toml==0.10.2 + # via datamodel-code-generator tomli==2.2.1 - # via schemathesis + # via + # black + # pytest + # schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0+cu128 @@ -756,12 +769,16 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via + # anyio + # black # huggingface-hub # librosa # mistral-common + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer tzdata==2024.2 diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index d30f4cef3bb..542f03a01a1 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -522,13 +522,10 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): assert torch.cuda.current_device() == pgi.local_rank - hidden_dim = a.shape[1] num_experts = w1.shape[0] - block_size = 128 device = pgi.device rank = pgi.rank world_size = pgi.world_size - topk = topk_ids.shape[1] max_num_tokens = rank_chunk(a.shape[0], 0, world_size) dispatch_combine = BatchedDispatchCombine( diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 0cb4a2d7c5f..faeb6c4e73c 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -328,7 +328,8 @@ def _get_shape_env() -> AlwaysHitShapeEnv: assert hash_str is not None, ( f"failed to get the hash of the compiled graph: {file_path}") assert file_path is not None, ( - "failed to get the file path of the compiled graph: {file_path}") + "failed to get the file path of the compiled graph: {file_path}" + ) return compiled_graph, (hash_str, file_path) def load(self, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index d9143619224..f8fa55f5208 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -514,31 +514,18 @@ def dispatch( dtype=torch.int, device=a1.device) - rem_experts = num_experts % self.world_size - num_local_experts = ((num_experts // self.world_size) + - (1 if self.rank < rem_experts else 0)) + assert num_experts % self.world_size == 0 + + num_local_experts = num_experts // self.world_size b_a1 = torch.zeros( (num_local_experts, self.max_num_tokens, hidden_dim), dtype=a1.dtype, device=a1.device) - first_expert = (((num_experts // self.world_size) * self.rank) + - rem_experts - self.rank) + first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts - # rhs = torch.empty((self.max_num_tokens, hidden_dim), - # dtype=a1.dtype, device=a1.device) - - # for expert_id in range(first_expert, last_expert): - # topks = torch.any(topk_ids == expert_id, dim=1).flatten() - # rows = torch.count_nonzero(topks.flatten()) - # #rhs[:rows] = a1[:topks.numel()][topks] - # topks_idx = topks.nonzero() - # torch.index_select(a1, dim=0, index=topks_idx.flatten(), out=rhs[:rows]) - # b_a1[expert_id - first_expert, :rows, :] = rhs[:rows] - # tokens_per_expert[expert_id - first_expert] = rows - for expert_id in range(first_expert, last_expert): topks = torch.any(topk_ids == expert_id, dim=1).flatten() rows = torch.count_nonzero(topks.flatten()) @@ -558,24 +545,14 @@ def combine( ) -> None: num_tokens = topk_ids.shape[0] num_local_experts = fused_expert_output.shape[0] - topk = topk_weights.shape[1] K = fused_expert_output.shape[-1] assert output.shape[0] == num_tokens and output.shape[1] == K output.fill_(0) - first_expert = num_local_experts * self.rank # NOT QUITE RIGHT + first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts - # for expert_id in range(first_expert, last_expert): - # topkws = topk_ids == expert_id - # topks = torch.any(topkws, dim=1).flatten() - # outrhs = output[topks] - # rhs = fused_expert_output[expert_id - first_expert, :outrhs.shape[0], :] - # if not apply_router_weight_on_input: - # rhs.mul_(topk_weights[topkws].view(rhs.shape[0], 1)) - # output[topks] = outrhs + rhs - for expert_id in range(first_expert, last_expert): topkws = topk_ids == expert_id topks = torch.any(topkws, dim=1).flatten() @@ -661,20 +638,20 @@ def apply( num_experts = global_num_experts out = _resize_cache(workspace13, (num_experts, max_num_tokens * num_dp, hidden_dim)) - num_local_experts = w1.shape[0] #expert_num_tokens.numel() + num_local_experts = w1.shape[0] assert num_local_experts == w1.shape[ 0], f"{num_local_experts} == {w1.shape[0]}" N = w1.shape[1] // 2 # Not cudagraph friendly - # assert (torch.cuda.is_current_stream_capturing() or - # torch.all(expert_num_tokens <= max_num_tokens)), ( - # f"{expert_num_tokens} <= {max_num_tokens}") + assert (torch.cuda.is_current_stream_capturing() + or torch.all(expert_num_tokens <= max_num_tokens)), ( + f"{expert_num_tokens} <= {max_num_tokens}") for expert in range(num_local_experts): # Indexing expert_num_tokens doesn't work w/cudagraphs - if True or torch.cuda.is_current_stream_capturing(): + if torch.cuda.is_current_stream_capturing(): num = max_num_tokens * num_dp else: num = int(expert_num_tokens[expert].item()) @@ -821,12 +798,14 @@ def apply( block_shape=self.block_shape) # Fix activations - # assert activation == "silu" - # invoke_batched_silu_and_mul(output=intermediate_cache2, - # input=intermediate_cache1, - # expert_num_tokens=expert_num_tokens) - self.activation(activation, intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) + if True: + assert activation == "silu" + invoke_batched_silu_and_mul(output=intermediate_cache2, + input=intermediate_cache1, + expert_num_tokens=expert_num_tokens) + else: + self.activation(activation, intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N)) #qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b5d34b4b1c9..89b932c6ad4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -68,55 +68,68 @@ def use_pplx_kernels(self): def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": """ - Determine MoE parallel configuration. Based on the input tp_size_, dp_size_, - ep_size_ and vllm's parallel config, determine what level's of parallelism - to use in the fused moe layer. + Determine MoE parallel configuration. Based on the input tp_size_, + dp_size_, ep_size_ and vllm's parallel config, determine what + level's of parallelism to use in the fused moe layer. Args: tp_size_ (int): tp_size passed into the FusedMoE constructor. dp_size_ (int): dp_size passed into the FusedMoE constructor. ep_size_ (int): ep_size passed into the FusedMoE constructor. - vllm_parallel_config (ParallelConfig): vllm's parallel config object. + vllm_parallel_config (ParallelConfig): vllm's parallel config + object. Examples: When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, we simply return the sizes unaltered and the ranks set to 0. - Expert Parallelism is considered only when either dp_size_ or tp_size_ is non trivial. + Expert Parallelism is considered only when either dp_size_ or tp_size_ + is non trivial. - When TP = 2, DP = 1 and EP = False, the configuration on different devices, - - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // legend : {size, rank} + When TP = 2, DP = 1 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // + legend : {size, rank} - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} - Comment : Tensors are sharded across 2 devices. - When TP = 1, DP = 2 and EP = False, the configuration on different devices, + When TP = 1, DP = 2 and EP = False, the configuration on different + devices, - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} - - Comment: There are 2 engine instances and the tensors are sharded across 2 decvices. + - Comment: There are 2 engine instances and the tensors are sharded + across 2 decvices. - When TP = 2, DP = 2 and EP = False, the configuration on different devices, + When TP = 2, DP = 2 and EP = False, the configuration on different + devices, - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} - - Comment: There are 2 engine instances and the tensors are sharded across 4 devices. + - Comment: There are 2 engine instances and the tensors are sharded + across 4 devices. - When, TP = 2, DP = 1 and EP = True, the configuration on different devices, + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices, - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} - Comment: The experts are split between the 2 devices. - When, TP = 1, DP = 2 and EP = True, the configuration on different devices, + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices, - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} - - Comment: There are 2 engine instances and the experts are split between the 2 devices. + - Comment: There are 2 engine instances and the experts are split + between the 2 devices. - When TP = 2, DP = 2 and EP = True, the configuration on different devices, + When TP = 2, DP = 2 and EP = True, the configuration on different + devices, - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} - - Comment: There are 2 engine instances and the experts are split between the 4 devices. + - Comment: There are 2 engine instances and the experts are split + between the 4 devices. """ def flatten_tp_across_dp(dp_rank: int): @@ -127,7 +140,8 @@ def flatten_tp_across_dp(dp_rank: int): tp_rank = dp_rank * tp_size_ + tp_rank return tp_size, tp_rank - use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel + use_ep = (dp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel) dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group @@ -143,8 +157,8 @@ def flatten_tp_across_dp(dp_rank: int): use_ep=False) # DP + EP / TP + EP / DP + TP + EP assert use_ep - # In EP, each device owns a set of experts fully. There is no tensor parallel. - # Update tp_size, tp_rank, ep_size and ep_rank to reflect that. + # In EP, each device owns a set of experts fully. There is no tensor + # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank return FusedMoEParallelConfig(tp_size=1, @@ -719,12 +733,13 @@ def __init__( self.params_dtype = params_dtype vllm_config = get_current_vllm_config() - self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( - tp_size_=(tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()), - dp_size_=(dp_size - if dp_size is not None else get_dp_group().world_size), - vllm_parallel_config=vllm_config.parallel_config) + self.moe_parallel_config: FusedMoEParallelConfig = ( + FusedMoEParallelConfig.make( + tp_size_=(tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()), + dp_size_=(dp_size if dp_size is not None else + get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config)) self.global_num_experts = num_experts @@ -1184,8 +1199,9 @@ def must_reduce_shared_outputs(self) -> bool: def maybe_all_reduce_tensor_model_parallel( self, final_hidden_states: torch.Tensor): """ - The pplx combine kernel reduce across GPU ranks by default. The pplx kernels are - used when EP is enabled. In that case, this function is a no-op. + The pplx combine kernel reduce across GPU ranks by default. The pplx + kernels are used when EP is enabled. In that case, this function is a + no-op. """ if self.dp_size > 1 and self.use_ep and has_pplx: return final_hidden_states diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index acab253f731..96cd3315f46 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -145,7 +145,8 @@ def __init__( # to reduce the shared_output result. Instead we reduce # at the end of the forward pass. # With EP and the pplx kernels - this is no longer viable - # as all GPU ranks in DP, produce the complete set of hidden_states. + # as all GPU ranks in DP, produce the complete set of + # hidden_states. # Therefore reduce the shared experts early. reduce_results=self.experts.must_reduce_shared_outputs(), prefix=f"{prefix}.shared_experts", @@ -178,7 +179,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: * (1. / self.routed_scaling_factor) if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 3a09841e722..b0c525849a2 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -100,7 +100,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts(hidden_states, router_logits) if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 8670e3facc2..efbd61755c0 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -154,7 +154,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index fc96e329c7f..f507a072734 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -135,7 +135,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits=router_logits) final_hidden_states = final_hidden_states if self.tp_size > 1: - final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 31a28f5064d..ab03dece8c1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -157,7 +157,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") - if (False and parallel_config.data_parallel_size > 1 + if (parallel_config.data_parallel_size > 1 and compilation_config.use_cudagraph): logger.info( "Data Parallel: Forcing enforce eager to be True since DP is " diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index df1176212d0..e0c3d05c797 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1542,7 +1542,6 @@ def _dummy_run( self.drafter.dummy_run(num_tokens) logit_indices = np.cumsum(num_scheduled_tokens) - 1 - #logit_indices = torch.from_numpy(logit_indices).to(hidden_states.device) return hidden_states[logit_indices] @torch.inference_mode() From 90e9d05a1e33f8181b60fcd5bbd91f244c26e0aa Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 7 May 2025 15:26:25 +0000 Subject: [PATCH 183/190] revert lint changes to requirements/test.txt Signed-off-by: Bill Nell --- requirements/test.txt | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index e2a853a1469..9a15d9a0d82 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -27,10 +27,6 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration -async-timeout==5.0.1 - # via - # aiohttp - # redis attrs==24.2.0 # via # aiohttp @@ -130,11 +126,6 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # hypothesis - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -632,6 +623,7 @@ setuptools==77.0.3 # via # mamba-ssm # pytablewriter + # torch # triton shellingham==1.5.4 # via typer @@ -691,13 +683,8 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers -toml==0.10.2 - # via datamodel-code-generator tomli==2.2.1 - # via - # black - # pytest - # schemathesis + # via schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0+cu128 @@ -769,16 +756,12 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via - # anyio - # black # huggingface-hub # librosa # mistral-common - # multidict # pqdm # pydantic # pydantic-core - # rich # torch # typer tzdata==2024.2 From be1a8e5619e00c8b41f1f453eae52b871e061aee Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 7 May 2025 15:27:28 +0000 Subject: [PATCH 184/190] revert lint changes to compiler_interface.py Signed-off-by: Bill Nell --- vllm/compilation/compiler_interface.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index faeb6c4e73c..b7e7a79bef0 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -326,10 +326,9 @@ def _get_shape_env() -> AlwaysHitShapeEnv: # compilation cache. if not envs.VLLM_DISABLE_COMPILE_CACHE: assert hash_str is not None, ( - f"failed to get the hash of the compiled graph: {file_path}") + "failed to get the hash of the compiled graph") assert file_path is not None, ( - "failed to get the file path of the compiled graph: {file_path}" - ) + "failed to get the file path of the compiled graph") return compiled_graph, (hash_str, file_path) def load(self, From 916f902625fd4154e65ce5f1139ae8d1829e4afa Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 7 May 2025 15:38:16 +0000 Subject: [PATCH 185/190] fix merge Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 80e3942c5ee..93b651ecde8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -765,7 +765,7 @@ def get_default_config( # num_stages=3 can cause triton.runtime.errors.OutOfResources # on ROCm, set it to 2 instead. config = { - "BLOCK_SIZE_M": 64 if not use_deep_gemm else dg.get_m_alignment_for_contiguous_layout(), + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, From c04cb12a7e9d7170e57f78eb54b9560aea4a4583 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 7 May 2025 19:20:04 +0000 Subject: [PATCH 186/190] fix more lint errors Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 9 +++-- .../layers/fused_moe/modular_kernel.py | 39 +++++++++---------- .../layers/fused_moe/triton_deep_gemm_moe.py | 9 ++--- 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 89b932c6ad4..6be014ae857 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -2,11 +2,11 @@ import importlib import threading -import weakref from abc import abstractmethod from dataclasses import dataclass from enum import Enum from typing import Callable, List, Optional, Tuple +from weakref import WeakValueDictionary import torch import torch.nn.functional as F @@ -266,7 +266,7 @@ def apply( class AllToAllCache: def __init__(self): - self._cache = weakref.WeakValueDictionary() + self._cache: WeakValueDictionary = WeakValueDictionary() self._lock = threading.RLock() # Reentrant lock for thread safety def get_or_create(self, **kwargs): @@ -802,7 +802,8 @@ def __init__( if quant_config is None: quant_method = UnquantizedFusedMoEMethod(moe) else: - quant_method = quant_config.get_quant_method(self, prefix) + quant_method = quant_config.get_quant_method( + self, prefix) # type: ignore assert isinstance(quant_method, FusedMoEMethodBase) assert quant_method is not None @@ -812,7 +813,7 @@ def __init__( if dispatch_combine is not None: world_size = moe.ep_size - dp_size = moe.ep_size // moe.dp_size + dp_size = int(moe.ep_size // moe.dp_size) success = self.quant_method.set_dispatch_combine( dp_size, world_size, dispatch_combine) if not success: diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 299d98c7f15..95b0397f952 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -339,27 +339,24 @@ def forward( a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, expert_map, apply_router_weight_on_input) - if True: - fused_out = self.fused_experts.apply( - a1q, - w1, - w2, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) - else: - fused_out = torch.empty_like(a1q) + fused_out = self.fused_experts.apply( + a1q, + w1, + w2, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) self.dispatch_combine.combine(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 5ddb0e66842..88edfbf0719 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -21,11 +21,10 @@ def __init__(self, block_m: Optional[int] = None, allow_deep_gemm: bool = False): super().__init__() - self.triton_expert = TritonExperts(use_fp8_w8a8, use_int8_w8a8, - use_int4_w4a16, use_int8_w8a16, - per_channel_quant, block_shape, - block_m) - self.deep_gemm_expert = DeepGemmExperts() + self.triton_expert: TritonExperts = TritonExperts( + use_fp8_w8a8, use_int8_w8a8, use_int4_w4a16, use_int8_w8a16, + per_channel_quant, block_shape, block_m) + self.deep_gemm_expert: DeepGemmExperts = DeepGemmExperts() self.allow_deep_gemm = allow_deep_gemm self.use_fp8_w8a8 = use_fp8_w8a8 From f5bcc229c9aeea51c25e7007266a26b3e4928e0a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 7 May 2025 20:13:29 +0000 Subject: [PATCH 187/190] fix lint Signed-off-by: Bill Nell --- .../layers/fused_moe/triton_deep_gemm_moe.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 88edfbf0719..e512c11933d 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -21,10 +21,14 @@ def __init__(self, block_m: Optional[int] = None, allow_deep_gemm: bool = False): super().__init__() - self.triton_expert: TritonExperts = TritonExperts( - use_fp8_w8a8, use_int8_w8a8, use_int4_w4a16, use_int8_w8a16, - per_channel_quant, block_shape, block_m) - self.deep_gemm_expert: DeepGemmExperts = DeepGemmExperts() + self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + block_m=block_m) + self.deep_gemm_expert = DeepGemmExperts() self.allow_deep_gemm = allow_deep_gemm self.use_fp8_w8a8 = use_fp8_w8a8 @@ -69,7 +73,7 @@ def apply( N = w1.shape[1] if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): - return self.deep_gemm_expert( + return self.deep_gemm_expert.apply( hidden_states, w1, w2, @@ -88,7 +92,7 @@ def apply( expert_num_tokens, ) else: - return self.triton_expert( + return self.triton_expert.apply( hidden_states, w1, w2, From f7b70708ec70c23cf9ede7dc160d2bff47c8fbaa Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 9 May 2025 01:57:32 -0400 Subject: [PATCH 188/190] Fixes and cleanup Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_batched_moe.py | 76 +------------- vllm/distributed/parallel_state.py | 7 +- vllm/distributed/utils.py | 35 +++++-- vllm/forward_context.py | 18 +--- .../layers/fused_moe/fused_batched_moe.py | 98 +------------------ vllm/model_executor/layers/fused_moe/layer.py | 41 +++++--- vllm/model_executor/models/deepseek_v2.py | 14 +-- vllm/model_executor/models/llama4.py | 6 +- vllm/model_executor/models/qwen2_moe.py | 7 +- vllm/v1/attention/backends/mla/common.py | 4 +- 10 files changed, 87 insertions(+), 219 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index f9f3f6506a5..90b8e1a0d99 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -3,11 +3,11 @@ from dataclasses import dataclass import pytest -import torch import triton.language as tl +import torch from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - invoke_batched_silu_and_mul, invoke_moe_batched_triton_kernel) + invoke_moe_batched_triton_kernel) @dataclass @@ -103,75 +103,5 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, ref_output = ref_impl(tensors.A, tensors.B, ref_output, tensors.num_expert_tokens) - #torch.cuda.synchronize() - #print (f"ref output {ref_output}") - #print (f"test output {test_output}") - - torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3) - - -@dataclass -class BatchedSiluMulConfig: - dtype: torch.dtype - num_experts: int - max_tokens_per_expert: int - D: int - - -@dataclass -class BatchedSiluMulTensors: - input: torch.Tensor - output: torch.Tensor - expert_num_tokens: torch.Tensor - - @staticmethod - def make_tensors(config: BatchedSiluMulConfig): - input = torch.randn( - (config.num_experts, config.max_tokens_per_expert, config.D * 2), - device="cuda", - dtype=config.dtype) / 50.0 - output = torch.zeros( - (config.num_experts, config.max_tokens_per_expert, config.D), - device="cuda", - dtype=config.dtype) - num_expert_tokens = torch.randint(low=0, - high=config.max_tokens_per_expert, - size=(config.num_experts, ), - device="cuda", - dtype=torch.int32) - return BatchedSiluMulTensors(input, output, num_expert_tokens) - - -def ref_batched_silu_mul(output: torch.Tensor, input: torch.Tensor, - num_expert_tokens: torch.Tensor) -> torch.Tensor: - - num_expert_tokens_cpu = num_expert_tokens.clone() - num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") - num_experts = num_expert_tokens.size(0) - - for e in range(num_experts): - num_tokens = num_expert_tokens_cpu[e].item() - out_part = output[e, :num_tokens, :] - in_part = input[e, :num_tokens, :] - torch.ops._C.silu_and_mul(out_part, in_part) - - -@pytest.mark.parametrize("num_experts", [16, 32]) -@pytest.mark.parametrize("max_tokens_per_expert", [128]) -@pytest.mark.parametrize("D", [128, 256]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_batched_silu_mul(num_experts: int, max_tokens_per_expert: int, D: int, - dtype: torch.dtype): - - config = BatchedSiluMulConfig(dtype, num_experts, max_tokens_per_expert, D) - tensors = BatchedSiluMulTensors.make_tensors(config) - - test_out = tensors.output - ref_out = torch.zeros_like(test_out) - - ref_batched_silu_mul(ref_out, tensors.input, tensors.expert_num_tokens) - - invoke_batched_silu_and_mul(test_out, tensors.input, - tensors.expert_num_tokens) - torch.testing.assert_close(test_out, ref_out) + torch.testing.assert_close(test_output, ref_output, atol=1e-3, rtol=1e-3) \ No newline at end of file diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 2cedaa06018..dbe0d06199b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -33,10 +33,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch -import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +import torch import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) @@ -943,6 +943,11 @@ def pplx_finalize(): global PPLX_DID_INIT if PPLX_DID_INIT: from pplx_kernels.nvshmem import nvshmem_finalize + + from vllm.model_executor.layers.fused_moe.layer import ( + _all_to_all_cache) + _all_to_all_cache.destroy() + logger.info("PPLX finalize") nvshmem_finalize() diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 6bb39672a32..899ec843040 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -4,6 +4,7 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import contextlib import dataclasses import datetime import pickle @@ -12,7 +13,6 @@ from collections import deque from typing import Any, Deque, Dict, Optional, Sequence, Tuple -import torch from torch.distributed import ProcessGroup, TCPStore from torch.distributed.distributed_c10d import (Backend, PrefixStore, _get_default_timeout, @@ -20,6 +20,7 @@ is_nccl_available) from torch.distributed.rendezvous import rendezvous +import torch import vllm.envs as envs from vllm.logger import init_logger @@ -360,11 +361,29 @@ def stateless_destroy_torch_distributed_process_group( Destroy ProcessGroup returned by stateless_init_torch_distributed_process_group(). """ - # TODO: pytorch < 2.7? - if False: - # Lazy import for non-CUDA backends. - from torch.distributed.distributed_c10d import _shutdown_backend - _shutdown_backend(pg) - else: - pg.shutdown() + + def _shutdown_backend(pg): + # We have been using, + # torch.distributed.distributed_c10d._shutdown_backend + # for backend shutdowns. But the function has been retired + # since Torch 2.7.0. As a recourse, we copy-paste the + # `_shutdown_backend` function from <2.7.0 here. + from torch.distributed.distributed_c10d import ProcessGroupNCCL + backend = None + with contextlib.suppress(RuntimeError): + backend = pg._get_backend(torch.device("cuda")) + + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + # explicitly call shutdown to ensure that NCCL resources are + # released + backend._shutdown() + + torch.distributed.barrier() + if pg.rank() == 0: + # Let the other ranks finish first + # Rank 0 has the TCPStore server. Let the other ranks finish so + # they don't complain about the non-existence of the TCPStore server. + time.sleep(1) + + _shutdown_backend(pg) _unregister_process_group(pg.group_name) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 2d6153095eb..02946b1d12c 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -6,9 +6,9 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union -import torch import torch.distributed as dist +import torch import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -31,10 +31,8 @@ @dataclass class DPMetadata: - max_tokens_across_dp: torch.Tensor - num_tokens_across_dp: torch.Tensor + max_tokens_across_dp_cpu: torch.Tensor cu_tokens_across_dp_cpu: torch.Tensor - dp_rank_num_tokens: torch.Tensor @dataclass @@ -97,16 +95,10 @@ def set_forward_context(attn_metadata: Any, dtype=torch.int32) from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) - #TODO device? (tms) - max_tokens_across_dp = torch.max( - num_tokens_tensor) #.to(device="cuda") + max_tokens_across_dp_cpu = torch.max(num_tokens_tensor) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) - dp_rank_num_tokens = torch.tensor( - [num_tokens], - dtype=torch.uint32, - device=vllm_config.device_config.device) - dp_metadata = DPMetadata(max_tokens_across_dp, num_tokens_tensor, - cu_tokens_across_dp_cpu, dp_rank_num_tokens) + dp_metadata = DPMetadata(max_tokens_across_dp_cpu, + cu_tokens_across_dp_cpu) global _forward_context prev_context = _forward_context diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index f8fa55f5208..9107f13a8f3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -2,75 +2,16 @@ """Fused batched MoE kernel.""" from typing import List, Optional, Tuple -import torch import triton import triton.language as tl +import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import _resize_cache -@triton.jit -def batched_silu_and_mul_kernel( - output, # [E, MAX_NUM_TOKENS, D] - input, # [E, MAX_NUM_TOKENS, D * 2] - expert_num_tokens, # [E] - stride_oe, - stride_om, - stride_ie, - stride_im, - compute_type: tl.constexpr, - D, - BLOCK_M: tl.constexpr, - BLOCK_D: tl.constexpr): - - expert_id = tl.program_id(axis=0) - e_num_tokens = tl.load(expert_num_tokens + expert_id) - if e_num_tokens == 0: - # early exit - return - - pid_m = tl.program_id(axis=1) - cta_m_start = pid_m * BLOCK_M - if cta_m_start >= e_num_tokens: - # early exit - return - - cta_input_ptr = input + expert_id * stride_ie + cta_m_start * stride_im - cta_output_ptr = output + expert_id * stride_oe + cta_m_start * stride_om - - cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) - offs_m = tl.arange(0, BLOCK_M)[:, None] - mask_m = offs_m < cta_m_size - - cta_input_ptrs = cta_input_ptr + offs_m * stride_im - cta_output_ptrs = cta_output_ptr + offs_m * stride_om - - # offset by D - offs_D = tl.arange(0, BLOCK_D) - cta_input_ptrs = cta_input_ptrs + offs_D - cta_output_ptrs = cta_output_ptrs + offs_D - - for d in range(0, tl.cdiv(D, BLOCK_D)): - mask_D = offs_D < (D - (d * BLOCK_D)) - mask_tile = mask_m & mask_D - - x_tile = tl.load(cta_input_ptrs, mask=mask_tile, - other=0.0).to(dtype=tl.float32) - y_tile = tl.load(cta_input_ptrs + D, mask=mask_tile, other=0.0) - - # silu and mul - out_tile = (x_tile * (1.0 / - (1.0 + tl.exp(-x_tile)))).to(dtype=compute_type) - out_tile = out_tile * y_tile - tl.store(cta_output_ptrs, out_tile, mask=mask_tile) - - cta_input_ptrs = cta_input_ptrs + BLOCK_D - cta_output_ptrs = cta_output_ptrs + BLOCK_D - - @triton.jit def moe_mmk( a_ptrs, @@ -438,33 +379,6 @@ def invoke_moe_batched_triton_kernel( BLOCK_K=BLOCK_K) -def invoke_batched_silu_and_mul( - output: torch.Tensor, #[E, MAX_TOKENS, D] - input: torch.Tensor, #[E, MAX_TOKENS, D * 2] - expert_num_tokens: torch.Tensor): - - num_experts = output.size(0) - max_num_tokens = output.size(1) - D = output.size(2) - - BLOCK_D = 1024 - BLOCK_M = 1 - - compute_tl_dtype = { - torch.float16: tl.float16, - torch.float32: tl.float32, - torch.bfloat16: tl.bfloat16 - }[output.dtype] - - #print(f"compute type {compute_tl_dtype}") - - grid = (num_experts, triton.cdiv(max_num_tokens, BLOCK_M)) - batched_silu_and_mul_kernel[grid](output, input, expert_num_tokens, - output.stride(0), output.stride(1), - input.stride(0), input.stride(1), - compute_tl_dtype, D, BLOCK_M, BLOCK_D) - - def rank_chunk(num, r, w): rem = num % w return (num // w) + (1 if r < rem else 0) @@ -798,14 +712,8 @@ def apply( block_shape=self.block_shape) # Fix activations - if True: - assert activation == "silu" - invoke_batched_silu_and_mul(output=intermediate_cache2, - input=intermediate_cache1, - expert_num_tokens=expert_num_tokens) - else: - self.activation(activation, intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) + self.activation(activation, intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N)) #qintermediate_cache2 = intermediate_cache2 a2q_scale = a2_scale diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6be014ae857..3f1c21548aa 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -8,10 +8,10 @@ from typing import Callable, List, Optional, Tuple from weakref import WeakValueDictionary -import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter +import torch import vllm.envs as envs from vllm.config import ParallelConfig, get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, @@ -62,7 +62,7 @@ class FusedMoEParallelConfig: @property def use_pplx_kernels(self): - return self.use_ep and has_pplx + return self.dp_size > 1 and self.use_ep and has_pplx @staticmethod def make(tp_size_: int, dp_size_: int, @@ -269,6 +269,11 @@ def __init__(self): self._cache: WeakValueDictionary = WeakValueDictionary() self._lock = threading.RLock() # Reentrant lock for thread safety + def destroy(self): + with self._lock: + for _, a2a in self._cache.items(): + a2a.destroy() + def get_or_create(self, **kwargs): assert has_pplx import pplx_kernels as pplx @@ -279,7 +284,9 @@ def get_or_create(self, **kwargs): with self._lock: instance = self._cache.get(key) if instance is None: - # TODO: should be intranode + # TODO (varun): Add support to switch to intranode + # when all communications are within the same + # node. instance = pplx.AllToAll.internode(**kwargs) self._cache[key] = instance return instance @@ -651,7 +658,7 @@ def _construct_dispatch_combine( dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank - if moe.use_ep and has_pplx: + if moe.use_pplx_kernels: logger.debug("using pplx dispatch") all_to_all = get_all_to_all( @@ -1194,17 +1201,27 @@ def naive_multicast(self, x: torch.Tensor, return buffer - def must_reduce_shared_outputs(self) -> bool: - return self.dp_size > 1 and self.use_ep and has_pplx + def must_reduce_shared_expert_outputs(self) -> bool: + """ + The shared_experts are typically computed using the RowParallelLinear + layer. The result of this function is typically used as + the reduce_results argument to the module. + When just tensor-parallel is used, it is not required to reduce + the shared_experts results immediately. Instead we reduce at the + once at the end of the MoE op. (Refer to DeepSeekV2MoE module) + With EP and the pplx kernels - this is no longer viable as all + GPU ranks in DP, produce the complete set of hidden_states. + Therefore it is required that we reduce the shared_experts output + early. + """ + return self.use_pplx_kernels def maybe_all_reduce_tensor_model_parallel( self, final_hidden_states: torch.Tensor): """ - The pplx combine kernel reduce across GPU ranks by default. The pplx - kernels are used when EP is enabled. In that case, this function is a - no-op. + The pplx combine kernel reduces across GPU ranks by default. """ - if self.dp_size > 1 and self.use_ep and has_pplx: + if self.use_pplx_kernels: return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -1249,7 +1266,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): final_hidden_states) ctx = get_forward_context() - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp + max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE num_tokens = full_hidden_states.size(0) @@ -1271,7 +1288,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None - if self.dp_size > 1 and self.use_ep and has_pplx: + if self.moe_parallel_config.use_pplx_kernels: return self.forward_impl_chunked(hidden_states, router_logits) if self.dp_size > 1: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 96cd3315f46..6c1a40d9933 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -24,10 +24,10 @@ """Inference-only DeepseekV2/DeepseekV3 model.""" from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union -import torch -from torch import nn from transformers import PretrainedConfig +import torch +from torch import nn from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig @@ -141,14 +141,8 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - # When just tensor-parallel is used, it isn't required - # to reduce the shared_output result. Instead we reduce - # at the end of the forward pass. - # With EP and the pplx kernels - this is no longer viable - # as all GPU ranks in DP, produce the complete set of - # hidden_states. - # Therefore reduce the shared experts early. - reduce_results=self.experts.must_reduce_shared_outputs(), + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), prefix=f"{prefix}.shared_experts", ) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 2ba2d797883..620812de2e7 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -18,10 +18,10 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, List, Optional, Set, Tuple -import torch -from torch import nn from transformers import Llama4TextConfig +import torch +from torch import nn from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -88,7 +88,7 @@ def __init__(self, quant_config=quant_config, bias=False, prefix=f"{prefix}.shared_expert", - reduce_results=False, # We need to do scatter before reduce + reduce_results=self.experts.must_reduce_shared_expert_outputs(), ) def forward(self, hidden_states): diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index efbd61755c0..d1b9a86ecbc 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -25,11 +25,11 @@ """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union -import torch import torch.nn.functional as F -from torch import nn from transformers import PretrainedConfig +import torch +from torch import nn from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -127,7 +127,8 @@ def __init__( intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), ) else: self.shared_expert = None diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b4f8317e096..4e33dd81891 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -192,7 +192,6 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar import torch - from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, @@ -882,6 +881,9 @@ def forward( if attn_metadata is None: # Profiling run. + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. return output.fill_(0) num_actual_toks = attn_metadata.num_actual_tokens From af021672934224a77f7722079ad0908334c8fe95 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 9 May 2025 09:18:18 -0400 Subject: [PATCH 189/190] import sort Signed-off-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_batched_moe.py | 2 +- vllm/distributed/parallel_state.py | 2 +- vllm/distributed/utils.py | 2 +- vllm/forward_context.py | 2 +- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/models/deepseek_v2.py | 3 ++- vllm/model_executor/models/llama4.py | 4 ++-- vllm/model_executor/models/qwen2_moe.py | 4 ++-- vllm/v1/attention/backends/mla/common.py | 1 + 10 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 90b8e1a0d99..5de87fa581e 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -3,9 +3,9 @@ from dataclasses import dataclass import pytest +import torch import triton.language as tl -import torch from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( invoke_moe_batched_triton_kernel) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index dbe0d06199b..10db10236f6 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -33,10 +33,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch +import torch import torch.distributed from torch.distributed import Backend, ProcessGroup -import torch import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 899ec843040..2bf45e20251 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -13,6 +13,7 @@ from collections import deque from typing import Any, Deque, Dict, Optional, Sequence, Tuple +import torch from torch.distributed import ProcessGroup, TCPStore from torch.distributed.distributed_c10d import (Backend, PrefixStore, _get_default_timeout, @@ -20,7 +21,6 @@ is_nccl_available) from torch.distributed.rendezvous import rendezvous -import torch import vllm.envs as envs from vllm.logger import init_logger diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 02946b1d12c..f3732f83af5 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -6,9 +6,9 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union +import torch import torch.distributed as dist -import torch import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer import (get_kv_transfer_group, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 9107f13a8f3..47739375b14 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -2,10 +2,10 @@ """Fused batched MoE kernel.""" from typing import List, Optional, Tuple +import torch import triton import triton.language as tl -import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3f1c21548aa..0feee87b4cb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -8,10 +8,10 @@ from typing import Callable, List, Optional, Tuple from weakref import WeakValueDictionary +import torch import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter -import torch import vllm.envs as envs from vllm.config import ParallelConfig, get_current_vllm_config from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6c1a40d9933..a75c7c116b9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -24,10 +24,11 @@ """Inference-only DeepseekV2/DeepseekV3 model.""" from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union -from transformers import PretrainedConfig import torch from torch import nn +from transformers import PretrainedConfig + from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 620812de2e7..dfd0804f21c 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -18,10 +18,10 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, List, Optional, Set, Tuple -from transformers import Llama4TextConfig - import torch from torch import nn +from transformers import Llama4TextConfig + from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index d1b9a86ecbc..e45b81cb015 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -25,11 +25,11 @@ """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +import torch import torch.nn.functional as F +from torch import nn from transformers import PretrainedConfig -import torch -from torch import nn from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 4e33dd81891..f8746ca708f 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -192,6 +192,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar import torch + from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, From 081f11f7fd65e515067d6cd25a96227dc6ce481e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 9 May 2025 09:59:27 -0400 Subject: [PATCH 190/190] fix shutdown Signed-off-by: Varun Sundar Rabindranath --- vllm/distributed/utils.py | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 2bf45e20251..18459f982c0 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -4,7 +4,6 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -import contextlib import dataclasses import datetime import pickle @@ -361,29 +360,12 @@ def stateless_destroy_torch_distributed_process_group( Destroy ProcessGroup returned by stateless_init_torch_distributed_process_group(). """ + try: + # pytorch < 2.7 + # Lazy import for non-CUDA backends. + from torch.distributed.distributed_c10d import _shutdown_backend + _shutdown_backend(pg) + except: + pg.shutdown() - def _shutdown_backend(pg): - # We have been using, - # torch.distributed.distributed_c10d._shutdown_backend - # for backend shutdowns. But the function has been retired - # since Torch 2.7.0. As a recourse, we copy-paste the - # `_shutdown_backend` function from <2.7.0 here. - from torch.distributed.distributed_c10d import ProcessGroupNCCL - backend = None - with contextlib.suppress(RuntimeError): - backend = pg._get_backend(torch.device("cuda")) - - if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): - # explicitly call shutdown to ensure that NCCL resources are - # released - backend._shutdown() - - torch.distributed.barrier() - if pg.rank() == 0: - # Let the other ranks finish first - # Rank 0 has the TCPStore server. Let the other ranks finish so - # they don't complain about the non-existence of the TCPStore server. - time.sleep(1) - - _shutdown_backend(pg) _unregister_process_group(pg.group_name)