From 383e5e2170850dc713ac8e5b3554ec18d41599ca Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 10 Feb 2025 20:43:52 +0800 Subject: [PATCH 1/2] add triton grouped_topk --- lightllm/common/fused_moe/grouped_topk.py | 231 ++++++++++++++++++ .../common/fused_moe/test_grouped_topk.py | 98 ++++++++ 2 files changed, 329 insertions(+) create mode 100644 lightllm/common/fused_moe/grouped_topk.py create mode 100755 unit_tests/common/fused_moe/test_grouped_topk.py diff --git a/lightllm/common/fused_moe/grouped_topk.py b/lightllm/common/fused_moe/grouped_topk.py new file mode 100644 index 000000000..56cd92683 --- /dev/null +++ b/lightllm/common/fused_moe/grouped_topk.py @@ -0,0 +1,231 @@ +# adopt from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396 +import torch +import triton +import triton.language as tl +from triton.language.standard import _log2, sum, zeros_like + + +@triton.jit +def _compare_and_swap(x, ids, flip, i: tl.core.constexpr, n_dims: tl.core.constexpr): + n_outer: tl.core.constexpr = x.numel >> n_dims + shape: tl.core.constexpr = [n_outer * 2 ** i, 2, 2 ** (n_dims - i - 1)] + y = tl.core.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = tl.core.arange(0, 2)[None, :, None] + left = tl.core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape) + right = tl.core.broadcast_to(sum(y * mask, 1)[:, None, :], shape) + left = tl.core.reshape(left, x.shape) + right = tl.core.reshape(right, x.shape) + + # idx + y_idx = tl.core.reshape(ids, shape) + left_idx = tl.core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape) + right_idx = tl.core.broadcast_to(sum(y_idx * mask, 1)[:, None, :], shape) + left_idx = tl.core.reshape(left_idx, x.shape) + right_idx = tl.core.reshape(right_idx, x.shape) + + # actual compare-and-swap + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + cond = (left > right) ^ flip + + ret = ix ^ tl.core.where(cond, ileft ^ iright, zeros_like(ix)) + + new_ids = ids ^ tl.core.where(cond, left_idx ^ right_idx, zeros_like(ids)) + + return ret.to(x.dtype, bitcast=True), new_ids + + +@triton.jit +def _bitonic_merge(x, ids, stage: tl.core.constexpr, order: tl.core.constexpr, n_dims: tl.core.constexpr): + """ + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + """ + n_outer: tl.core.constexpr = x.numel >> n_dims + tl.core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: tl.core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2 ** stage] + flip = tl.core.reshape(tl.core.broadcast_to(tl.core.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in tl.core.static_range(stage): + x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) + return x, ids + + +@triton.jit +def argsort(x, ids, dim: tl.core.constexpr = None, descending: tl.core.constexpr = tl.core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: tl.core.constexpr = len(x.shape) - 1 if dim is None else dim + tl.core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: tl.core.constexpr = _log2(x.shape[_dim]) + + for i in tl.core.static_range(1, n_dims + 1): + x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) + return x, ids + + +@triton.jit +def grouped_topk_kernel( + gating_output_ptr, + gating_output_stride_m, + gating_output_stride_n, + correction_bias_ptr, + scores_buffer_ptr, # [token_num, total_expert_num] + scores_stride_m, + scores_stride_n, + scores_stride_token_m, + scores_stride_group, + scores_stride_group_v, + out_topk_weights, + out_topk_weights_stride_m, + out_topk_weights_stride_n, + out_topk_ids, + out_topk_ids_stride_m, + out_topk_ids_stride_n, + group_num, + group_expert_num, + total_expert_num, # group_num * group_expert_num == total_expert_num + topk_num, + group_topk_num, + IS_SIGMOID: tl.constexpr, + HAS_CORRECTION_BIAS: tl.constexpr, + EXPERT_BLOCK_SIZE: tl.constexpr, # tl.next_power_two_of(total_expert_num) + EXPERT_GROUP_NUM: tl.constexpr, # tl.next_power_two_of(group_num) + EXPERT_GROUP_SIZE: tl.constexpr, # tl.next_power_two_of(group_expert_num) + RENORMALIZE: tl.constexpr, +): + token_index = tl.program_id(axis=0) + offs_n = tl.arange(0, EXPERT_BLOCK_SIZE) + hidden_states = tl.load( + gating_output_ptr + token_index * gating_output_stride_m + offs_n, + mask=offs_n < total_expert_num, + other=-10000000.0, + ) + if IS_SIGMOID: + scores = tl.sigmoid(hidden_states) + else: + scores = tl.softmax(hidden_states) + + if HAS_CORRECTION_BIAS: + scores += tl.load(correction_bias_ptr + offs_n, mask=offs_n < total_expert_num, other=-10000000.0) + + offs_group = tl.arange(0, EXPERT_GROUP_NUM) + offs_group_v = tl.arange(0, EXPERT_GROUP_SIZE) + tl.store(scores_buffer_ptr + scores_stride_m * token_index + offs_n, scores, mask=offs_n < total_expert_num) + group_scores = tl.load( + scores_buffer_ptr + + scores_stride_token_m * token_index + + offs_group[:, None] * scores_stride_group + + offs_group_v[None, :] * scores_stride_group_v, + mask=(offs_group < group_num)[:, None] & (offs_group_v < group_expert_num)[None, :], + other=-10000000.0, + ) # [group, group_size] + + group_value = tl.max(group_scores, axis=1) # [group,] + sorted_group_value = tl.sort(group_value, descending=True) + group_topk_value = tl.sum(tl.where(offs_group == group_topk_num - 1, sorted_group_value, 0.0)) + mask_group_scores = tl.where( + ((group_value >= group_topk_value)[:, None]) & ((offs_group_v < group_expert_num)[None, :]), + group_scores, + -10000000.0, + ) + + tl.store( + scores_buffer_ptr + + scores_stride_token_m * token_index + + offs_group[:, None] * scores_stride_group + + offs_group_v[None, :] * scores_stride_group_v, + mask_group_scores, + mask=((offs_group < group_num)[:, None]) & ((offs_group_v < group_expert_num)[None, :]), + ) # [group, group_size] + + mask_scores = tl.load( + scores_buffer_ptr + scores_stride_m * token_index + offs_n, mask=offs_n < total_expert_num, other=-10000000.0 + ) + sorted_scores, sorted_indexes = argsort(mask_scores, offs_n, descending=True) + + if RENORMALIZE: + sum_scores = tl.sum(tl.where(offs_n < topk_num, sorted_scores, 0.0)) + renormlize_scores = sorted_scores / sum_scores + + tl.store( + out_topk_weights + token_index * out_topk_weights_stride_m + offs_n, + renormlize_scores, + mask=offs_n < topk_num, + ) + tl.store(out_topk_ids + token_index * out_topk_ids_stride_m + offs_n, sorted_indexes, mask=offs_n < topk_num) + else: + tl.store( + out_topk_weights + token_index * out_topk_weights_stride_m + offs_n, sorted_scores, mask=offs_n < topk_num + ) + tl.store(out_topk_ids + token_index * out_topk_ids_stride_m + offs_n, sorted_indexes, mask=offs_n < topk_num) + return + + +def triton_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + scoring_func: str = "softmax", +): + + if correction_bias is not None: + has_correction_bias = True + else: + has_correction_bias = False + + token_num, total_expert_num = gating_output.shape + if gating_output.dtype == torch.float64: + dtype = torch.float64 + else: + dtype = torch.float32 + + scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda") + out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda") + out_topk_ids = torch.empty((token_num, topk), dtype=torch.int32, device="cuda") + + assert total_expert_num % num_expert_group == 0 + + grouped_topk_kernel[(token_num,)]( + gating_output, + *gating_output.stride(), + correction_bias, + scores_buffer, + *scores_buffer.stride(), + *scores_buffer.view(token_num, num_expert_group, -1).stride(), + out_topk_weights, + *out_topk_weights.stride(), + out_topk_ids, + *out_topk_ids.stride(), + group_num=num_expert_group, + group_expert_num=total_expert_num // num_expert_group, + total_expert_num=total_expert_num, + topk_num=topk, + group_topk_num=topk_group, + IS_SIGMOID=scoring_func == "sigmoid", + HAS_CORRECTION_BIAS=has_correction_bias, + EXPERT_BLOCK_SIZE=triton.next_power_of_2(total_expert_num), + EXPERT_GROUP_NUM=triton.next_power_of_2(num_expert_group), + EXPERT_GROUP_SIZE=triton.next_power_of_2(total_expert_num // num_expert_group), + RENORMALIZE=renormalize, + num_warps=1, + num_stages=1, + ) + return out_topk_weights, out_topk_ids diff --git a/unit_tests/common/fused_moe/test_grouped_topk.py b/unit_tests/common/fused_moe/test_grouped_topk.py new file mode 100755 index 000000000..b3b0bc041 --- /dev/null +++ b/unit_tests/common/fused_moe/test_grouped_topk.py @@ -0,0 +1,98 @@ +import torch +import time +import pytest +import numpy as np +from lightllm.common.fused_moe.topk_select import grouped_topk +from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "expert_num, topk_group, group_num, topk_num, scoring_func, token_num", + [ + (*a, b, c) + for a in [(256, 4, 8, 8), (160, 3, 8, 6)] + for b in ["softmax", "sigmoid"] + for c in [1, 8, 256, 1024, 2048, 4096, 8192] + ], +) +def test_grouped_topk(expert_num, topk_group, group_num, topk_num, scoring_func, token_num): + print("test", expert_num, topk_group, group_num, topk_num, scoring_func, token_num) + dtype = torch.float32 + hidden_state = torch.randn((token_num, 1), dtype=dtype, device="cuda") + gating_output = torch.randn((token_num, expert_num), dtype=dtype, device="cuda") * 10 + correction_bias = torch.randn((expert_num,), dtype=dtype, device="cuda") + correction_bias[correction_bias <= 0.0] = 0.0 + + old_topk_weights, old_topk_ids = grouped_topk( + hidden_state, + gating_output=gating_output, + correction_bias=correction_bias, + topk=topk_num, + renormalize=True, + num_expert_group=group_num, + topk_group=topk_group, + scoring_func=scoring_func, + ) + + new_topk_weights, new_topk_ids = triton_grouped_topk( + None, + gating_output=gating_output, + correction_bias=correction_bias, + topk=topk_num, + renormalize=True, + num_expert_group=group_num, + topk_group=topk_group, + scoring_func=scoring_func, + ) + + torch.cuda.synchronize() + start = time.time() + for _ in range(60): + old_topk_weights, old_topk_ids = grouped_topk( + hidden_state, + gating_output=gating_output, + correction_bias=correction_bias, + topk=topk_num, + renormalize=True, + num_expert_group=group_num, + topk_group=topk_group, + scoring_func=scoring_func, + ) + torch.cuda.synchronize() + print(f"old cost time {time.time() - start} s") + + torch.cuda.synchronize() + start = time.time() + for _ in range(60): + new_topk_weights, new_topk_ids = triton_grouped_topk( + None, + gating_output=gating_output, + correction_bias=correction_bias, + topk=topk_num, + renormalize=True, + num_expert_group=group_num, + topk_group=topk_group, + scoring_func=scoring_func, + ) + torch.cuda.synchronize() + print(f"new cost time {time.time() - start} s") + + assert torch.equal(torch.sort(old_topk_ids, dim=1)[0], torch.sort(new_topk_ids, dim=1)[0]) + assert torch.allclose( + torch.sort(old_topk_weights, dim=1)[0], torch.sort(new_topk_weights, dim=1)[0], atol=1e-4, rtol=0 + ) + return + + +if __name__ == "__main__": + pytest.main() From a38c3fd87cab403c7f2f699efbfd6d9d0eb71d44 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 11 Feb 2025 13:08:48 +0800 Subject: [PATCH 2/2] used triton grouped topk as default. --- lightllm/common/fused_moe/topk_select.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/fused_moe/topk_select.py index c1cc49732..cf5f245bf 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/fused_moe/topk_select.py @@ -102,13 +102,14 @@ def select_experts( scoring_func: str = "softmax", custom_routing_function: Optional[Callable] = None, ): - from lightllm.common.fused_moe.topk_select import fused_topk, grouped_topk + from lightllm.common.fused_moe.topk_select import fused_topk + from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk # DeekSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( + topk_weights, topk_ids = triton_grouped_topk( hidden_states=hidden_states, gating_output=router_logits, correction_bias=correction_bias,