-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconv.cpp
147 lines (126 loc) · 4.53 KB
/
conv.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#include <functional>
#include <sstream>
#include "gradstudent/iter.h"
#include "gradstudent/ops.h"
namespace gs {
/* SLIDING WINDOW HELPER FUNCTIONS */
void slidingWindowTransform(
Tensor &result, const Tensor &input,
const std::function<Tensor(const Tensor &, const array_t &)> &windowFn,
const std::function<double(const Tensor &)> &transform) {
for (auto [resIdx, res] : ITensorIter(result)) {
Tensor window(windowFn(input, resIdx));
res = transform(window);
}
}
void slidingWindowTransformNoStride(
Tensor &result, const Tensor &input, const array_t &windowShape,
const std::function<double(Tensor)> &transform) {
slidingWindowTransform(
result, input,
[&](const Tensor &input, const array_t &resIdx) {
return truncate(input, resIdx, resIdx + windowShape);
},
transform);
}
void slidingWindowTransformFullStride(
Tensor &result, const Tensor &input, const array_t &windowShape,
const std::function<double(Tensor)> &transform) {
slidingWindowTransform(
result, input,
[&](const Tensor &input, const array_t &resIdx) {
return truncate(input, resIdx * windowShape,
resIdx * windowShape + windowShape);
},
transform);
}
/* CONVOLUTION OVER ALL DIMENSIONS */
Tensor singleConv(const Tensor &input, const Tensor &kernel, size_t n) {
array_t kernel_shape = kernel.shape().sliceTo(n);
array_t result_shape = input.shape().sliceTo(n) - kernel_shape + 1;
Tensor result(result_shape);
slidingWindowTransformNoStride(
result, input, kernel_shape,
[&](const Tensor &window) { return sum(window * kernel); });
return result;
}
Tensor multiConv(const Tensor &input, const Tensor &kernel, size_t n) {
auto singleKernelShape = kernel.shape().slice(1, 1 + n);
array_t singleResultShape = input.shape().sliceTo(n) - singleKernelShape + 1;
auto resultShape = array_t{kernel.shape()[0]} | singleResultShape;
Tensor result(resultShape);
for (size_t i = 0; i < kernel.shape()[0]; ++i) {
slice(result, {i}) = singleConv(input, slice(kernel, array_t{i}), n);
}
return result;
}
Tensor conv(const Tensor &input, const Tensor &kernel, size_t n) {
if (n > input.ndims()) {
std::stringstream ss;
ss << "Convolution rank " << n << " exceeds input rank " << input.ndims();
throw std::invalid_argument(ss.str());
}
if (kernel.ndims() < input.ndims()) {
std::stringstream ss;
ss << "Input rank should not exceed kernel rank, got " << input.ndims()
<< " and " << kernel.ndims();
throw std::invalid_argument(ss.str());
}
if (kernel.ndims() > 1 + input.ndims()) {
std::stringstream ss;
ss << "Kernel rank should not exceed input rank by more than 1. Got kernel "
"rank "
<< kernel.ndims() << " and input rank " << input.ndims();
throw std::invalid_argument(ss.str());
}
n = n > 0 ? n : input.ndims();
if (kernel.ndims() == input.ndims()) {
return singleConv(input, kernel, n);
}
return multiConv(input, kernel, n);
}
/* MAX POOLING */
Tensor singleMaxPool(const Tensor &input, const array_t &poolShape) {
array_t result_shape;
try {
result_shape = input.shape() / poolShape;
} catch (const std::invalid_argument &e) {
std::stringstream ss;
ss << "Pool shape " << poolShape << " does not divide input shape "
<< input.shape();
throw std::invalid_argument(ss.str());
}
Tensor result(result_shape);
slidingWindowTransformFullStride(
result, input, poolShape,
[](const Tensor &window) { return max(window); });
return result;
}
Tensor maxPool(const Tensor &input, const array_t &poolShape) {
if (input.ndims() > poolShape.size() + 1) {
std::stringstream ss;
ss << "Input rank must be at most one greater than pool rank, got "
<< input.ndims() << " and " << poolShape.size();
throw std::invalid_argument(ss.str());
}
if (poolShape.size() == input.ndims()) {
return singleMaxPool(input, poolShape);
}
array_t inputSliceShape = input.shape().sliceFrom(1);
array_t resultSliceShape;
try {
resultSliceShape = inputSliceShape / poolShape;
} catch (const std::invalid_argument &e) {
std::stringstream ss;
ss << "Pool shape " << poolShape << " does not divide input shape "
<< input.shape();
throw std::invalid_argument(ss.str());
}
array_t result_shape = array_t{input.shape()[0]} | resultSliceShape;
Tensor result(result_shape);
for (size_t i = 0; i < input.shape()[0]; ++i) {
slice(result, {i}) = singleMaxPool(slice(input, {i}), poolShape);
}
return result;
}
} // namespace gs