Skip to content

Commit 9e71f53

Browse files
committed
Add conv optimization and updates, including mathType, argument data type, etc.
1 parent 8cf0604 commit 9e71f53

File tree

7 files changed

+67
-50
lines changed

7 files changed

+67
-50
lines changed

include/ops/conv/conv.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ __C __export infiniopStatus_t infiniopCreateConvDescriptor(infiniopHandle_t hand
1515
infiniopTensorDescriptor_t y,
1616
infiniopTensorDescriptor_t x,
1717
infiniopTensorDescriptor_t w,
18-
void *pads,
19-
void *strides,
20-
void *dilations,
18+
uint64_t const *pads,
19+
int64_t const *strides,
20+
uint64_t const *dilations,
2121
uint64_t n);
2222

2323
__C __export infiniopStatus_t infiniopGetConvWorkspaceSize(infiniopConvDescriptor_t desc, uint64_t *size);

operatorspy/tests/conv.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,25 @@ class ConvDescriptor(Structure):
3939

4040

4141
def conv(x, w, stride, padding, dilation):
42-
match len(x.shape) - 2:
43-
case 1:
44-
return F.conv1d(
45-
x, w, stride=stride, padding=padding, dilation=dilation
46-
)
47-
case 2:
48-
return F.conv2d(
49-
x, w, stride=stride, padding=padding, dilation=dilation
50-
)
51-
case 3:
52-
return F.conv3d(
53-
x, w, stride=stride, padding=padding, dilation=dilation
54-
)
55-
case _:
56-
print("Error: Pytorch -> Unsupported tensor dimension")
57-
return None
42+
ndim = len(x.shape) - 2
43+
conv_func_map = {
44+
1: F.conv1d,
45+
2: F.conv2d,
46+
3: F.conv3d
47+
}
48+
49+
if ndim not in conv_func_map:
50+
print("Error: Pytorch -> Unsupported tensor dimension")
51+
return None
52+
53+
# Select the appropriate convolution function
54+
conv_func = conv_func_map[ndim]
55+
56+
if PROFILE:
57+
ans = conv_func(x, w, stride=stride, padding=padding, dilation=dilation)
58+
torch.cuda.synchronize()
59+
return ans
60+
return conv_func(x, w, stride=stride, padding=padding, dilation=dilation)
5861

5962

6063
# infer the shape of the output given the inputs for a N-ary convolution

