Skip to content

feat: 增加摩尔线程大模型算子和部分传统模型算子 #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion operatorspy/tests/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@ def test_bang(lib, test_cases):
test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
destroy_handle(lib, handle)

def test_musa(lib, test_cases):
import torch_musa

device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)
for c_shape, a_shape, b_shape, inplace in test_cases:
test(lib, handle, "musa", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "musa", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
Expand Down Expand Up @@ -163,6 +173,8 @@ def test_bang(lib, test_cases):
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if not (args.cpu or args.cuda or args.bang):
if args.musa:
test_musa(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.musa):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
14 changes: 13 additions & 1 deletion operatorspy/tests/causal_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def test_maca(lib, test_cases):

destroy_handle(lib, handle)

def test_musa(lib, test_cases):
import torch_musa
device = DeviceEnum.DEVICE_MUSA

handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "musa", x_shape, x_stride)

destroy_handle(lib, handle)

if __name__ == "__main__":
test_cases = [
# x_shape, x_stride
Expand Down Expand Up @@ -161,6 +171,8 @@ def test_maca(lib, test_cases):
test_ascend(lib, test_cases)
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
if args.musa:
test_musa(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
14 changes: 13 additions & 1 deletion operatorspy/tests/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ def test_bang(lib, test_cases):
test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
destroy_handle(lib, handle)

def test_musa(lib, test_cases):
import torch_musa

device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)
for y_shape, x_shape, y_stride, x_stride in test_cases:
test(lib, handle, "musa", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16)
test(lib, handle, "musa", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
Expand Down Expand Up @@ -174,6 +184,8 @@ def test_bang(lib, test_cases):
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if not (args.cpu or args.cuda or args.bang):
if args.musa:
test_musa(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.musa):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
35 changes: 34 additions & 1 deletion operatorspy/tests/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,37 @@ def test_maca(lib, test_cases):

destroy_handle(lib, handle)

def test_musa(lib, test_cases):
import torch_musa

device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)
for (
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
) in test_cases:
test(
lib,
handle,
"musa",
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
)

if __name__ == "__main__":
test_cases = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype
Expand Down Expand Up @@ -387,6 +418,8 @@ def test_maca(lib, test_cases):
test_ascend(lib, test_cases)
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
if args.musa:
test_musa(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
15 changes: 12 additions & 3 deletions operatorspy/tests/random_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
if(torch_device == 'maca'):
indices = torch.zeros([1], dtype = torch.int64).to('cuda')
else:
indices = torch.zeros([1], dtype = torch.uint64).to(torch_device)
indices = torch.zeros([1], dtype = torch.int64).to(torch_device)
x_tensor = to_tensor(data, lib)
indices_tensor = to_tensor(indices, lib)
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_ascend(lib, test_cases):
for (voc, random_val, topp, topk, temperature) in test_cases:
test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)

def test_maca(lib, test_cases):
device = DeviceEnum.DEVICE_MACA
handle = create_handle(lib, device)
Expand All @@ -179,6 +179,13 @@ def test_maca(lib, test_cases):
destroy_handle(lib, handle)


def test_musa(lib, test_cases):
import torch_musa
device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases:
test(lib, handle, "musa", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)

if __name__ == "__main__":
test_cases = [
Expand Down Expand Up @@ -236,6 +243,8 @@ def test_maca(lib, test_cases):
test_ascend(lib, test_cases)
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
if args.musa:
test_musa(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
22 changes: 22 additions & 0 deletions operatorspy/tests/rearrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,26 @@ def test_maca(lib, test_cases):
test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)

def test_musa(lib, test_cases):
import torch_musa
device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "musa", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)

def test_musa(lib, test_cases):
import torch_musa
device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "musa", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)

