Skip to content

Commit

Permalink
[Feature, Hardware] Enable DeepseekV3 on AMD GPUs (sgl-project#2601)
Browse files Browse the repository at this point in the history
Co-authored-by: root <[email protected]>
Co-authored-by: HAI <[email protected]>
Co-authored-by: Bruce Xue <[email protected]>
Co-authored-by: Yineng Zhang <[email protected]>
  • Loading branch information
5 people authored and XiaotongJiang committed Jan 3, 2025
1 parent 783947b commit ab9b4c0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ def _decode_grouped_att_m_fwd(
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]

# [TODO] work around shmem limit on MI3xx
if is_hip_ and Lk >= 576:
BLOCK = 16

if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,9 @@ def invoke_fused_moe_kernel(

padded_size = 0
if use_fp8_w8a8:
padded_size = padding_size
assert B_scale is not None
if block_shape is None:
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
assert len(block_shape) == 2
Expand Down Expand Up @@ -614,7 +614,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
"num_stages": 2 if is_hip_flag else 4,
}
if M <= E:
config = {
Expand All @@ -623,7 +623,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_stages": 2 if is_hip_flag else 4,
}
else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
Expand All @@ -633,7 +633,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"num_stages": 2 if is_hip_flag else 3,
}
else:
config = {
Expand Down Expand Up @@ -878,7 +878,7 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None,
):
padded_size = padding_size
if not use_fp8_w8a8:
if not use_fp8_w8a8 or block_shape is not None:
padded_size = 0

# Check constraints.
Expand Down

0 comments on commit ab9b4c0

Please sign in to comment.