From d2b5758f2756f7fbb972bb59e73a23f77e077447 Mon Sep 17 00:00:00 2001 From: Tobias Hermann Date: Sat, 3 Jul 2021 16:51:45 +0200 Subject: [PATCH] Fix invalid-beta error when using BatchNormalization layer inside TimeDistribution * Add test case for BatchNormalization as inner layer of TimeDistributed * Fix rank of slices passed to inner layer of TimeDistribution --- .../fdeep/layers/time_distributed_layer.hpp | 5 +++++ include/fdeep/tensor.hpp | 4 ++++ include/fdeep/tensor_shape.hpp | 18 ++++++++++++++++++ keras_export/generate_test_models.py | 1 + 4 files changed, 28 insertions(+) diff --git a/include/fdeep/layers/time_distributed_layer.hpp b/include/fdeep/layers/time_distributed_layer.hpp index 745fb8a1..921793ac 100644 --- a/include/fdeep/layers/time_distributed_layer.hpp +++ b/include/fdeep/layers/time_distributed_layer.hpp @@ -67,6 +67,11 @@ class time_distributed_layer : public layer else raise_error("invalid input dim for TimeDistributed"); + for (auto& slice: slices) + { + slice.shrink_rank(); + } + if (td_output_len_ == 2) concat_axis = 2; else if (td_output_len_ == 3) diff --git a/include/fdeep/tensor.hpp b/include/fdeep/tensor.hpp index 7dc8eb88..a22545bc 100644 --- a/include/fdeep/tensor.hpp +++ b/include/fdeep/tensor.hpp @@ -123,6 +123,10 @@ class tensor { return shape_; } + void shrink_rank() + { + shape_.shrink_rank(); + } std::size_t depth() const { return shape().depth_; diff --git a/include/fdeep/tensor_shape.hpp b/include/fdeep/tensor_shape.hpp index 2e7e3172..81f517c0 100644 --- a/include/fdeep/tensor_shape.hpp +++ b/include/fdeep/tensor_shape.hpp @@ -122,6 +122,24 @@ class tensor_shape return rank_; } + std::size_t minimal_rank() const + { + if (size_dim_5_ > 1) + return 5; + if (size_dim_4_ > 1) + return 4; + if (height_ > 1) + return 3; + if (width_ > 1) + return 2; + return 1; + } + + void shrink_rank() + { + rank_ = minimal_rank(); + } + std::vector dimensions() const { if (rank() == 5) diff --git a/keras_export/generate_test_models.py b/keras_export/generate_test_models.py index d8481d87..17a997fc 100644 --- a/keras_export/generate_test_models.py +++ b/keras_export/generate_test_models.py @@ -466,6 +466,7 @@ def get_test_model_recurrent(): outputs.append(TimeDistributed(MaxPooling2D(2, 2))(inputs[3])) outputs.append(TimeDistributed(AveragePooling2D(2, 2))(inputs[3])) + outputs.append(TimeDistributed(BatchNormalization())(inputs[3])) model = Model(inputs=inputs, outputs=outputs, name='test_model_recurrent') model.compile(loss='mse', optimizer='nadam')