Skip to content

Commit 2bd55da

Browse files
authored
Merge pull request #152 from PanZezhong1725/maca
feat: 增加沐曦Maca(Ht)的大语言模型算子
2 parents 0353a33 + 7cf84bf commit 2bd55da

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2108
-14
lines changed

include/device.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
#define __DEVICE_H__
33

44
enum DeviceEnum {
5-
DevCpu,
6-
DevNvGpu,
7-
DevCambriconMlu,
8-
DevAscendNpu,
5+
DevCpu = 0,
6+
DevNvGpu = 1,
7+
DevCambriconMlu = 2,
8+
DevAscendNpu = 3,
9+
DevMetaxGpu = 4,
10+
DevMthreadsGpu = 5,
911
};
1012

1113
typedef enum DeviceEnum Device;

operatorspy/devices.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ class DeviceEnum:
33
DEVICE_CUDA = 1
44
DEVICE_BANG = 2
55
DEVICE_ASCEND = 3
6+
DEVICE_MACA = 4
7+
DEVICE_MUSA = 5

operatorspy/tests/causal_softmax.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ def test_ascend(lib, test_cases):
111111

112112
destroy_handle(lib, handle)
113113

114+
def test_maca(lib, test_cases):
115+
device = DeviceEnum.DEVICE_MACA
116+
handle = create_handle(lib, device)
117+
for x_shape, x_stride in test_cases:
118+
test(lib, handle, "cuda", x_shape, x_stride)
119+
120+
destroy_handle(lib, handle)
121+
114122
if __name__ == "__main__":
115123
test_cases = [
116124
# x_shape, x_stride
@@ -151,6 +159,8 @@ def test_ascend(lib, test_cases):
151159
test_bang(lib, test_cases)
152160
if args.ascend:
153161
test_ascend(lib, test_cases)
154-
if not (args.cpu or args.cuda or args.bang or args.ascend):
162+
if args.maca:
163+
test_maca(lib, test_cases)
164+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
155165
test_cpu(lib, test_cases)
156166
print("\033[92mTest passed!\033[0m")

operatorspy/tests/matmul.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,38 @@ def test_ascend(lib, test_cases):
293293

294294
destroy_handle(lib, handle)
295295

296+
def test_maca(lib, test_cases):
297+
device = DeviceEnum.DEVICE_MACA
298+
handle = create_handle(lib, device)
299+
300+
for (
301+
alpha,
302+
beta,
303+
a_shape,
304+
b_shape,
305+
c_shape,
306+
a_stride,
307+
b_stride,
308+
c_stride,
309+
dtype,
310+
) in test_cases:
311+
test(
312+
lib,
313+
handle,
314+
"cuda",
315+
alpha,
316+
beta,
317+
a_shape,
318+
b_shape,
319+
c_shape,
320+
a_stride,
321+
b_stride,
322+
c_stride,
323+
dtype,
324+
)
325+
326+
destroy_handle(lib, handle)
327+
296328
if __name__ == "__main__":
297329
test_cases = [
298330
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype
@@ -353,6 +385,8 @@ def test_ascend(lib, test_cases):
353385
test_bang(lib, test_cases)
354386
if args.ascend:
355387
test_ascend(lib, test_cases)
356-
if not (args.cpu or args.cuda or args.bang or args.ascend):
388+
if args.maca:
389+
test_maca(lib, test_cases)
390+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
357391
test_cpu(lib, test_cases)
358392
print("\033[92mTest passed!\033[0m")

operatorspy/tests/random_sample.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,18 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
8383
)
8484
data = torch.arange(voc).float() * 0.0001
8585
_perm = torch.randperm(voc)
86-
data = data[_perm].to(x_dtype).to(torch_device)
86+
if (torch_device == 'maca'):
87+
data = data[_perm].to(x_dtype).to('cuda')
88+
else:
89+
data = data[_perm].to(x_dtype).to(torch_device)
8790
if(topp > 0 and topk > 1):
8891
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
8992
else:
9093
ans = random_sample_0(data)
91-
indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
94+
if(torch_device == 'maca'):
95+
indices = torch.zeros([1], dtype = torch.int64).to('cuda')
96+
else:
97+
indices = torch.zeros([1], dtype = torch.uint64).to(torch_device)
9298
x_tensor = to_tensor(data, lib)
9399
indices_tensor = to_tensor(indices, lib)
94100
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
@@ -163,7 +169,15 @@ def test_ascend(lib, test_cases):
163169
handle = create_handle(lib, device)
164170
for (voc, random_val, topp, topk, temperature) in test_cases:
165171
test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
166-
destroy_handle(lib, handle)
172+
destroy_handle(lib, handle)
173+
174+
def test_maca(lib, test_cases):
175+
device = DeviceEnum.DEVICE_MACA
176+
handle = create_handle(lib, device)
177+
for (voc, random_val, topp, topk, temperature) in test_cases:
178+
test(lib, handle, "maca", voc, random_val, topp, topk, temperature)
179+
destroy_handle(lib, handle)
180+
167181

168182

169183
if __name__ == "__main__":
@@ -220,6 +234,8 @@ def test_ascend(lib, test_cases):
220234
test_bang(lib, test_cases)
221235
if args.ascend:
222236
test_ascend(lib, test_cases)
223-
if not (args.cpu or args.cuda or args.bang or args.ascend):
237+
if args.maca:
238+
test_maca(lib, test_cases)
239+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
224240
test_cpu(lib, test_cases)
225241
print("\033[92mTest passed!\033[0m")

operatorspy/tests/rearrange.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ def test_ascend(lib, test_cases):
108108
test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride)
109109
destroy_handle(lib, handle)
110110

