Skip to content

Commit

Permalink
Pre-calculate denominators in batch normalization for performance
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Oct 23, 2023
1 parent 44db357 commit dc53ae0
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions include/fdeep/layers/batch_normalization_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,24 +118,26 @@ class batch_normalization_layer : public layer
}

tensor output(input.shape(), 0);
const auto denoms = fplus::transform([this](const auto& mv)
{ return std::sqrt(mv + this->epsilon_); },
moving_variance_);
for (std::size_t dim5 = 0; dim5 < output.shape().size_dim_5_; ++dim5)
{
for (std::size_t dim4 = 0; dim4 < output.shape().size_dim_4_; ++dim4)
{
for (std::size_t z = 0; z < output.shape().depth_; ++z)
{
const float_type denom = std::sqrt(moving_variance_[z] + epsilon_);
if (use_gamma && use_beta) {
apply_to_channel(apply_to_value_gamma_beta, moving_mean_, beta_, gamma_, input, output, denom, z, dim5, dim4);
apply_to_channel(apply_to_value_gamma_beta, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
}
else if (use_gamma) {
apply_to_channel(apply_to_value_gamma, moving_mean_, beta_, gamma_, input, output, denom, z, dim5, dim4);
apply_to_channel(apply_to_value_gamma, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
}
else if (use_beta) {
apply_to_channel(apply_to_value_beta, moving_mean_, beta_, gamma_, input, output, denom, z, dim5, dim4);
apply_to_channel(apply_to_value_beta, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
}
else {
apply_to_channel(apply_to_value, moving_mean_, beta_, gamma_, input, output, denom, z, dim5, dim4);
apply_to_channel(apply_to_value, moving_mean_, beta_, gamma_, input, output, denoms[z], z, dim5, dim4);
}
}
}
Expand Down

0 comments on commit dc53ae0

Please sign in to comment.