Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature, Hardware] Enable DeepseekV3 on AMD GPUs #2601

Merged
merged 27 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
45dfe9e
Add hip config
BruceXcluding Dec 26, 2024
d315402
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 27, 2024
57a5006
Fix AMD moe_align and triton stage config
Dec 27, 2024
3fa113b
fix fused_moe.py conflict
BruceXcluding Dec 27, 2024
f1c48e2
Merge branch 'main' into main
HaiShaw Dec 27, 2024
6fb6d7c
fix typo
BruceXcluding Dec 27, 2024
83a682a
Merge remote-tracking branch 'upstream/main'
BruceXcluding Dec 27, 2024
732c6b5
remove not_hip in fused_moe
BruceXcluding Dec 27, 2024
a645383
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 28, 2024
8a62e6e
Add normalize_e4m3fnuz into block quant
Dec 28, 2024
2b4afba
merged upstream and add amd block_shape moe config
Dec 28, 2024
f0122b7
Fix shmem/LDS size constraint on AMD MI3xx
HaiShaw Dec 29, 2024
4379b5c
Lint
HaiShaw Dec 29, 2024
fe54618
Merge branch 'main' into main
HaiShaw Dec 29, 2024
0a3b5c1
fix MOE_PADDING=1 mismatch
Dec 29, 2024
ba1597c
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 29, 2024
1c48b3d
fix e4m3fnuz scaling max
Dec 30, 2024
a021825
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 30, 2024
7aad77e
refactor setup.py with rocm
Dec 30, 2024
3dddac3
merge haishaw FP8 Numerical fix
Dec 30, 2024
ca11e11
Merge branch 'sgl-project:main' into main
BruceXcluding Dec 30, 2024
abc497d
sperate sgl-kernel with amd backend
BruceXcluding Dec 31, 2024
4bb3332
Merge 'main' into 'main'
Jan 2, 2025
b10c089
Clang format
BruceXcluding Jan 2, 2025
bf2ad5d
Merge branch 'main' into main
zhyncs Jan 2, 2025
3b63a5f
Merge branch 'main' into main
zhyncs Jan 2, 2025
7b8d375
Merge branch 'main' into main
HaiShaw Jan 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cu

# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"]
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post2.dev1+g1ef171e0.rocm624"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What issues could occur if the image isn't updated? Minimize updating the base image whenever possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs we (AMD) will have to decide on this, so ignore it for now.

# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt_xpu = ["sglang[runtime_common]"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ def _decode_grouped_att_m_fwd(
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]

# [TODO] work around shmem limit on MI3xx
if is_hip_ and Lk >= 576:
HaiShaw marked this conversation as resolved.
Show resolved Hide resolved
BLOCK = 16

if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
Expand Down
21 changes: 9 additions & 12 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,14 @@
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops

from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip

not_hip = False
if not is_hip():
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size

not_hip = True

is_hip_ = is_hip()
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0

Expand Down Expand Up @@ -272,7 +268,7 @@ def moe_align_block_size(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
if not_hip and num_experts >= 224:
if num_experts >= 224:
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
Expand Down Expand Up @@ -326,11 +322,12 @@ def invoke_fused_moe_kernel(

padded_size = 0
if use_fp8_w8a8:
padded_size = padding_size
assert B_scale is not None
if block_shape is None:
padded_size = padding_size
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
padding_size = 0
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
Expand Down Expand Up @@ -463,7 +460,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4,
"num_stages": 2 if is_hip_ else 4,
}
if M <= E:
config = {
Expand All @@ -472,7 +469,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_stages": 2 if is_hip_ else 4,
}
else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
Expand All @@ -482,7 +479,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"num_stages": 2 if is_hip_ else 3,
}
else:
config = {
Expand Down Expand Up @@ -727,7 +724,7 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None,
):
padded_size = padding_size
if not use_fp8_w8a8:
if not use_fp8_w8a8 or block_shape is not None:
padded_size = 0

# Check constraints.
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,9 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
if "GLOO_SOCKET_IFNAME" not in os.environ:
os.environ["GLOO_SOCKET_IFNAME"] = "eth0"
# TODO(fix socket error with gpu backend)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used for cpu backend or specific workstation? get RuntimeError: [enforce fail at pytorch/third_party/gloo/gloo/transport/tcp/device.cc:83] ifa != nullptr. Unable to find address for: eth0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used for multi-node tensor parallelism. Instead of using comments, we suggest adding an is_hip flag.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the value set for the GLOO_SOCKET_IFNAME environment variable should depend on the name of the network interface card in each user's system and should not be hard-coded as eth0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wufann If the user's value is not eth0, they should specify it explicitly, this applies only when no setting is provided, with eth0 as the default.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs Different network interface ( "ens" ) may be used. Also they may test in a single node envrionment where IB is not configured. In that case IB should be disabled

# if "GLOO_SOCKET_IFNAME" not in os.environ:
# os.environ["GLOO_SOCKET_IFNAME"] = "eth0"

# Set prometheus env vars
if server_args.enable_metrics:
Expand Down
51 changes: 51 additions & 0 deletions sgl-kernel/amd/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
cmake_minimum_required(VERSION 3.18)
Copy link
Member

@zhyncs zhyncs Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this, we only use CMakeLists.txt for clangd indexing, so it's not necessary.

project(sgl-kernel LANGUAGES CXX CUDA)

# Basic settings
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

find_package(Python3 COMPONENTS Interpreter Development REQUIRED)

# Find PyTorch
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE TORCH_CMAKE_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}")

find_package(Torch REQUIRED)

# Warp Reduce library
add_library(_kernels SHARE
src/sgl-kernel/csrc/moe_align_kernel.cu
src/sgl-kernel/csrc/sgl_kernel_ops.cu
)

target_include_directories(_kernels
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc
${CUDA_INCLUDE_DIRS}
${TORCH_INCLUDE_DIRS}
)

target_link_libraries(_kernels
PRIVATE
${TORCH_LIBRARIES}
Python3::Python
)

# Set common properties for both libraries
foreach(target _kernels)
set_target_properties(${target} PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
POSITION_INDEPENDENT_CODE ON
CUDA_RESOLVE_DEVICE_SYMBOLS ON
PREFIX ""
SUFFIX ".so"
)
endforeach()
201 changes: 201 additions & 0 deletions sgl-kernel/amd/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

1. Definitions.

"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.

"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.

"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.

"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.

"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.

"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.

"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).

"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.

"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."

"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.

2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.

3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.

4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:

(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and

(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and

(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and

(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.

You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.

5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.

6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.

7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.

8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.

9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.

END OF TERMS AND CONDITIONS

APPENDIX: How to apply the Apache License to your work.

To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright 2023-2024 SGLang Team

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
22 changes: 22 additions & 0 deletions sgl-kernel/amd/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.PHONY: tree ln install build clean test format

tree:
@tree --prune -I "__pycache__|*.egg-info|*.so|build"

ln:
@rm -rf build && cmake . -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DCMAKE_CUDA_COMPILER=nvcc -B build && rm -rf compile_commands.json && ln -s build/compile_commands.json compile_commands.json

install:
@pip install -e .

build:
@export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel

clean:
@rm -rf build dist *.egg-info

test:
@pytest tests/

format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
Loading
Loading