111+
def test_maca(lib, test_cases):
112+
device = DeviceEnum.DEVICE_MACA
113+
handle = create_handle(lib, device)
114+
for test_case in test_cases:
115+
x_shape, x_stride = test_case[0]
116+
y_shape, y_stride = test_case[1]
117+
test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
118+
destroy_handle(lib, handle)
119+
111120
if __name__ == "__main__":
112121
args = get_args()
113122
test_cases = [
@@ -145,4 +154,6 @@ def test_ascend(lib, test_cases):
145154
test_bang(lib, test_cases)
146155
if args.ascend:
147156
test_ascend(lib, test_cases)
157+
if args.maca:
158+
test_maca(lib, test_cases)
148159
print("\033[92mTest passed!\033[0m")

operatorspy/tests/rms_norm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ def test_ascend(lib, test_cases):
117117

118118
destroy_handle(lib, handle)
119119

120+
def test_maca(lib, test_cases):
121+
device = DeviceEnum.DEVICE_MACA
122+
handle = create_handle(lib, device)
123+
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases:
124+
test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype)
125+
126+
destroy_handle(lib, handle)
127+
120128
if __name__ == "__main__":
121129
test_cases = [
122130
# y_shape, x_shape, w_shape, dtype, w_dtype
@@ -164,6 +172,8 @@ def test_ascend(lib, test_cases):
164172
test_bang(lib, test_cases)
165173
if args.ascend:
166174
test_ascend(lib, test_cases)
167-
if not (args.cpu or args.cuda or args.bang or args.ascend):
175+
if args.maca:
176+
test_maca(lib, test_cases)
177+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
168178
test_cpu(lib, test_cases)
169179
print("\033[92mTest passed!\033[0m")

operatorspy/tests/rotary_embedding.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def rotary_embedding(t, pos, theta, torch_device):
4545
)
4646
freqs = torch.outer(pos, freqs)
4747
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
48-
4948
t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2))
5049
freqs_cis = reshape_for_broadcast(freqs_cis, t_)
5150
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
@@ -82,6 +81,10 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
8281
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
8382
pos = pos.to(torch_device)
8483
t = t.to(torch_device)
84+
elif torch_device == 'maca':
85+
ans = rotary_embedding(t, posTmp, theta, "cpu").to('cuda')
86+
pos = pos.to('cuda')
87+
t = t.to('cuda')
8588
else:
8689
t = t.to(torch_device)
8790
pos = pos.to(torch_device)
@@ -133,7 +136,6 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
133136
None,
134137
)
135138
)
136-
137139
assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2)
138140
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))
139141