if __name__ == "__main__":
args = get_args()
test_cases = [
Expand Down Expand Up @@ -156,4 +176,6 @@ def test_maca(lib, test_cases):
test_ascend(lib, test_cases)
if args.maca:
test_maca(lib, test_cases)
if args.musa:
test_musa(lib, test_cases)
print("\033[92mTest passed!\033[0m")
14 changes: 13 additions & 1 deletion operatorspy/tests/relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ def test_bang(lib, test_cases):
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
destroy_handle(lib, handle)

def test_musa(lib, test_cases):
import torch_musa

device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)
for tensor_shape, inplace in test_cases:
test(lib, handle, "musa", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "musa", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
Expand Down Expand Up @@ -172,6 +182,8 @@ def test_bang(lib, test_cases):
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if not (args.cpu or args.cuda or args.bang):
if args.musa:
test_musa(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.musa):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
12 changes: 11 additions & 1 deletion operatorspy/tests/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ def test_maca(lib, test_cases):

destroy_handle(lib, handle)

def test_musa(lib, test_cases):
import torch_musa
device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases:
test(lib, handle, "musa", y_shape, x_shape, w_shape, dtype, w_dtype)
destroy_handle(lib, handle)

if __name__ == "__main__":
test_cases = [
# y_shape, x_shape, w_shape, dtype, w_dtype
Expand Down Expand Up @@ -174,6 +182,8 @@ def test_maca(lib, test_cases):
test_ascend(lib, test_cases)
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
if args.musa:
test_musa(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
14 changes: 12 additions & 2 deletions operatorspy/tests/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
pos[2 * i] = posTmp[i]
pos[2 * i + 1] = 0
theta = 1e4
if torch_device == 'mlu' or torch_device == 'npu':
if torch_device == 'mlu' or torch_device == 'npu' or torch_device == 'musa':
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
pos = pos.to(torch_device)
t = t.to(torch_device)
Expand Down Expand Up @@ -181,6 +181,14 @@ def test_maca(lib, test_cases) :
test(lib, handle, "maca", shape, strides, dtype)
destroy_handle(lib, handle)

def test_musa(lib, test_cases) :
import torch_musa
device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "musa", shape, strides, dtype)
destroy_handle(lib, handle)

if __name__ == "__main__":
test_cases = [
((1, 32, 128), None, torch.float16),
Expand Down Expand Up @@ -233,6 +241,8 @@ def test_maca(lib, test_cases) :
test_ascend(lib, test_cases)
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
if args.musa:
test_musa(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
16 changes: 16 additions & 0 deletions operatorspy/tests/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,20 @@ def test_maca(lib, test_cases):

destroy_handle(lib, handle)

def test_musa(lib, test_cases):
import torch_musa
device = DeviceEnum.DEVICE_MUSA
handle = create_handle(lib, device)

for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place(
lib, handle, "musa", shape, a_stride, b_stride, c_stride, dtype
)
test_in_place1(lib, handle, "musa", shape, a_stride, b_stride, dtype)
test_in_place2(lib, handle, "musa", shape, a_stride, b_stride, dtype)

destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
Expand Down Expand Up @@ -307,4 +321,6 @@ def test_maca(lib, test_cases):
test_ascend(lib, test_cases)
if args.maca:
test_maca(lib, test_cases)
if args.musa:
test_musa(lib, test_cases)
print("\033[92mTest passed!\033[0m")
5 changes: 5 additions & 0 deletions operatorspy/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def get_args():
action="store_true",
help="Run ASCEND NPU test",
)
parser.add_argument(
"--musa",
action="store_true",
help="Run MUSA test",
)

return parser.parse_args()

Expand Down
14 changes: 14 additions & 0 deletions src/devices/handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#ifdef ENABLE_METAX_GPU
#include "./maca/maca_handle.h"
#endif
#ifdef ENABLE_MTHREADS_GPU
#include "./musa/musa_handle.h"
#endif


__C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device device, int device_id) {
Expand Down Expand Up @@ -48,6 +51,11 @@ __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device d
case DevMetaxGpu: {
return createMacaHandle((MacaHandle_t *) handle_ptr, device_id);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return createMusaHandle((MusaHandle_t *) handle_ptr, device_id);
}
#endif
}
return STATUS_BAD_DEVICE;
Expand Down Expand Up @@ -81,6 +89,12 @@ __C infiniopStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
case DevMetaxGpu: {
return deleteMacaHandle((MacaHandle_t) handle);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
deleteMusaHandle((MusaHandle_t) handle);
return STATUS_SUCCESS;
}
#endif
}
return STATUS_BAD_DEVICE;
Expand Down
Loading