diff --git a/include/fdeep/import_model.hpp b/include/fdeep/import_model.hpp index c71f4568..a759348f 100644 --- a/include/fdeep/import_model.hpp +++ b/include/fdeep/import_model.hpp @@ -452,13 +452,16 @@ inline layer_ptr create_batch_normalization_layer(const get_param_f& get_param, decode_floats(get_param(name, "moving_variance")); const bool center = data["config"]["center"]; const bool scale = data["config"]["scale"]; + const auto axis_vec = create_vector(create_int, data["config"]["axis"]); + assertion(axis_vec.size() == 1, "invalid axis configuration"); + const int axis = axis_vec.front(); const float_type epsilon = data["config"]["epsilon"]; float_vec gamma; float_vec beta; if (scale) gamma = decode_floats(get_param(name, "gamma")); if (center) beta = decode_floats(get_param(name, "beta")); return std::make_shared( - name, moving_mean, moving_variance, beta, gamma, epsilon); + name, axis, moving_mean, moving_variance, beta, gamma, epsilon); } inline layer_ptr create_identity_layer( @@ -998,7 +1001,7 @@ inline layer_ptr create_bidirectional_layer(const get_param_f& get_param, ); const bool return_sequences = json_object_get(layer_config, "return_sequences", false); const bool stateful = json_object_get(layer_config, "stateful", false); - + return std::make_shared(name, merge_mode, units, unit_activation, recurrent_activation, wrapped_layer_type, use_bias, reset_after, return_sequences, stateful, diff --git a/include/fdeep/layers/batch_normalization_layer.hpp b/include/fdeep/layers/batch_normalization_layer.hpp index 37e5b052..486bb06e 100644 --- a/include/fdeep/layers/batch_normalization_layer.hpp +++ b/include/fdeep/layers/batch_normalization_layer.hpp @@ -19,12 +19,14 @@ class batch_normalization_layer : public layer { public: explicit batch_normalization_layer(const std::string& name, + int axis, const float_vec& moving_mean, const float_vec& moving_variance, const float_vec& beta, const float_vec& gamma, float_type epsilon) : layer(name), + axis_(axis), moving_mean_(moving_mean), moving_variance_(moving_variance), beta_(beta), @@ -33,6 +35,7 @@ class batch_normalization_layer : public layer { } protected: + int axis_; float_vec moving_mean_; float_vec moving_variance_; float_vec beta_; @@ -59,21 +62,27 @@ class batch_normalization_layer : public layer } tensor5 output(input.shape(), 0); - for (std::size_t z = 0; z < output.shape().depth_; ++z) + for (std::size_t dim5 = 0; dim5 < output.shape().size_dim_5_; ++dim5) { - const float_type denom = std::sqrt(moving_variance_[z] + epsilon_); - for (std::size_t y = 0; y < output.shape().height_; ++y) + for (std::size_t dim4 = 0; dim4 < output.shape().size_dim_4_; ++dim4) { - for (std::size_t x = 0; x < output.shape().width_; ++x) + for (std::size_t z = 0; z < output.shape().depth_; ++z) { - float_type val = input.get(0, 0, y, x, z); - val -= moving_mean_[z]; - if (use_gamma) - val *= gamma_[z]; - val /= denom; - if (use_beta) - val += beta_[z]; - output.set(0, 0, y, x, z, val); + const float_type denom = std::sqrt(moving_variance_[z] + epsilon_); + for (std::size_t y = 0; y < output.shape().height_; ++y) + { + for (std::size_t x = 0; x < output.shape().width_; ++x) + { + float_type val = input.get(dim5, dim4, y, x, z); + val -= moving_mean_[z]; + if (use_gamma) + val *= gamma_[z]; + val /= denom; + if (use_beta) + val += beta_[z]; + output.set(dim5, dim4, y, x, z, val); + } + } } } } @@ -84,7 +93,46 @@ class batch_normalization_layer : public layer { assertion(inputs.size() == 1, "invalid number of tensors"); const auto& input = inputs.front(); - return {apply_to_slices(input)}; + const int adjusted_axis = + axis_ == -1 + ? 5 + : 5 + axis_ - static_cast(input.shape().rank()); + + if (adjusted_axis == 5) + { + return {apply_to_slices(input)}; + } + else if (adjusted_axis == 4) + { + return {permute_tensor5(apply_to_slices(permute_tensor5(input, + {1, 2, 3, 5, 4})), + {1, 2, 3, 5, 4})}; + } + else if (adjusted_axis == 3) + { + return {permute_tensor5(apply_to_slices(permute_tensor5(input, + {1, 2, 5, 4, 3})), + {1, 2, 5, 4, 3})}; + } + else if (adjusted_axis == 2) + { + return {permute_tensor5(apply_to_slices(permute_tensor5(input, + {1, 5, 3, 4, 2})), + {1, 5, 3, 4, 2})}; + } + else if (adjusted_axis == 1) + { + return {permute_tensor5(apply_to_slices(permute_tensor5(input, + {5, 2, 3, 4, 1})), + {5, 2, 3, 4, 1})}; + } + else { + raise_error("Invalid axis for batch normalization."); + // Just to make the compiler happy. + // In reality, this is never called. + return inputs; + } + } }; diff --git a/keras_export/convert_model.py b/keras_export/convert_model.py index 5a7db5d7..9da5fb2e 100755 --- a/keras_export/convert_model.py +++ b/keras_export/convert_model.py @@ -354,8 +354,6 @@ def show_batch_normalization_layer(layer): else: assert len(layer.axis) == 1 layer_axis = layer.axis[0] - assert layer_axis == -1 or layer_axis + 1 == len(layer.input_shape), \ - "BatchNormalization only supported on the last tensor axis" moving_mean = K.get_value(layer.moving_mean) moving_variance = K.get_value(layer.moving_variance) result = {} diff --git a/keras_export/generate_test_models.py b/keras_export/generate_test_models.py index 3784d830..e2eba789 100644 --- a/keras_export/generate_test_models.py +++ b/keras_export/generate_test_models.py @@ -152,7 +152,27 @@ def get_test_model_exhaustive(): outputs.append(GlobalMaxPooling2D()(inputs[4])) outputs.append(GlobalMaxPooling2D(data_format="channels_first")(inputs[4])) + outputs.append(BatchNormalization()(inputs[0])) + outputs.append(BatchNormalization(axis=1)(inputs[0])) + outputs.append(BatchNormalization(axis=2)(inputs[0])) + outputs.append(BatchNormalization(axis=3)(inputs[0])) + outputs.append(BatchNormalization(axis=4)(inputs[0])) + outputs.append(BatchNormalization(axis=5)(inputs[0])) + outputs.append(BatchNormalization()(inputs[2])) + outputs.append(BatchNormalization(axis=1)(inputs[2])) + outputs.append(BatchNormalization(axis=2)(inputs[2])) + outputs.append(BatchNormalization(axis=3)(inputs[2])) + outputs.append(BatchNormalization(axis=4)(inputs[2])) outputs.append(BatchNormalization()(inputs[4])) + #outputs.append(BatchNormalization(axis=1)(inputs[4])) # tensorflow.python.framework.errors_impl.InternalError: The CPU implementation of FusedBatchNorm only supports NHWC tensor format for now. + outputs.append(BatchNormalization(axis=2)(inputs[4])) + outputs.append(BatchNormalization(axis=3)(inputs[4])) + outputs.append(BatchNormalization()(inputs[6])) + outputs.append(BatchNormalization(axis=1)(inputs[6])) + outputs.append(BatchNormalization(axis=2)(inputs[6])) + outputs.append(BatchNormalization()(inputs[8])) + outputs.append(BatchNormalization(axis=1)(inputs[8])) + outputs.append(Dropout(0.5)(inputs[4])) outputs.append(ZeroPadding2D(2)(inputs[4]))