-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathim2col_cython.pyx
121 lines (96 loc) · 4.87 KB
/
im2col_cython.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
cimport numpy as np
cimport cython
# DTYPE = np.float64
# ctypedef np.float64_t DTYPE_t
ctypedef fused DTYPE_t:
np.float32_t
np.float64_t
def im2col_cython(np.ndarray[DTYPE_t, ndim=4] x, int field_height,
int field_width, int padding, int stride):
cdef int N = x.shape[0]
cdef int C = x.shape[1]
cdef int H = x.shape[2]
cdef int W = x.shape[3]
cdef int HH = (H + 2 * padding - field_height) / stride + 1
cdef int WW = (W + 2 * padding - field_width) / stride + 1
cdef int p = padding
cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.pad(x,
((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
cdef np.ndarray[DTYPE_t, ndim=2] cols = np.zeros(
(C * field_height * field_width, N * HH * WW),
dtype=x.dtype)
# Moving the inner loop to a C function with no bounds checking works, but does
# not seem to help performance in any measurable way.
im2col_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
field_height, field_width, padding, stride)
return cols
@cython.boundscheck(False)
cdef int im2col_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
np.ndarray[DTYPE_t, ndim=4] x_padded,
int N, int C, int H, int W, int HH, int WW,
int field_height, int field_width, int padding, int stride) except? -1:
cdef int c, ii, jj, row, yy, xx, i, col
for c in range(C):
for yy in range(HH):
for xx in range(WW):
for ii in range(field_height):
for jj in range(field_width):
row = c * field_width * field_height + ii * field_height + jj
for i in range(N):
col = yy * WW * N + xx * N + i
cols[row, col] = x_padded[i, c, stride * yy + ii, stride * xx + jj]
def col2im_cython(np.ndarray[DTYPE_t, ndim=2] cols, int N, int C, int H, int W,
int field_height, int field_width, int padding, int stride):
cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
cdef int HH = (H + 2 * padding - field_height) / stride + 1
cdef int WW = (W + 2 * padding - field_width) / stride + 1
cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * padding, W + 2 * padding),
dtype=cols.dtype)
# Moving the inner loop to a C-function with no bounds checking improves
# performance quite a bit for col2im.
col2im_cython_inner(cols, x_padded, N, C, H, W, HH, WW,
field_height, field_width, padding, stride)
if padding > 0:
return x_padded[:, :, padding:-padding, padding:-padding]
return x_padded
@cython.boundscheck(False)
cdef int col2im_cython_inner(np.ndarray[DTYPE_t, ndim=2] cols,
np.ndarray[DTYPE_t, ndim=4] x_padded,
int N, int C, int H, int W, int HH, int WW,
int field_height, int field_width, int padding, int stride) except? -1:
cdef int c, ii, jj, row, yy, xx, i, col
for c in range(C):
for ii in range(field_height):
for jj in range(field_width):
row = c * field_width * field_height + ii * field_height + jj
for yy in range(HH):
for xx in range(WW):
for i in range(N):
col = yy * WW * N + xx * N + i
x_padded[i, c, stride * yy + ii, stride * xx + jj] += cols[row, col]
@cython.boundscheck(False)
@cython.wraparound(False)
cdef col2im_6d_cython_inner(np.ndarray[DTYPE_t, ndim=6] cols,
np.ndarray[DTYPE_t, ndim=4] x_padded,
int N, int C, int H, int W, int HH, int WW,
int out_h, int out_w, int pad, int stride):
cdef int c, hh, ww, n, h, w
for n in range(N):
for c in range(C):
for hh in range(HH):
for ww in range(WW):
for h in range(out_h):
for w in range(out_w):
x_padded[n, c, stride * h + hh, stride * w + ww] += cols[c, hh, ww, n, h, w]
def col2im_6d_cython(np.ndarray[DTYPE_t, ndim=6] cols, int N, int C, int H, int W,
int HH, int WW, int pad, int stride):
cdef np.ndarray x = np.empty((N, C, H, W), dtype=cols.dtype)
cdef int out_h = (H + 2 * pad - HH) / stride + 1
cdef int out_w = (W + 2 * pad - WW) / stride + 1
cdef np.ndarray[DTYPE_t, ndim=4] x_padded = np.zeros((N, C, H + 2 * pad, W + 2 * pad),
dtype=cols.dtype)
col2im_6d_cython_inner(cols, x_padded, N, C, H, W, HH, WW, out_h, out_w, pad, stride)
if pad > 0:
return x_padded[:, :, pad:-pad, pad:-pad]
return x_padded