@@ -172,6 +174,13 @@ def test_ascend(lib, test_cases) :
172174
test(lib, handle, "npu", shape, strides, dtype)
173175
destroy_handle(lib, handle)
174176

177+
def test_maca(lib, test_cases) :
178+
device = DeviceEnum.DEVICE_MACA
179+
handle = create_handle(lib, device)
180+
for shape, strides, dtype in test_cases:
181+
test(lib, handle, "maca", shape, strides, dtype)
182+
destroy_handle(lib, handle)
183+
175184
if __name__ == "__main__":
176185
test_cases = [
177186
((1, 32, 128), None, torch.float16),
@@ -222,6 +231,8 @@ def test_ascend(lib, test_cases) :
222231
test_bang(lib, test_cases)
223232
if args.ascend:
224233
test_ascend(lib, test_cases)
225-
if not (args.cpu or args.cuda or args.bang or args.ascend):
234+
if args.maca:
235+
test_maca(lib, test_cases)
236+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
226237
test_cpu(lib, test_cases)
227238
print("\033[92mTest passed!\033[0m")

operatorspy/tests/swiglu.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,18 @@ def test_ascend(lib, test_cases):
250250

251251
destroy_handle(lib, handle)
252252

253+
def test_maca(lib, test_cases):
254+
device = DeviceEnum.DEVICE_MACA
255+
handle = create_handle(lib, device)
256+
257+
for shape, a_stride, b_stride, c_stride, dtype in test_cases:
258+
test_out_of_place(
259+
lib, handle, "cuda", shape, a_stride, b_stride, c_stride, dtype)
260+
test_in_place1(lib, handle, "cuda", shape, a_stride, b_stride, dtype)
261+
test_in_place2(lib, handle, "cuda", shape, a_stride, b_stride, dtype)
262+
263+
destroy_handle(lib, handle)
264+
253265

254266
if __name__ == "__main__":
255267
test_cases = [
@@ -293,4 +305,6 @@ def test_ascend(lib, test_cases):
293305
test_bang(lib, test_cases)
294306
if args.ascend:
295307
test_ascend(lib, test_cases)
308+
if args.maca:
309+
test_maca(lib, test_cases)
296310
print("\033[92mTest passed!\033[0m")

operatorspy/tests/test_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ def get_args():
2727
action="store_true",
2828
help="Run ASCEND NPU test",
2929
)
30+
parser.add_argument(
31+
"--maca",
32+
action="store_true",
33+
help="Run ASCEND NPU test",
34+
)
3035

3136
return parser.parse_args()
3237

operatorspy/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def create_workspace(size, torch_device):
5050
if size == 0:
5151
return None
5252
import torch
53+
if (torch_device == 'maca'):
54+
return torch.zeros(size=(size,), dtype=torch.uint8, device='cuda')
5355
return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device)
5456

5557
def create_handle(lib, device, id=0):

src/devices/handle.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#ifdef ENABLE_ASCEND_NPU
1212
#include "./ascend/ascend_handle.h"
1313
#endif
14+
#ifdef ENABLE_METAX_GPU
15+
#include "./maca/maca_handle.h"
16+
#endif
1417

1518

1619
__C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device device, int device_id) {
@@ -40,6 +43,11 @@ __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device d
4043
case DevAscendNpu: {
4144
return createAscendHandle((AscendHandle_t *) handle_ptr, device_id);
4245
}
46+
#endif
47+
#ifdef ENABLE_METAX_GPU
48+
case DevMetaxGpu: {
49+
return createMacaHandle((MacaHandle_t *) handle_ptr, device_id);
50+
}
4351
#endif
4452
}
4553
return STATUS_BAD_DEVICE;
@@ -68,6 +76,11 @@ __C infiniopStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
6876
case DevAscendNpu: {
6977
return deleteAscendHandle((AscendHandle_t) handle);
7078
}
79+
#endif
80+
#ifdef ENABLE_METAX_GPU
81+
case DevMetaxGpu: {
82+
return deleteMacaHandle((MacaHandle_t) handle);
83+
}
7184
#endif
7285
}
7386
return STATUS_BAD_DEVICE;

0 commit comments

Comments
 (0)