Skip to content

Commit 6a100a8

Browse files
committed
未完成
1 parent f4405a3 commit 6a100a8

File tree

12 files changed

+508
-39
lines changed

12 files changed

+508
-39
lines changed

operatorspy/tests/avg_pool.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def test(
8989
f"Testing AvgPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}"
9090
)
9191

92-
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
93-
y = torch.rand(inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype).to(torch_device)
92+
x = torch.ones(x_shape, dtype=tensor_dtype).to(torch_device)
93+
y = torch.zeros(inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype).to(torch_device)
9494

9595
for i in range(NUM_PRERUN if PROFILE else 1):
9696
ans = pool(x, k_shape, padding, strides)
@@ -152,6 +152,10 @@ def test(
152152
elapsed = (time.time() - start_time) / NUM_ITERATIONS
153153
print(f" lib time: {elapsed :6f}")
154154

155+
156+
print(x)
157+
print(y)
158+
print(ans)
155159
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
156160
check_error(lib.infiniopDestroyAvgPoolDescriptor(descriptor))
157161

@@ -184,12 +188,23 @@ def test_bang(lib, test_cases):
184188
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
185189
destroy_handle(lib, handle)
186190

191+
def test_musa(lib, test_cases):
192+
import torch_musa
193+
194+
device = DeviceEnum.DEVICE_MUSA
195+
handle = create_handle(lib, device)
196+
for x_shape, kernel_shape, padding, strides in test_cases:
197+
# test(lib, handle, "musa", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
198+
test(lib, handle, "musa", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
199+
destroy_handle(lib, handle)
200+
187201

188202
if __name__ == "__main__":
189203
test_cases = [
190204
# x_shape, kernel_shape, padding, strides
191-
((1, 1, 10), (3,), (1,), (1,)),
192-
((32, 3, 224, 224), (3, 3), (1, 1), (2, 2)),
205+
# ((1, 1, 10), (3,), (1,), (1,)),
206+
((1, 1, 2, 2), (2, 2), (1, 1), (1, 1)),
207+
((32, 4, 224, 224), (3, 3), (1, 1), (2, 2)),
193208
((1, 1, 16, 16, 16), (5, 5, 5), (2, 2, 2), (2, 2, 2)),
194209
]
195210
args = get_args()
@@ -230,6 +245,8 @@ def test_bang(lib, test_cases):
230245
test_cuda(lib, test_cases)
231246
if args.bang:
232247
test_bang(lib, test_cases)
233-
if not (args.cpu or args.cuda or args.bang):
248+
if args.musa:
249+
test_musa(lib, test_cases)
250+
if not (args.cpu or args.cuda or args.bang or args.musa):
234251
test_cpu(lib, test_cases)
235252
print("\033[92mTest passed!\033[0m")

operatorspy/tests/conv.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,25 @@ class ConvDescriptor(Structure):
3939

4040

4141
def conv(x, w, stride, padding, dilation):
42-
match len(x.shape) - 2:
43-
case 1:
44-
return F.conv1d(
45-
x, w, stride=stride, padding=padding, dilation=dilation
46-
)
47-
case 2:
48-
return F.conv2d(
49-
x, w, stride=stride, padding=padding, dilation=dilation
50-
)
51-
case 3:
52-
return F.conv3d(
53-
x, w, stride=stride, padding=padding, dilation=dilation
54-
)
55-
case _:
56-
print("Error: Pytorch -> Unsupported tensor dimension")
57-
return None
42+
ndim = len(x.shape) - 2
43+
conv_func_map = {
44+
1: F.conv1d,
45+
2: F.conv2d,
46+
3: F.conv3d
47+
}
48+
49+
if ndim not in conv_func_map:
50+
print("Error: Pytorch -> Unsupported tensor dimension")
51+
return None
52+
53+
# Select the appropriate convolution function
54+
conv_func = conv_func_map[ndim]
55+
56+
if PROFILE:
57+
ans = conv_func(x, w, stride=stride, padding=padding, dilation=dilation)
58+
torch.cuda.synchronize()
59+
return ans
60+
return conv_func(x, w, stride=stride, padding=padding, dilation=dilation)
5861

5962

6063
# infer the shape of the output given the inputs for a N-ary convolution
@@ -206,18 +209,28 @@ def test_bang(lib, test_cases):
206209
test(lib, handle, "mlu", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
207210
destroy_handle(lib, handle)
208211

212+
def test_musa(lib, test_cases):
213+
import torch_musa
214+
215+
device = DeviceEnum.DEVICE_MUSA
216+
handle = create_handle(lib, device)
217+
for x_shape, w_shape, pads, strides, dilations, x_strides in test_cases:
218+
# test(lib, handle, "musa", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float16)
219+
test(lib, handle, "musa", x_shape, w_shape, pads, strides, dilations, x_strides, tensor_dtype=torch.float32)
220+
destroy_handle(lib, handle)
221+
209222

210223
if __name__ == "__main__":
211224
test_cases = [
212225
# x_shape, w_shape, pads, strides, dilations, x_strides
213-
(
214-
(32, 3, 4),
215-
(32, 3, 5),
216-
(1,),
217-
(1,),
218-
(1,),
219-
None,
220-
),
226+
# (
227+
# (32, 3, 4),
228+
# (32, 3, 5),
229+
# (1,),
230+
# (1,),
231+
# (1,),
232+
# None,
233+
# ),
221234
(
222235
(1, 3, 4, 4),
223236
(2, 3, 3, 3),
@@ -228,9 +241,9 @@ def test_bang(lib, test_cases):
228241
),
229242
(
230243
(32, 3, 128, 128),
231-
(64, 3, 5, 5),
232-
(2, 2),
233-
(2, 2),
244+
(1, 3, 3, 3),
245+
(1, 1),
246+
(1, 1),
234247
(1, 1),
235248
None,
236249
),
@@ -286,6 +299,8 @@ def test_bang(lib, test_cases):
286299
test_cuda(lib, test_cases)
287300
if args.bang:
288301
test_bang(lib, test_cases)
289-
if not (args.cpu or args.cuda or args.bang):
302+
if args.musa:
303+
test_musa(lib, test_cases)
304+
if not (args.cpu or args.cuda or args.bang or args.musa):
290305
test_cpu(lib, test_cases)
291306
print("\033[92mTest passed!\033[0m")

operatorspy/tests/max_pool.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test(
8888

8989
x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device)
9090
y = torch.rand(inferShape(x_shape, k_shape, padding, strides), dtype=tensor_dtype).to(torch_device)
91-
91+
9292
for i in range(NUM_PRERUN if PROFILE else 1):
9393
ans = pool(x, k_shape, padding, strides)
9494
if PROFILE:
@@ -148,7 +148,9 @@ def test(
148148
)
149149
elapsed = (time.time() - start_time) / NUM_ITERATIONS
150150
print(f" lib time: {elapsed :6f}")
151-
151+
print(x)
152+
print(y)
153+
print(ans)
152154
assert torch.allclose(y, ans, atol=0, rtol=1e-3)
153155
check_error(lib.infiniopDestroyMaxPoolDescriptor(descriptor))
154156

@@ -181,6 +183,16 @@ def test_bang(lib, test_cases):
181183
test(lib, handle, "mlu", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
182184
destroy_handle(lib, handle)
183185

186+
def test_musa(lib, test_cases):
187+
import torch_musa
188+
189+
device = DeviceEnum.DEVICE_MUSA
190+
handle = create_handle(lib, device)
191+
for x_shape, kernel_shape, padding, strides in test_cases:
192+
# test(lib, handle, "musa", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float16)
193+
test(lib, handle, "musa", x_shape, kernel_shape, padding, strides, tensor_dtype=torch.float32)
194+
destroy_handle(lib, handle)
195+
184196

185197
if __name__ == "__main__":
186198
test_cases = [
@@ -227,6 +239,8 @@ def test_bang(lib, test_cases):
227239
test_cuda(lib, test_cases)
228240
if args.bang:
229241
test_bang(lib, test_cases)
230-
if not (args.cpu or args.cuda or args.bang):
242+
if args.musa:
243+
test_musa(lib, test_cases)
244+
if not (args.cpu or args.cuda or args.bang or args.musa):
231245
test_cpu(lib, test_cases)
232246
print("\033[92mTest passed!\033[0m")

src/ops/add/musa/add_musa.mu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ __global__ void add(
6969
}
7070

7171
template<typename Tdata, typename BTdata>
72-
void _add_nv_gpu(AddMusaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const *b, uint64_t data_size, uint64_t pack_size, uint64_t offset, void *stream) {
72+
void _add_mt_gpu(AddMusaDescriptor_t desc, Tdata *c, Tdata const *a, Tdata const *b, uint64_t data_size, uint64_t pack_size, uint64_t offset, void *stream) {
7373
if (data_size == 0) {
7474
return;
7575
}
@@ -92,13 +92,13 @@ infiniopStatus_t add_mt_gpu(AddMusaDescriptor_t desc, void *c, void const *a, vo
9292
const auto a_vec = reinterpret_cast<const Tdata *>(a);
9393
const auto b_vec = reinterpret_cast<const Tdata *>(b);
9494
const auto c_vec = reinterpret_cast<Tdata *>(c);
95-
_add_nv_gpu<Tdata, TIdata>(desc, c_vec, a_vec, b_vec, data_size, pack_size, 0, stream);
95+
_add_mt_gpu<Tdata, TIdata>(desc, c_vec, a_vec, b_vec, data_size, pack_size, 0, stream);
9696

9797
const auto remainder = desc->c_data_size % pack_size;
9898
const auto a_ = reinterpret_cast<const TIdata *>(a);
9999
const auto b_ = reinterpret_cast<const TIdata *>(b);
100100
const auto c_ = reinterpret_cast<TIdata *>(c);
101-
_add_nv_gpu<TIdata, TIdata>(desc, c_, a_, b_, remainder, 1, data_size * pack_size, stream);
101+
_add_mt_gpu<TIdata, TIdata>(desc, c_, a_, b_, remainder, 1, data_size * pack_size, stream);
102102
return STATUS_SUCCESS;
103103
}
104104

src/ops/conv/musa/conv_musa.cc

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include "conv_musa.h"
2+
#include "../../../devices/musa/common_musa.h"
3+
#include "../../utils.h"
4+
#include <vector>
5+
6+
infiniopStatus_t musaCreateConvDescriptor(MusaHandle_t handle,
7+
ConvMusaDescriptor_t *desc_ptr,
8+
infiniopTensorDescriptor_t y,
9+
infiniopTensorDescriptor_t x,
10+
infiniopTensorDescriptor_t w,
11+
void const *pads,
12+
void const *strides,
13+
void const *dilations,
14+
uint64_t n) {
15+
uint64_t ndim = y->ndim;
16+
if (ndim < 3 || ndim != x->ndim || ndim != w->ndim) {
17+
return STATUS_BAD_TENSOR_SHAPE;
18+
}
19+
if (x->shape[0] != y->shape[0] || w->shape[0] != y->shape[1] || x->shape[1] != w->shape[1]) {
20+
return STATUS_BAD_TENSOR_SHAPE;
21+
}
22+
if (y->dt != F16 && y->dt != F32) {
23+
return STATUS_BAD_TENSOR_DTYPE;
24+
}
25+
if (y->dt != x->dt || y->dt != w->dt) {
26+
return STATUS_BAD_TENSOR_DTYPE;
27+
}
28+
29+
const auto new_ndim = std::max(4UL, ndim);
30+
// convert pads, strides, dilations into int32[]
31+
int *pad = new int[new_ndim];
32+
int *stride = new int[new_ndim];
33+
int *dilation = new int[new_ndim];
34+
int64_t *x_shape = new int64_t[new_ndim];
35+
int64_t *w_shape = new int64_t[new_ndim];
36+
int64_t *y_shape = new int64_t[new_ndim];
37+
auto pads_ = reinterpret_cast<uint64_t const *>(pads);
38+
auto strides_ = reinterpret_cast<int64_t const *>(strides);
39+
auto dilations_ = reinterpret_cast<uint64_t const *>(dilations);
40+
for (size_t i = 0; i < new_ndim; ++i) {
41+
pad[i] = i < ndim - 2 ? static_cast<int>(pads_[i]) : 0;
42+
stride[i] = i < ndim - 2 ? static_cast<int>(strides_[i]) : 1;
43+
dilation[i] = i < ndim - 2 ? static_cast<int>(dilations_[i]) : 1;
44+
x_shape[i] = i < ndim ? static_cast<int64_t>(x->shape[i]) : 1;
45+
w_shape[i] = i < ndim ? static_cast<int64_t>(w->shape[i]) : 1;
46+
y_shape[i] = i < ndim ? static_cast<int64_t>(y->shape[i]) : 1;
47+
}
48+
49+
musa::dnn::Tensor *x_tensor = new musa::dnn::Tensor();
50+
musa::dnn::Tensor *y_tensor = new musa::dnn::Tensor();
51+
musa::dnn::Tensor *w_tensor = new musa::dnn::Tensor();
52+
53+
if (y->dt == F16) {
54+
x_tensor->SetType(musa::dnn::Tensor::Type::HALF);
55+
y_tensor->SetType(musa::dnn::Tensor::Type::HALF);
56+
w_tensor->SetType(musa::dnn::Tensor::Type::HALF);
57+
} else if (y->dt == F32) {
58+
x_tensor->SetType(musa::dnn::Tensor::Type::FLOAT);
59+
y_tensor->SetType(musa::dnn::Tensor::Type::FLOAT);
60+
w_tensor->SetType(musa::dnn::Tensor::Type::FLOAT);
61+
}
62+
63+
x_tensor->SetFormat(musa::dnn::Tensor::Format::NCHW);
64+
y_tensor->SetFormat(musa::dnn::Tensor::Format::NCHW);
65+
w_tensor->SetFormat(musa::dnn::Tensor::Format::NCHW);
66+
67+
x_tensor->SetNdInfo((int) new_ndim, x_shape);
68+
y_tensor->SetNdInfo((int) new_ndim, y_shape);
69+
w_tensor->SetNdInfo((int) new_ndim, w_shape);
70+
71+
musa::dnn::Convolution* conv_operator = new musa::dnn::Convolution();
72+
conv_operator->SetNdInfo((int) new_ndim-2, pad, stride, dilation);
73+
musa::dnn::Convolution::Algorithm algo = musa::dnn::Convolution::Algorithm::DIRECT;
74+
size_t workspace_size = 0;
75+
76+
use_mudnn(handle->mudnn_handles_t, handle->device_id, nullptr, [&](musa::dnn::Handle* handle) {
77+
printf(" %d \n", conv_operator->GetRecommendForwardAlgorithm(*handle, algo, *y_tensor, *x_tensor, *w_tensor));
78+
// printf(" %d \n", conv_operator->GetForwardWorkspaceSize(*handle, workspace_size, *y_tensor, *x_tensor, *w_tensor, algo));
79+
});
80+
const float alpha = 1.0f;
81+
const float beta = 0.0f;
82+
printf("after: %d\n", algo);
83+
84+
printf("A\n");
85+
86+
*desc_ptr = new ConvMusaDescriptor{
87+
DevMtGpu,
88+
y->dt,
89+
handle->device_id,
90+
handle->mudnn_handles_t,
91+
x_tensor,
92+
w_tensor,
93+
y_tensor,
94+
conv_operator,
95+
algo,
96+
alpha,
97+
beta,
98+
workspace_size};
99+
100+
delete[] pad;
101+
delete[] stride;
102+
delete[] dilation;
103+
delete[] x_shape;
104+
delete[] w_shape;
105+
delete[] y_shape;
106+
107+
return STATUS_SUCCESS;
108+
}
109+
110+
infiniopStatus_t musaGetConvWorkspaceSize(ConvMusaDescriptor_t desc, uint64_t *size) {
111+
*size = desc->workspace_size;
112+
return STATUS_SUCCESS;
113+
}
114+
115+
infiniopStatus_t musaDestroyConvDescriptor(ConvMusaDescriptor_t desc) {
116+
117+
desc->mudnn_handles_t = nullptr;
118+
delete desc;
119+
return STATUS_SUCCESS;
120+
}

src/ops/conv/musa/conv_musa.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#ifndef __MUSA_CONV_H__
2+
#define __MUSA_CONV_H__
3+
4+
#include "../../../devices/musa/common_musa.h"
5+
#include "../../../devices/musa/musa_handle.h"
6+
#include "operators.h"
7+
#include <mudnn.h>
8+
9+
struct ConvMusaDescriptor {
10+
Device device;
11+
DT dtype;
12+
int device_id;
13+
std::shared_ptr<Pool<musa::dnn::Handle>> mudnn_handles_t;
14+
musa::dnn::Tensor* x_tensor;
15+
musa::dnn::Tensor* w_tensor;
16+
musa::dnn::Tensor* y_tensor;
17+
musa::dnn::Convolution* conv_operator;
18+
musa::dnn::Convolution::Algorithm algo;
19+
const float alpha;
20+
const float beta;
21+
uint64_t workspace_size;
22+
};
23+
24+
typedef struct ConvMusaDescriptor *ConvMusaDescriptor_t;
25+
26+
infiniopStatus_t musaCreateConvDescriptor(MusaHandle_t,
27+
ConvMusaDescriptor_t *,
28+
infiniopTensorDescriptor_t y,
29+
infiniopTensorDescriptor_t x,
30+
infiniopTensorDescriptor_t w,
31+
void const *pads,
32+
void const *strides,
33+
void const *dilations,
34+
uint64_t n);
35+
36+
infiniopStatus_t musaGetConvWorkspaceSize(ConvMusaDescriptor_t desc, uint64_t *size);
37+
38+
infiniopStatus_t musaConv(ConvMusaDescriptor_t desc,
39+
void *workspace, uint64_t workspace_size,
40+
void *y, void const *x, void const *w,
41+
void *stream);
42+
43+
infiniopStatus_t musaDestroyConvDescriptor(ConvMusaDescriptor_t desc);
44+
45+
#endif

0 commit comments

Comments
 (0)