From b10c08930fbac5db12219947e21c509ae822713c Mon Sep 17 00:00:00 2001 From: yigex Date: Thu, 2 Jan 2025 15:56:53 +0000 Subject: [PATCH] Clang format --- .../srt/layers/moe/fused_moe_triton/fused_moe.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index e301167218..1c8700783a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -11,15 +11,19 @@ import torch import triton import triton.language as tl -from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from vllm import _custom_ops as ops from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip -from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size -is_hip_flag = True if is_hip() else False +is_hip_flag = False +if not is_hip(): + from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size + + is_hip_flag = False +else: + is_hip_flag = True logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @@ -408,7 +412,7 @@ def moe_align_block_size( ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) if num_experts >= 224: - if enable_moe_align_block_size_triton: + if enable_moe_align_block_size_triton or is_hip_flag: moe_align_block_size_triton( topk_ids, num_experts,