Skip to content

Commit

Permalink
Add support for batch normalization on arbitrary axes
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Dec 17, 2019
1 parent a085aa3 commit 61b67cc
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 17 deletions.
7 changes: 5 additions & 2 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,16 @@ inline layer_ptr create_batch_normalization_layer(const get_param_f& get_param,
decode_floats(get_param(name, "moving_variance"));
const bool center = data["config"]["center"];
const bool scale = data["config"]["scale"];
const auto axis_vec = create_vector<int>(create_int, data["config"]["axis"]);
assertion(axis_vec.size() == 1, "invalid axis configuration");
const int axis = axis_vec.front();
const float_type epsilon = data["config"]["epsilon"];
float_vec gamma;
float_vec beta;
if (scale) gamma = decode_floats(get_param(name, "gamma"));
if (center) beta = decode_floats(get_param(name, "beta"));
return std::make_shared<batch_normalization_layer>(
name, moving_mean, moving_variance, beta, gamma, epsilon);
name, axis, moving_mean, moving_variance, beta, gamma, epsilon);
}

inline layer_ptr create_identity_layer(
Expand Down Expand Up @@ -998,7 +1001,7 @@ inline layer_ptr create_bidirectional_layer(const get_param_f& get_param,
);
const bool return_sequences = json_object_get(layer_config, "return_sequences", false);
const bool stateful = json_object_get(layer_config, "stateful", false);

return std::make_shared<bidirectional_layer>(name, merge_mode, units, unit_activation,
recurrent_activation, wrapped_layer_type,
use_bias, reset_after, return_sequences, stateful,
Expand Down
74 changes: 61 additions & 13 deletions include/fdeep/layers/batch_normalization_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ class batch_normalization_layer : public layer
{
public:
explicit batch_normalization_layer(const std::string& name,
int axis,
const float_vec& moving_mean,
const float_vec& moving_variance,
const float_vec& beta,
const float_vec& gamma,
float_type epsilon)
: layer(name),
axis_(axis),
moving_mean_(moving_mean),
moving_variance_(moving_variance),
beta_(beta),
Expand All @@ -33,6 +35,7 @@ class batch_normalization_layer : public layer
{
}
protected:
int axis_;
float_vec moving_mean_;
float_vec moving_variance_;
float_vec beta_;
Expand All @@ -59,21 +62,27 @@ class batch_normalization_layer : public layer
}

tensor5 output(input.shape(), 0);
for (std::size_t z = 0; z < output.shape().depth_; ++z)
for (std::size_t dim5 = 0; dim5 < output.shape().size_dim_5_; ++dim5)
{
const float_type denom = std::sqrt(moving_variance_[z] + epsilon_);
for (std::size_t y = 0; y < output.shape().height_; ++y)
for (std::size_t dim4 = 0; dim4 < output.shape().size_dim_4_; ++dim4)
{
for (std::size_t x = 0; x < output.shape().width_; ++x)
for (std::size_t z = 0; z < output.shape().depth_; ++z)
{
float_type val = input.get(0, 0, y, x, z);
val -= moving_mean_[z];
if (use_gamma)
val *= gamma_[z];
val /= denom;
if (use_beta)
val += beta_[z];
output.set(0, 0, y, x, z, val);
const float_type denom = std::sqrt(moving_variance_[z] + epsilon_);
for (std::size_t y = 0; y < output.shape().height_; ++y)
{
for (std::size_t x = 0; x < output.shape().width_; ++x)
{
float_type val = input.get(dim5, dim4, y, x, z);
val -= moving_mean_[z];
if (use_gamma)
val *= gamma_[z];
val /= denom;
if (use_beta)
val += beta_[z];
output.set(dim5, dim4, y, x, z, val);
}
}
}
}
}
Expand All @@ -84,7 +93,46 @@ class batch_normalization_layer : public layer
{
assertion(inputs.size() == 1, "invalid number of tensors");
const auto& input = inputs.front();
return {apply_to_slices(input)};
const int adjusted_axis =
axis_ == -1
? 5
: 5 + axis_ - static_cast<int>(input.shape().rank());

if (adjusted_axis == 5)
{
return {apply_to_slices(input)};
}
else if (adjusted_axis == 4)
{
return {permute_tensor5(apply_to_slices(permute_tensor5(input,
{1, 2, 3, 5, 4})),
{1, 2, 3, 5, 4})};
}
else if (adjusted_axis == 3)
{
return {permute_tensor5(apply_to_slices(permute_tensor5(input,
{1, 2, 5, 4, 3})),
{1, 2, 5, 4, 3})};
}
else if (adjusted_axis == 2)
{
return {permute_tensor5(apply_to_slices(permute_tensor5(input,
{1, 5, 3, 4, 2})),
{1, 5, 3, 4, 2})};
}
else if (adjusted_axis == 1)
{
return {permute_tensor5(apply_to_slices(permute_tensor5(input,
{5, 2, 3, 4, 1})),
{5, 2, 3, 4, 1})};
}
else {
raise_error("Invalid axis for batch normalization.");
// Just to make the compiler happy.
// In reality, this is never called.
return inputs;
}

}
};

Expand Down
2 changes: 0 additions & 2 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,6 @@ def show_batch_normalization_layer(layer):
else:
assert len(layer.axis) == 1
layer_axis = layer.axis[0]
assert layer_axis == -1 or layer_axis + 1 == len(layer.input_shape), \
"BatchNormalization only supported on the last tensor axis"
moving_mean = K.get_value(layer.moving_mean)
moving_variance = K.get_value(layer.moving_variance)
result = {}
Expand Down
20 changes: 20 additions & 0 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,27 @@ def get_test_model_exhaustive():
outputs.append(GlobalMaxPooling2D()(inputs[4]))
outputs.append(GlobalMaxPooling2D(data_format="channels_first")(inputs[4]))

outputs.append(BatchNormalization()(inputs[0]))
outputs.append(BatchNormalization(axis=1)(inputs[0]))
outputs.append(BatchNormalization(axis=2)(inputs[0]))
outputs.append(BatchNormalization(axis=3)(inputs[0]))
outputs.append(BatchNormalization(axis=4)(inputs[0]))
outputs.append(BatchNormalization(axis=5)(inputs[0]))
outputs.append(BatchNormalization()(inputs[2]))
outputs.append(BatchNormalization(axis=1)(inputs[2]))
outputs.append(BatchNormalization(axis=2)(inputs[2]))
outputs.append(BatchNormalization(axis=3)(inputs[2]))
outputs.append(BatchNormalization(axis=4)(inputs[2]))
outputs.append(BatchNormalization()(inputs[4]))
#outputs.append(BatchNormalization(axis=1)(inputs[4])) # tensorflow.python.framework.errors_impl.InternalError: The CPU implementation of FusedBatchNorm only supports NHWC tensor format for now.
outputs.append(BatchNormalization(axis=2)(inputs[4]))
outputs.append(BatchNormalization(axis=3)(inputs[4]))
outputs.append(BatchNormalization()(inputs[6]))
outputs.append(BatchNormalization(axis=1)(inputs[6]))
outputs.append(BatchNormalization(axis=2)(inputs[6]))
outputs.append(BatchNormalization()(inputs[8]))
outputs.append(BatchNormalization(axis=1)(inputs[8]))

outputs.append(Dropout(0.5)(inputs[4]))

outputs.append(ZeroPadding2D(2)(inputs[4]))
Expand Down

0 comments on commit 61b67cc

Please sign in to comment.