Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature, Hardware] Enable DeepseekV3 on AMD GPUs #2601

Merged
merged 27 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
45dfe9e
Add hip config
BruceXcluding Dec 26, 2024
d315402
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 27, 2024
57a5006
Fix AMD moe_align and triton stage config
Dec 27, 2024
3fa113b
fix fused_moe.py conflict
BruceXcluding Dec 27, 2024
f1c48e2
Merge branch 'main' into main
HaiShaw Dec 27, 2024
6fb6d7c
fix typo
BruceXcluding Dec 27, 2024
83a682a
Merge remote-tracking branch 'upstream/main'
BruceXcluding Dec 27, 2024
732c6b5
remove not_hip in fused_moe
BruceXcluding Dec 27, 2024
a645383
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 28, 2024
8a62e6e
Add normalize_e4m3fnuz into block quant
Dec 28, 2024
2b4afba
merged upstream and add amd block_shape moe config
Dec 28, 2024
f0122b7
Fix shmem/LDS size constraint on AMD MI3xx
HaiShaw Dec 29, 2024
4379b5c
Lint
HaiShaw Dec 29, 2024
fe54618
Merge branch 'main' into main
HaiShaw Dec 29, 2024
0a3b5c1
fix MOE_PADDING=1 mismatch
Dec 29, 2024
ba1597c
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 29, 2024
1c48b3d
fix e4m3fnuz scaling max
Dec 30, 2024
a021825
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 30, 2024
7aad77e
refactor setup.py with rocm
Dec 30, 2024
3dddac3
merge haishaw FP8 Numerical fix
Dec 30, 2024
ca11e11
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 30, 2024
abc497d
sperate sgl-kernel with amd backend
BruceXcluding Dec 31, 2024
4bb3332
Merge 'main' into 'main'
Jan 2, 2025
b10c089
Clang format
BruceXcluding Jan 2, 2025
bf2ad5d
Merge branch 'main' into main
zhyncs Jan 2, 2025
3b63a5f
Merge branch 'main' into main
zhyncs Jan 2, 2025
7b8d375
Merge branch 'main' into main
HaiShaw Jan 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# docker build --build-arg SGL_BRANCH=v0.4.1 -t v0.4.1-rocm620 -f Dockerfile.rocm .

# default base image
ARG BASE_IMAGE="rocm/vllm-dev:20241022"
ARG BASE_IMAGE="rocm/vllm-dev:20241226"
BruceXcluding marked this conversation as resolved.
Show resolved Hide resolved

FROM $BASE_IMAGE AS base
USER root
Expand Down
4 changes: 2 additions & 2 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"]
[project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi",
"hf_transfer", "huggingface_hub", "interegular", "modelscope",
"orjson", "outlines>=0.0.44,<0.1.0",
"orjson", "outlines>=0.1.7", "outlines-core>=0.1.17",
"packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
Expand All @@ -27,7 +27,7 @@ srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cu

# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.5.dev414+ga2646938.rocm634"]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt_xpu = ["sglang[runtime_common]"]
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/constrained/outlines_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import interegular
import torch
from outlines.fsm.guide import RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _decode_grouped_att_m_fwd(
sm_scale,
logit_cap,
):
BLOCK = 32
BLOCK = 16 if is_hip() else 32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not cut by half for HIP globally here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't work well in latest vllm with BLOCK 32

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part we can not take as it is - it will cost performance of all other models in large margin.

Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]

Expand Down
8 changes: 5 additions & 3 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import direct_register_custom_op, get_device_name
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip

is_hip_ = is_hip()

logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
Expand Down Expand Up @@ -437,7 +439,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
"num_stages": 2 if is_hip_ else 4,
}
if M <= E:
config = {
Expand All @@ -446,7 +448,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_stages": 2 if is_hip_ 4,
}
else:
config = {
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def create_weights(

# WEIGHT
weight_dtype = (
torch.float8_e4m3fn
torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not have this, serialized weight is always OCP (torch.float8_e4m3fn)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would encounter the error "python/sglang/srt/layers/quantization/fp8_kernel.py:176:33: error: Unsupported conversion from 'f8E4M3FN' to 'f16'
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]" with torch.float8_e4m3fn at w8a8_block_fp8_matmul

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check how normalize_e4m3fn_to_e4m3fnuz is used.
Basically - we do not expected non-OCP/e4m3fn dtype in the quantized model.

if self.quant_config.is_checkpoint_fp8_serialized
else params_dtype
)
Expand Down Expand Up @@ -432,7 +432,7 @@ def create_weights(
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported

if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
params_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same problem here - check out the previous usage from normalize_e4m3fn_to_e4m3fnuz

tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import triton
import triton.language as tl

from sglang.srt.utils import is_hip

@triton.jit
def _per_token_group_quant_fp8(
Expand Down Expand Up @@ -65,7 +65,7 @@ def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = torch.float8_e4m3fn,
dtype: torch.dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try not to use dynamic binding to assign default - we can go without default to param.

) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.

Expand Down
46 changes: 26 additions & 20 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import is_flashinfer_available
from sglang.srt.utils import is_flashinfer_available, is_hip

if is_flashinfer_available():
from flashinfer import bmm_fp8
Expand Down Expand Up @@ -573,13 +573,16 @@ def forward_absorb(
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

if self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
if self.w_kc.dtype == torch.float8_e4m3fn or torch.float8_e4m3fnuz:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify this as orig.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif orig would conflict with v3 hf_config.architect

if is_hip():
q_nope_out = torch.bmm(q_nope.to(torch.bfloat16).transpose(0, 1), self.w_kc.to(torch.bfloat16))
else:
q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change else below to elif self.w_kc.dtype == torch.float8_e4m3fnuz

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, please add a TODO (lack of bmm for torch.float8_e4m3fnuz) here.

else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
Expand All @@ -598,17 +601,20 @@ def forward_absorb(
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

if self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn
)
attn_bmm_output = bmm_fp8(
attn_output_val,
self.w_vc,
attn_output_scale,
self.w_scale,
torch.bfloat16,
)
if self.w_vc.dtype == torch.float8_e4m3fn or torch.float8_e4m3fnuz:
if is_hip():
attn_bmm_output = torch.bmm(attn_output.to(torch.bfloat16).transpose(0, 1), self.w_vc.to(torch.bfloat16))
else:
attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn
)
attn_bmm_output = bmm_fp8(
attn_output_val,
self.w_vc,
attn_output_scale,
self.w_scale,
torch.bfloat16,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as previous

else:
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
Expand Down Expand Up @@ -942,7 +948,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# This may affect the accuracy of fp8 model.
if (
hasattr(self.quant_config, "weight_block_size")
and w.dtype == torch.float8_e4m3fn
and w.dtype == torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use or instead of if... else...

):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
Expand Down
184 changes: 112 additions & 72 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import shutil
import zipfile
from pathlib import Path
import torch
def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
return torch.version.hip is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return torch.version.hip is not None
return torch.cuda.is_available() and torch.version.hip

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
Expand Down Expand Up @@ -58,80 +62,116 @@ def update_wheel_platform_tag():
old_wheel.rename(new_wheel)


setup(
name="sgl-kernel",
version=get_version(),
packages=["sgl_kernel"],
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
"sgl_kernel.ops.warp_reduce_cuda",
[
"src/sgl-kernel/csrc/warp_reduce.cc",
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",

if not is_hip():
setup(
name="sgl-kernel",
version=get_version(),
packages=["sgl_kernel"],
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
"sgl_kernel.ops.warp_reduce_cuda",
[
"src/sgl-kernel/csrc/warp_reduce.cc",
"src/sgl-kernel/csrc/warp_reduce_kernel.cu",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.custom_reduce_cuda",
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce.cc",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.custom_reduce_cuda",
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/trt_reduce.cc",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.moe_align_block_size",
[
"src/sgl-kernel/csrc/moe_align_kernel.cu",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
CUDAExtension(
"sgl_kernel.ops.moe_align_block_size",
[
"src/sgl-kernel/csrc/moe_align_kernel.cu",
],
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
extra_compile_args={
"nvcc": [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)
else:
hipcc_flags = [
"-D__HIP_PLATFORM_AMD__=1",
"--amdgpu-target=gfx942",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broad this target list

"-DENABLE_BF16", # Enable BF16 for cuda_version >= 11
"-DENABLE_FP8", # Enable FP8 for cuda_version >= 11.8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comments above 2 lines make no sense.

]
setup(
name="sgl-kernel",
version=get_version(),
packages=["sgl_kernel"],
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
"sgl_kernel.ops.moe_align_block_size",
[
"src/sgl-kernel/csrc/moe_align_kernel.cu",
],
"cxx": ["-O3"],
},
libraries=["c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)
extra_compile_args={
"nvcc": hipcc_flags + [
"-O3",
"-Xcompiler",
"-fPIC",
],
"cxx": ["-O3"],
},
libraries=["hiprtc", "amdhip64", "c10", "torch", "torch_python"],
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"],
)

update_wheel_platform_tag()