Skip to content

Commit

Permalink
Fix interpretation of axis value in Normalization layer, closes #357
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Jul 22, 2022
1 parent 6009a9b commit 2c354ac
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
27 changes: 15 additions & 12 deletions include/fdeep/layers/normalization_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,35 +38,38 @@ class normalization_layer : public layer

const int rank = static_cast<int>(input.shape().rank());

const auto axes = fplus::transform([&](int a) {
return a < 0 ? a + rank : rank - a;
}, axes_);

const auto transform_slice = [&](const std::size_t idx, const tensor& slice) -> tensor
{
const auto sqrt_of_variance = std::sqrt(variance_[idx]);
return transform_tensor([&](float_type x){ return (x - mean_[idx]) / sqrt_of_variance; }, slice);
};

if (axes_.empty()) {
return {transform_slice(0, input)};
assertion(variance_.size() == 1, "Invalid number of variance values in Normalization layer.");
}

assertion(axes_.size() <= 1, "Unsupported number of axes for Normalization layer. Must be 0 or 1.");
const auto axis_dim = axes_[0] == -1 ? 0 : rank - axes_[0];

const auto transform_slice_with_idx = [&](const tensors& slices) -> tensors
{
assertion(variance_.size() == slices.size(), "Invalid number of variance values in Normalization layer.");
return fplus::transform_with_idx(transform_slice, slices);
};

if (axes_.empty())
return {transform_slice(0, input)};
else if (axes[0] == 0)
if (axis_dim == 0)
return {concatenate_tensors_depth(transform_slice_with_idx(tensor_to_depth_slices(input)))};
else if (axes[0] == 1)
else if (axis_dim == 1)
return {concatenate_tensors_width(transform_slice_with_idx(tensor_to_tensors_width_slices(input)))};
else if (axes[0] == 2)
else if (axis_dim == 2)
return {concatenate_tensors_height(transform_slice_with_idx(tensor_to_tensors_height_slices(input)))};
else if (axes[0] == 3)
else if (axis_dim == 3)
return {concatenate_tensors_dim4(transform_slice_with_idx(tensor_to_tensors_dim4_slices(input)))};
else if (axes[0] == 4)
else if (axis_dim == 4)
return {concatenate_tensors_dim5(transform_slice_with_idx(tensor_to_tensors_dim5_slices(input)))};
else
raise_error("Invalid axis (" + std::to_string(axes[0]) + ") for Normalization layer");
raise_error("Invalid axis (" + std::to_string(axis_dim) + ") for Normalization layer");
return {};
}
const std::vector<int> axes_;
Expand Down
4 changes: 3 additions & 1 deletion keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,9 @@ def show_softmax_layer(layer):

def show_normalization_layer(layer):
"""Serialize normalization layer to dict"""
assert len(layer.axis) <= 1, "Multiple normalization axis are not supported"
assert len(layer.axis) <= 1, "Multiple normalization axes are not supported"
if len(layer.axis) == 1:
assert layer.axis[0] in (-1, 1, 2, 3, 4, 5), "Invalid axis for Normalization layer."
return {
'mean': encode_floats(layer.mean),
'variance': encode_floats(layer.variance)
Expand Down
12 changes: 10 additions & 2 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def get_test_model_exhaustive():
(1, 1, 4, 1, 6),
(1, 3, 1, 5, 1),
(2, 1, 4, 1, 1),
(1, ), # 46
(3, 1),
(6, 5, 4, 3, 2),
]

inputs = [Input(shape=s) for s in input_shapes]
Expand All @@ -157,14 +160,19 @@ def get_test_model_exhaustive():
outputs.append(GlobalAveragePooling1D()(inputs[6]))
outputs.append(GlobalAveragePooling1D(data_format="channels_first")(inputs[6]))

outputs.append(Normalization(axis=None, mean=2.1, variance=2.2)(inputs[4]))
outputs.append(Normalization(axis=-1, mean=2.1, variance=2.2)(inputs[6]))
outputs.append(Normalization(axis=-1, mean=2.1, variance=2.2)(inputs[46]))
outputs.append(Normalization(axis=1, mean=2.1, variance=2.2)(inputs[46]))
outputs.append(Normalization(axis=-1, mean=2.1, variance=2.2)(inputs[47]))
outputs.append(Normalization(axis=1, mean=2.1, variance=2.2)(inputs[47]))
outputs.append(Normalization(axis=2, mean=2.1, variance=2.2)(inputs[47]))
for axis in range(1, 6):
shape = input_shapes[0][axis - 1]
outputs.append(Normalization(axis=axis,
mean=np.random.rand(shape),
variance=np.random.rand(shape)
)(inputs[0]))
outputs.append(Normalization(axis=None, mean=2.1, variance=2.2)(inputs[4]))
outputs.append(Normalization(axis=-1, mean=2.1, variance=2.2)(inputs[6]))

outputs.append(Rescaling(23.5, 42.1)(inputs[0]))

Expand Down

0 comments on commit 2c354ac

Please sign in to comment.