src/ops/conv/cpu/conv_cpu.cc

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ infiniopStatus_t cpuCreateConvDescriptor(infiniopHandle_t,
1717
infiniopTensorDescriptor_t y,
1818
infiniopTensorDescriptor_t x,
1919
infiniopTensorDescriptor_t w,
20-
void const *pads,
21-
void const *strides,
22-
void const *dilations,
20+
uint64_t const *pads,
21+
int64_t const *strides,
22+
uint64_t const *dilations,
2323
uint64_t n) {
2424
uint64_t ndim = y->ndim;
2525
if (ndim < 3 || ndim != x->ndim || ndim != w->ndim) {
@@ -36,27 +36,39 @@ infiniopStatus_t cpuCreateConvDescriptor(infiniopHandle_t,
3636
}
3737

3838
uint64_t y_size = getTotalSize(y->shape, ndim);
39-
const auto pads_ = reinterpret_cast<uint64_t const *>(pads);
40-
uint64_t padded_x_size = requirePadding(pads_, ndim) ? getPaddedSize(ndim, x->shape, pads_) : 0;
39+
uint64_t padded_x_size = requirePadding(pads, ndim) ? getPaddedSize(ndim, x->shape, pads) : 0;
4140
uint64_t *x_shape = new uint64_t[ndim];
4241
uint64_t *w_shape = new uint64_t[ndim];
4342
uint64_t *y_shape = new uint64_t[ndim];
43+
uint64_t *pads_ = new uint64_t[n];
44+
int64_t *strides_ = new int64_t[n];
45+
uint64_t *dilations_ = new uint64_t[n];
4446
memcpy(x_shape, x->shape, ndim * sizeof(uint64_t));
4547
memcpy(w_shape, w->shape, ndim * sizeof(uint64_t));
4648
memcpy(y_shape, y->shape, ndim * sizeof(uint64_t));
49+
memcpy(pads_, pads, n * sizeof(*pads));
50+
memcpy(strides_, strides, n * sizeof(*strides));
51+
memcpy(dilations_, dilations, n * sizeof(*dilations));
52+
53+
uint64_t *padded_shape = nullptr;
54+
if (padded_x_size > 0) {
55+
padded_shape = new uint64_t[ndim];
56+
getPaddedShape(ndim, x_shape, pads_, padded_shape);
57+
}
4758

4859
*desc_ptr = new ConvCpuDescriptor{
4960
DevCpu,
5061
y->dt,
5162
ndim,
5263
y_size,
5364
padded_x_size,
65+
padded_shape,
5466
x_shape,
5567
w_shape,
5668
y_shape,
57-
reinterpret_cast<uint64_t const *>(pads),
58-
reinterpret_cast<int64_t const *>(strides),
59-
reinterpret_cast<uint64_t const *>(dilations),
69+
pads_,
70+
strides_,
71+
dilations_,
6072
};
6173

6274
return STATUS_SUCCESS;
@@ -71,9 +83,13 @@ infiniopStatus_t cpuGetConvWorkspaceSize(ConvCpuDescriptor_t desc, uint64_t *siz
7183
}
7284

7385
infiniopStatus_t cpuDestroyConvDescriptor(ConvCpuDescriptor_t desc) {
86+
delete[] desc->padded_shape;
7487
delete[] desc->x_shape;
7588
delete[] desc->w_shape;
7689
delete[] desc->y_shape;
90+
delete[] desc->pads;
91+
delete[] desc->strides;
92+
delete[] desc->dilations;
7793
delete desc;
7894
return STATUS_SUCCESS;
7995
}
@@ -121,6 +137,7 @@ void _applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x,
121137

122138
// perform all the convolutions along this axis
123139
for (size_t i = 0; i < steps; ++i, ++y_index) {
140+
#pragma unroll
124141
// perform a single convolution
125142
for (size_t k = 0; k < kernel_size; ++k) {
126143
// calculate the current indices
@@ -129,7 +146,7 @@ void _applyConv(ConvCpuDescriptor_t desc, Ydata *y, Xdata const *x,
129146

130147
// base case (last dimension)
131148
if (ndim == desc->ndim - 1) {
132-
if (desc->dtype == F16) {
149+
if constexpr (std::is_same_v<Xdata, uint16_t>) {
133150
y[y_index] += f16_to_f32(x[curr_x_index]) * f16_to_f32(w[curr_w_index]);
134151
} else {
135152
y[y_index] += x[curr_x_index] * w[curr_w_index];
@@ -173,11 +190,9 @@ void _conv_cpu(ConvCpuDescriptor_t desc, void *workspace, uint64_t workspace_siz
173190
Ydata *y, Xdata const *x, Xdata const *w) {
174191
if (desc->padded_x_size > 0) {
175192
auto padded_x = reinterpret_cast<Xdata *>(workspace);
176-
uint64_t padded_shape[desc->ndim];
177193
std::fill(padded_x, padded_x + desc->padded_x_size, 0);
178-
getPaddedShape(desc->ndim, desc->x_shape, desc->pads, padded_shape);
179-
fillPaddedInput<Xdata>(desc, padded_shape, padded_x, x, desc->pads, 0, 0, 0);
180-
applyConv<Xdata, Ydata>(desc, y, padded_x, w, padded_shape);
194+
fillPaddedInput<Xdata>(desc, desc->padded_shape, padded_x, x, desc->pads, 0, 0, 0);
195+
applyConv<Xdata, Ydata>(desc, y, padded_x, w, desc->padded_shape);
181196
} else {
182197
applyConv<Xdata, Ydata>(desc, y, x, w, desc->x_shape);
183198
}

src/ops/conv/cpu/conv_cpu.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct ConvCpuDescriptor {
1313
uint64_t ndim;
1414
uint64_t y_size;
1515
uint64_t padded_x_size;
16+
uint64_t const *padded_shape;
1617
uint64_t const *x_shape;
1718
uint64_t const *w_shape;
1819
uint64_t const *y_shape;
@@ -28,9 +29,9 @@ infiniopStatus_t cpuCreateConvDescriptor(infiniopHandle_t,
2829
infiniopTensorDescriptor_t y,
2930
infiniopTensorDescriptor_t x,
3031
infiniopTensorDescriptor_t w,
31-
void const *pads,
32-
void const *strides,
33-
void const *dilations,
32+
uint64_t const *pads,
33+
int64_t const *strides,
34+
uint64_t const *dilations,
3435
uint64_t n);
3536

3637
infiniopStatus_t cpuGetConvWorkspaceSize(ConvCpuDescriptor_t desc, uint64_t *size);

src/ops/conv/cuda/conv.cc

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t handle,
77
infiniopTensorDescriptor_t y,
88
infiniopTensorDescriptor_t x,
99
infiniopTensorDescriptor_t w,
10-
void const *pads,
11-
void const *strides,
12-
void const *dilations,
10+
uint64_t const *pads,
11+
int64_t const *strides,
12+
uint64_t const *dilations,
1313
uint64_t n) {
1414
uint64_t ndim = y->ndim;
1515
if (ndim < 3 || ndim != x->ndim || ndim != w->ndim) {
@@ -33,13 +33,10 @@ infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t handle,
3333
int32_t *x_shape = new int32_t[new_ndim];
3434
int32_t *w_shape = new int32_t[new_ndim];
3535
int32_t *y_shape = new int32_t[new_ndim];
36-
auto pads_ = reinterpret_cast<uint64_t const *>(pads);
37-
auto strides_ = reinterpret_cast<int64_t const *>(strides);
38-
auto dilations_ = reinterpret_cast<uint64_t const *>(dilations);
3936
for (size_t i = 0; i < new_ndim; ++i) {
40-
pad[i] = i < ndim - 2 ? static_cast<int32_t>(pads_[i]) : 0;
41-
stride[i] = i < ndim - 2 ? static_cast<int32_t>(strides_[i]) : 1;
42-
dilation[i] = i < ndim - 2 ? static_cast<int32_t>(dilations_[i]) : 1;
37+
pad[i] = i < ndim - 2 ? static_cast<int32_t>(pads[i]) : 0;
38+
stride[i] = i < ndim - 2 ? static_cast<int32_t>(strides[i]) : 1;
39+
dilation[i] = i < ndim - 2 ? static_cast<int32_t>(dilations[i]) : 1;
4340
x_shape[i] = i < ndim ? static_cast<int32_t>(x->shape[i]) : 1;
4441
w_shape[i] = i < ndim ? static_cast<int32_t>(w->shape[i]) : 1;
4542
y_shape[i] = i < ndim ? static_cast<int32_t>(y->shape[i]) : 1;
@@ -92,6 +89,7 @@ infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t handle,
9289
checkCudnnError(cudnnCreateTensorDescriptor(&y_desc));
9390
checkCudnnError(cudnnSetTensorNdDescriptorEx(y_desc, CUDNN_TENSOR_NCHW, static_cast<cudnnDataType_t>(tensor_dt), new_ndim, y_shape));
9491

92+
cudnnSetConvolutionMathType(op_desc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION);
9593

9694
// tuning: get the best algorithm
9795
int requestedAlgoCount = 1;

src/ops/conv/cuda/conv.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ infiniopStatus_t cudaCreateConvDescriptor(CudaHandle_t,
2828
infiniopTensorDescriptor_t y,
2929
infiniopTensorDescriptor_t x,
3030
infiniopTensorDescriptor_t w,
31-
void const *pads,
32-
void const *strides,
33-
void const *dilations,
31+
uint64_t const *pads,
32+
int64_t const *strides,
33+
uint64_t const *dilations,
3434
uint64_t n);
3535

3636
infiniopStatus_t cudaGetConvWorkspaceSize(ConvCudaDescriptor_t desc, uint64_t *size);

src/ops/conv/operator.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ __C infiniopStatus_t infiniopCreateConvDescriptor(
1616
infiniopTensorDescriptor_t y,
1717
infiniopTensorDescriptor_t x,
1818
infiniopTensorDescriptor_t w,
19-
void *pads,
20-
void *strides,
21-
void *dilations,
19+
uint64_t const *pads,
20+
int64_t const *strides,
21+
uint64_t const *dilations,
2222
uint64_t n) {
2323
switch (handle->device) {
2424
#ifdef ENABLE_CPU

0 commit comments

Comments
 (0)