diff --git a/include/fdeep/import_model.hpp b/include/fdeep/import_model.hpp index e559150f..79a4ba28 100644 --- a/include/fdeep/import_model.hpp +++ b/include/fdeep/import_model.hpp @@ -111,6 +111,11 @@ inline fplus::maybe create_maybe_size_t(const nlohmann::json& data) { return fplus::nothing(); } + const int signed_result = data; + if (signed_result < 0) + { + return fplus::nothing(); + } const std::size_t result = data; return fplus::just(result); } @@ -725,7 +730,7 @@ inline layer_ptr create_reshape_layer( const get_param_f&, const nlohmann::json& data, const std::string& name) { - const auto target_shape = create_tensor_shape(data["config"]["target_shape"]); + const auto target_shape = create_tensor_shape_variable(data["config"]["target_shape"]); return std::make_shared(name, target_shape); } diff --git a/include/fdeep/layers/reshape_layer.hpp b/include/fdeep/layers/reshape_layer.hpp index d2250a08..4c0f8907 100644 --- a/include/fdeep/layers/reshape_layer.hpp +++ b/include/fdeep/layers/reshape_layer.hpp @@ -18,7 +18,7 @@ class reshape_layer : public layer { public: explicit reshape_layer(const std::string& name, - const tensor_shape& target_shape) + const tensor_shape_variable& target_shape) : layer(name), target_shape_(target_shape) { @@ -27,9 +27,11 @@ class reshape_layer : public layer tensors apply_impl(const tensors& inputs) const override { const auto& input = single_tensor_from_tensors(inputs); - return {tensor(target_shape_, input.as_vector())}; + const auto fixed_target_shape = derive_fixed_tensor_shape( + input.shape().volume(), target_shape_); + return {tensor(fixed_target_shape, input.as_vector())}; } - tensor_shape target_shape_; + tensor_shape_variable target_shape_; }; } } // namespace fdeep, namespace internal diff --git a/include/fdeep/tensor_shape.hpp b/include/fdeep/tensor_shape.hpp index 2ec6695c..2e7e3172 100644 --- a/include/fdeep/tensor_shape.hpp +++ b/include/fdeep/tensor_shape.hpp @@ -206,6 +206,15 @@ inline tensor_shape make_tensor_shape_with( fplus::just_with_default(default_shape.depth_, shape.depth_)); } +inline tensor_shape derive_fixed_tensor_shape( + std::size_t values, + const tensor_shape_variable shape) +{ + const auto inferred = values / shape.minimal_volume(); + return make_tensor_shape_with( + tensor_shape(inferred, inferred, inferred, inferred, inferred), shape); +} + inline bool tensor_shape_equals_tensor_shape_variable( const tensor_shape& lhs, const tensor_shape_variable& rhs) { diff --git a/include/fdeep/tensor_shape_variable.hpp b/include/fdeep/tensor_shape_variable.hpp index 0cdde3f7..06a16486 100644 --- a/include/fdeep/tensor_shape_variable.hpp +++ b/include/fdeep/tensor_shape_variable.hpp @@ -87,6 +87,16 @@ class tensor_shape_variable { } + std::size_t minimal_volume() const + { + return + fplus::just_with_default(1, size_dim_5_) * + fplus::just_with_default(1, size_dim_4_) * + fplus::just_with_default(1, height_) * + fplus::just_with_default(1, width_) * + fplus::just_with_default(1, depth_); + } + std::size_t rank() const { return rank_; diff --git a/keras_export/convert_model.py b/keras_export/convert_model.py index 8bc60d2b..82869ab5 100755 --- a/keras_export/convert_model.py +++ b/keras_export/convert_model.py @@ -497,12 +497,6 @@ def show_softmax_layer(layer): assert layer.axis == -1 -def show_reshape_layer(layer): - """Serialize reshape layer to dict""" - for dim_size in layer.target_shape: - assert dim_size != -1, 'Reshape inference not supported' - - def get_layer_functions_dict(): return { 'Conv1D': show_conv_1d_layer, @@ -521,8 +515,7 @@ def get_layer_functions_dict(): 'Bidirectional': show_bidirectional_layer, 'TimeDistributed': show_time_distributed_layer, 'Input': show_input_layer, - 'Softmax': show_softmax_layer, - 'Reshape': show_reshape_layer + 'Softmax': show_softmax_layer } diff --git a/keras_export/generate_test_models.py b/keras_export/generate_test_models.py index 8c6da9f1..d8481d87 100644 --- a/keras_export/generate_test_models.py +++ b/keras_export/generate_test_models.py @@ -673,6 +673,8 @@ def get_test_model_variable(): outputs.append(Conv2D(8, (3, 3), padding='same', activation='elu')(inputs[0])) outputs.append(Conv2D(8, (3, 3), padding='same', activation='relu')(inputs[1])) outputs.append(GlobalMaxPooling2D()(inputs[0])) + outputs.append(Reshape((2, -1))(inputs[2])) + outputs.append(Reshape((-1, 2))(inputs[2])) outputs.append(MaxPooling2D()(inputs[1])) outputs.append(AveragePooling1D()(inputs[2]))