Skip to content

Commit 352aaae

Browse files
committed
Speed-up batch normalization by avoiding branches in the inner loop
1 parent dc53ae0 commit 352aaae

File tree

1 file changed

+38
-9
lines changed

1 file changed

+38
-9
lines changed

include/fdeep/layers/batch_normalization_layer.hpp

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,26 +118,55 @@ class batch_normalization_layer : public layer
118118
}
119119

120120
tensor output(input.shape(), 0);
121+
121122
const auto denoms = fplus::transform([this](const auto& mv)
122123
{ return std::sqrt(mv + this->epsilon_); },
123124
moving_variance_);
124-
for (std::size_t dim5 = 0; dim5 < output.shape().size_dim_5_; ++dim5)
125-
{
126-
for (std::size_t dim4 = 0; dim4 < output.shape().size_dim_4_; ++dim4)
125+
126+
if (use_gamma && use_beta) {
127+
for (std::size_t dim5 = 0; dim5 < output.shape().size_dim_5_; ++dim5)
127128
{
128-
for (std::size_t z = 0; z < output.shape().depth_; ++z)
129+
for (std::size_t dim4 = 0; dim4 < output.shape().size_dim_4_; ++dim4)
129130
{
130-
if (use_gamma && use_beta) {
131+
for (std::size_t z = 0; z < output.shape().depth_; ++z)
132+
{
131133
apply_to_channel(apply_to_value_gamma_beta, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
132134
}
133-
else if (use_gamma) {
135+
}
136+
}
137+
}
138+
else if (use_gamma) {
139+
for (std::size_t dim5 = 0; dim5 < output.shape().size_dim_5_; ++dim5)
140+
{
141+
for (std::size_t dim4 = 0; dim4 < output.shape().size_dim_4_; ++dim4)
142+
{
143+
for (std::size_t z = 0; z < output.shape().depth_; ++z)
144+
{
134145
apply_to_channel(apply_to_value_gamma, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
135146
}
136-
else if (use_beta) {
147+
}
148+
}
149+
}
150+
else if (use_beta) {
151+
for (std::size_t dim5 = 0; dim5 < output.shape().size_dim_5_; ++dim5)
152+
{
153+
for (std::size_t dim4 = 0; dim4 < output.shape().size_dim_4_; ++dim4)
154+
{
155+
for (std::size_t z = 0; z < output.shape().depth_; ++z)
156+
{
137157
apply_to_channel(apply_to_value_beta, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
138158
}
139-
else {
140-
apply_to_channel(apply_to_value, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
159+
}
160+
}
161+
}
162+
else {
163+
for (std::size_t dim5 = 0; dim5 < output.shape().size_dim_5_; ++dim5)
164+
{
165+
for (std::size_t dim4 = 0; dim4 < output.shape().size_dim_4_; ++dim4)
166+
{
167+
for (std::size_t z = 0; z < output.shape().depth_; ++z)
168+
{
169+
apply_to_channel(apply_to_value, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
141170
}
142171
}
143172
}

0 commit comments

Comments
 (0)