@@ -118,26 +118,55 @@ class batch_normalization_layer : public layer
118
118
}
119
119
120
120
tensor output (input.shape (), 0 );
121
+
121
122
const auto denoms = fplus::transform ([this ](const auto & mv)
122
123
{ return std::sqrt (mv + this ->epsilon_ ); },
123
124
moving_variance_);
124
- for (std:: size_t dim5 = 0 ; dim5 < output. shape (). size_dim_5_ ; ++dim5)
125
- {
126
- for (std::size_t dim4 = 0 ; dim4 < output.shape ().size_dim_4_ ; ++dim4 )
125
+
126
+ if (use_gamma && use_beta) {
127
+ for (std::size_t dim5 = 0 ; dim5 < output.shape ().size_dim_5_ ; ++dim5 )
127
128
{
128
- for (std::size_t z = 0 ; z < output.shape ().depth_ ; ++z )
129
+ for (std::size_t dim4 = 0 ; dim4 < output.shape ().size_dim_4_ ; ++dim4 )
129
130
{
130
- if (use_gamma && use_beta) {
131
+ for (std::size_t z = 0 ; z < output.shape ().depth_ ; ++z)
132
+ {
131
133
apply_to_channel (apply_to_value_gamma_beta, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
132
134
}
133
- else if (use_gamma) {
135
+ }
136
+ }
137
+ }
138
+ else if (use_gamma) {
139
+ for (std::size_t dim5 = 0 ; dim5 < output.shape ().size_dim_5_ ; ++dim5)
140
+ {
141
+ for (std::size_t dim4 = 0 ; dim4 < output.shape ().size_dim_4_ ; ++dim4)
142
+ {
143
+ for (std::size_t z = 0 ; z < output.shape ().depth_ ; ++z)
144
+ {
134
145
apply_to_channel (apply_to_value_gamma, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
135
146
}
136
- else if (use_beta) {
147
+ }
148
+ }
149
+ }
150
+ else if (use_beta) {
151
+ for (std::size_t dim5 = 0 ; dim5 < output.shape ().size_dim_5_ ; ++dim5)
152
+ {
153
+ for (std::size_t dim4 = 0 ; dim4 < output.shape ().size_dim_4_ ; ++dim4)
154
+ {
155
+ for (std::size_t z = 0 ; z < output.shape ().depth_ ; ++z)
156
+ {
137
157
apply_to_channel (apply_to_value_beta, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
138
158
}
139
- else {
140
- apply_to_channel (apply_to_value, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
159
+ }
160
+ }
161
+ }
162
+ else {
163
+ for (std::size_t dim5 = 0 ; dim5 < output.shape ().size_dim_5_ ; ++dim5)
164
+ {
165
+ for (std::size_t dim4 = 0 ; dim4 < output.shape ().size_dim_4_ ; ++dim4)
166
+ {
167
+ for (std::size_t z = 0 ; z < output.shape ().depth_ ; ++z)
168
+ {
169
+ apply_to_channel (apply_to_value, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
141
170
}
142
171
}
143
172
}
0 commit comments