diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 469ab5ed24..2b4871af98 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -406,6 +406,10 @@ def _decode_grouped_att_m_fwd( Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] + # [TODO] work around shmem limit on MI3xx + if is_hip_ and Lk >= 576: + BLOCK = 16 + if Lk == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 128881e8aa..1c8700783a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -477,9 +477,9 @@ def invoke_fused_moe_kernel( padded_size = 0 if use_fp8_w8a8: - padded_size = padding_size assert B_scale is not None if block_shape is None: + padded_size = padding_size A, A_scale = ops.scaled_fp8_quant(A, A_scale) else: assert len(block_shape) == 2 @@ -614,7 +614,7 @@ def get_default_config( "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 4, + "num_stages": 2 if is_hip_flag else 4, } if M <= E: config = { @@ -623,7 +623,7 @@ def get_default_config( "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4, + "num_stages": 2 if is_hip_flag else 4, } else: # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1] @@ -633,7 +633,7 @@ def get_default_config( "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 3, + "num_stages": 2 if is_hip_flag else 3, } else: config = { @@ -878,7 +878,7 @@ def fused_experts_impl( block_shape: Optional[List[int]] = None, ): padded_size = padding_size - if not use_fp8_w8a8: + if not use_fp8_w8a8 or block_shape is not None: padded_size = 0 # Check constraints.