-
Notifications
You must be signed in to change notification settings - Fork 654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature, Hardware] Enable DeepseekV3 on AMD GPUs #2601
Changes from 21 commits
45dfe9e
d315402
57a5006
3fa113b
f1c48e2
6fb6d7c
83a682a
732c6b5
a645383
8a62e6e
2b4afba
f0122b7
4379b5c
fe54618
0a3b5c1
ba1597c
1c48b3d
a021825
7aad77e
3dddac3
ca11e11
abc497d
4bb3332
b10c089
bf2ad5d
3b63a5f
7b8d375
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,6 +3,14 @@ | |||||
import zipfile | ||||||
from pathlib import Path | ||||||
|
||||||
import torch | ||||||
|
||||||
|
||||||
def is_hip() -> bool: | ||||||
"""Return whether it is HIP on the AMD ROCm platform.""" | ||||||
return torch.version.hip is not None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
||||||
|
||||||
from setuptools import setup | ||||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||||||
|
||||||
|
@@ -58,38 +66,64 @@ def update_wheel_platform_tag(): | |||||
old_wheel.rename(new_wheel) | ||||||
|
||||||
|
||||||
nvcc_flags = [ | ||||||
"-O3", | ||||||
"-Xcompiler", | ||||||
"-fPIC", | ||||||
"-gencode=arch=compute_75,code=sm_75", | ||||||
"-gencode=arch=compute_80,code=sm_80", | ||||||
"-gencode=arch=compute_89,code=sm_89", | ||||||
"-gencode=arch=compute_90,code=sm_90", | ||||||
"-U__CUDA_NO_HALF_OPERATORS__", | ||||||
"-U__CUDA_NO_HALF2_OPERATORS__", | ||||||
] | ||||||
cxx_flags = ["-O3"] | ||||||
libraries = ["c10", "torch", "torch_python"] | ||||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] | ||||||
ext_modules = [ | ||||||
CUDAExtension( | ||||||
name="sgl_kernel.ops._kernels", | ||||||
sources=[ | ||||||
"src/sgl-kernel/csrc/warp_reduce_kernel.cu", | ||||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu", | ||||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu", | ||||||
"src/sgl-kernel/csrc/moe_align_kernel.cu", | ||||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu", | ||||||
], | ||||||
extra_compile_args={ | ||||||
"nvcc": nvcc_flags, | ||||||
"cxx": cxx_flags, | ||||||
}, | ||||||
libraries=libraries, | ||||||
extra_link_args=extra_link_args, | ||||||
), | ||||||
] | ||||||
if not is_hip(): | ||||||
nvcc_flags = [ | ||||||
"-O3", | ||||||
"-Xcompiler", | ||||||
"-fPIC", | ||||||
"-gencode=arch=compute_75,code=sm_75", | ||||||
"-gencode=arch=compute_80,code=sm_80", | ||||||
"-gencode=arch=compute_89,code=sm_89", | ||||||
"-gencode=arch=compute_90,code=sm_90", | ||||||
"-U__CUDA_NO_HALF_OPERATORS__", | ||||||
"-U__CUDA_NO_HALF2_OPERATORS__", | ||||||
] | ||||||
cxx_flags = ["-O3"] | ||||||
libraries = ["c10", "torch", "torch_python"] | ||||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] | ||||||
ext_modules = [ | ||||||
CUDAExtension( | ||||||
name="sgl_kernel.ops._kernels", | ||||||
sources=[ | ||||||
"src/sgl-kernel/csrc/warp_reduce_kernel.cu", | ||||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu", | ||||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu", | ||||||
"src/sgl-kernel/csrc/moe_align_kernel.cu", | ||||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu", | ||||||
], | ||||||
extra_compile_args={ | ||||||
"nvcc": nvcc_flags, | ||||||
"cxx": cxx_flags, | ||||||
}, | ||||||
libraries=libraries, | ||||||
extra_link_args=extra_link_args, | ||||||
), | ||||||
] | ||||||
else: | ||||||
hipcc_flags = [ | ||||||
"-D__HIP_PLATFORM_AMD__=1", | ||||||
"--amdgpu-target=gfx90a,gfx940,gfx941,gfx942", | ||||||
] | ||||||
ext_modules=[ | ||||||
CUDAExtension( | ||||||
"sgl_kernel.ops.moe_align_block_size", | ||||||
[ | ||||||
"src/sgl-kernel/csrc/moe_align_kernel.cu", | ||||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you need to use AMD for compilation, I recommend not compiling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have any suggestions? @yzh119 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems need to compile reduce kernel here, otherwise some archs will not be imported due to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May we use is_hip there There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zhyncs In case CUDA/HIP compatible kernel files, we don't use separate files (that is point of HIP), I believe that is one of the cases. We do for sure separate files for AMD specific kernels or kernel implementations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zyeric the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it's better to separate amd/nv kernels as 2 different backends? at this moment, moe_align_kernel is only required for amd backend, while in near future, there are ck kernels added to amd backend. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @HaiShaw I think the root cause is that the import path is still There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Current version works for me, many thanks :D
|
||||||
], | ||||||
extra_compile_args={ | ||||||
"nvcc": hipcc_flags | ||||||
+ [ | ||||||
"-O3", | ||||||
"-Xcompiler", | ||||||
"-fPIC", | ||||||
], | ||||||
"cxx": ["-O3"], | ||||||
}, | ||||||
libraries=["hiprtc", "amdhip64", "c10", "torch", "torch_python"], | ||||||
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], | ||||||
), | ||||||
] | ||||||
|
||||||
setup( | ||||||
name="sgl-kernel", | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What issues could occur if the image isn't updated? Minimize updating the base image whenever possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhyncs we (AMD) will have to decide on this, so ignore it for now.