Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature, Hardware] Enable DeepseekV3 on AMD GPUs #2601

Merged
merged 27 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
45dfe9e
Add hip config
BruceXcluding Dec 26, 2024
d315402
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 27, 2024
57a5006
Fix AMD moe_align and triton stage config
Dec 27, 2024
3fa113b
fix fused_moe.py conflict
BruceXcluding Dec 27, 2024
f1c48e2
Merge branch 'main' into main
HaiShaw Dec 27, 2024
6fb6d7c
fix typo
BruceXcluding Dec 27, 2024
83a682a
Merge remote-tracking branch 'upstream/main'
BruceXcluding Dec 27, 2024
732c6b5
remove not_hip in fused_moe
BruceXcluding Dec 27, 2024
a645383
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 28, 2024
8a62e6e
Add normalize_e4m3fnuz into block quant
Dec 28, 2024
2b4afba
merged upstream and add amd block_shape moe config
Dec 28, 2024
f0122b7
Fix shmem/LDS size constraint on AMD MI3xx
HaiShaw Dec 29, 2024
4379b5c
Lint
HaiShaw Dec 29, 2024
fe54618
Merge branch 'main' into main
HaiShaw Dec 29, 2024
0a3b5c1
fix MOE_PADDING=1 mismatch
Dec 29, 2024
ba1597c
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 29, 2024
1c48b3d
fix e4m3fnuz scaling max
Dec 30, 2024
a021825
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 30, 2024
7aad77e
refactor setup.py with rocm
Dec 30, 2024
3dddac3
merge haishaw FP8 Numerical fix
Dec 30, 2024
ca11e11
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 30, 2024
abc497d
sperate sgl-kernel with amd backend
BruceXcluding Dec 31, 2024
4bb3332
Merge 'main' into 'main'
Jan 2, 2025
b10c089
Clang format
BruceXcluding Jan 2, 2025
bf2ad5d
Merge branch 'main' into main
zhyncs Jan 2, 2025
3b63a5f
Merge branch 'main' into main
zhyncs Jan 2, 2025
7b8d375
Merge branch 'main' into main
HaiShaw Jan 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cu

# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post2.dev1+g1ef171e0.rocm624"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What issues could occur if the image isn't updated? Minimize updating the base image whenever possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs we (AMD) will have to decide on this, so ignore it for now.

# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt_xpu = ["sglang[runtime_common]"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ def _decode_grouped_att_m_fwd(
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]

# [TODO] work around shmem limit on MI3xx
if is_hip_ and Lk >= 576:
HaiShaw marked this conversation as resolved.
Show resolved Hide resolved
BLOCK = 16

if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
Expand Down
21 changes: 9 additions & 12 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,14 @@
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops

from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip

not_hip = False
if not is_hip():
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size

not_hip = True

is_hip_ = is_hip()
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0

Expand Down Expand Up @@ -272,7 +268,7 @@ def moe_align_block_size(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
if not_hip and num_experts >= 224:
if num_experts >= 224:
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
Expand Down Expand Up @@ -326,11 +322,12 @@ def invoke_fused_moe_kernel(

padded_size = 0
if use_fp8_w8a8:
padded_size = padding_size
assert B_scale is not None
if block_shape is None:
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
padding_size = 0
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
Expand Down Expand Up @@ -463,7 +460,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
"num_stages": 2 if is_hip_ else 4,
}
if M <= E:
config = {
Expand All @@ -472,7 +469,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_stages": 2 if is_hip_ else 4,
}
else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
Expand All @@ -482,7 +479,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"num_stages": 2 if is_hip_ else 3,
}
else:
config = {
Expand Down Expand Up @@ -727,7 +724,7 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None,
):
padded_size = padding_size
if not use_fp8_w8a8:
if not use_fp8_w8a8 or block_shape is not None:
padded_size = 0

# Check constraints.
Expand Down
98 changes: 66 additions & 32 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
import zipfile
from pathlib import Path

import torch


def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
return torch.version.hip is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return torch.version.hip is not None
return torch.cuda.is_available() and torch.version.hip

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

Expand Down Expand Up @@ -58,38 +66,64 @@ def update_wheel_platform_tag():
old_wheel.rename(new_wheel)


nvcc_flags = [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
]
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
ext_modules = [
CUDAExtension(
name="sgl_kernel.ops._kernels",
sources=[
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
],
extra_compile_args={
"nvcc": nvcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
),
]
if not is_hip():
nvcc_flags = [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
]
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
ext_modules = [
CUDAExtension(
name="sgl_kernel.ops._kernels",
sources=[
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
],
extra_compile_args={
"nvcc": nvcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
),
]
else:
hipcc_flags = [
"-D__HIP_PLATFORM_AMD__=1",
"--amdgpu-target=gfx90a,gfx940,gfx941,gfx942",
]
ext_modules=[
CUDAExtension(
"sgl_kernel.ops.moe_align_block_size",
[
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you need to use AMD for compilation, I recommend not compiling sgl_kernel_ops.cu directly. Instead, use a separate file to avoid mixing NVIDIA and AMD's cu files, it's better to keep them separate. cc @HaiShaw @ispobock @merrymercy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any suggestions? @yzh119

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems need to compile reduce kernel here, otherwise some archs will not be imported due to No module named 'sgl_kernel.ops._kernels'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May we use is_hip there

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs In case CUDA/HIP compatible kernel files, we don't use separate files (that is point of HIP), I believe that is one of the cases. We do for sure separate files for AMD specific kernels or kernel implementations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zyeric the else: case seemingly have no impact to NV side, can you be more specific?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's better to separate amd/nv kernels as 2 different backends? at this moment, moe_align_kernel is only required for amd backend, while in near future, there are ck kernels added to amd backend.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HaiShaw I think the root cause is that the import path is still sgl_kernel.ops._kernels at https://github.com/BruceXcluding/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py#L1

Copy link

@zyeric zyeric Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current version works for me, many thanks :D

Accuracy: 0.951
Invalid: 0.000
Latency: 160.916 s
Output throughput: 869.145 token/s

],
extra_compile_args={
"nvcc": hipcc_flags
+ [
"-O3",
"-Xcompiler",
"-fPIC",
],
"cxx": ["-O3"],
},
libraries=["hiprtc", "amdhip64", "c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
]

setup(
name="sgl-kernel",
Expand Down
Loading