Skip to content

Commit

Permalink
Add support for shape inference in Reshape layers, fixes #282
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd authored Jul 2, 2021
1 parent bd045c7 commit a6ec354
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 12 deletions.
7 changes: 6 additions & 1 deletion include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ inline fplus::maybe<std::size_t> create_maybe_size_t(const nlohmann::json& data)
{
return fplus::nothing<std::size_t>();
}
const int signed_result = data;
if (signed_result < 0)
{
return fplus::nothing<std::size_t>();
}
const std::size_t result = data;
return fplus::just(result);
}
Expand Down Expand Up @@ -725,7 +730,7 @@ inline layer_ptr create_reshape_layer(
const get_param_f&, const nlohmann::json& data,
const std::string& name)
{
const auto target_shape = create_tensor_shape(data["config"]["target_shape"]);
const auto target_shape = create_tensor_shape_variable(data["config"]["target_shape"]);
return std::make_shared<reshape_layer>(name, target_shape);
}

Expand Down
8 changes: 5 additions & 3 deletions include/fdeep/layers/reshape_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class reshape_layer : public layer
{
public:
explicit reshape_layer(const std::string& name,
const tensor_shape& target_shape)
const tensor_shape_variable& target_shape)
: layer(name),
target_shape_(target_shape)
{
Expand All @@ -27,9 +27,11 @@ class reshape_layer : public layer
tensors apply_impl(const tensors& inputs) const override
{
const auto& input = single_tensor_from_tensors(inputs);
return {tensor(target_shape_, input.as_vector())};
const auto fixed_target_shape = derive_fixed_tensor_shape(
input.shape().volume(), target_shape_);
return {tensor(fixed_target_shape, input.as_vector())};
}
tensor_shape target_shape_;
tensor_shape_variable target_shape_;
};

} } // namespace fdeep, namespace internal
9 changes: 9 additions & 0 deletions include/fdeep/tensor_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,15 @@ inline tensor_shape make_tensor_shape_with(
fplus::just_with_default(default_shape.depth_, shape.depth_));
}

inline tensor_shape derive_fixed_tensor_shape(
std::size_t values,
const tensor_shape_variable shape)
{
const auto inferred = values / shape.minimal_volume();
return make_tensor_shape_with(
tensor_shape(inferred, inferred, inferred, inferred, inferred), shape);
}

inline bool tensor_shape_equals_tensor_shape_variable(
const tensor_shape& lhs, const tensor_shape_variable& rhs)
{
Expand Down
10 changes: 10 additions & 0 deletions include/fdeep/tensor_shape_variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ class tensor_shape_variable
{
}

std::size_t minimal_volume() const
{
return
fplus::just_with_default<std::size_t>(1, size_dim_5_) *
fplus::just_with_default<std::size_t>(1, size_dim_4_) *
fplus::just_with_default<std::size_t>(1, height_) *
fplus::just_with_default<std::size_t>(1, width_) *
fplus::just_with_default<std::size_t>(1, depth_);
}

std::size_t rank() const
{
return rank_;
Expand Down
9 changes: 1 addition & 8 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,12 +497,6 @@ def show_softmax_layer(layer):
assert layer.axis == -1


def show_reshape_layer(layer):
"""Serialize reshape layer to dict"""
for dim_size in layer.target_shape:
assert dim_size != -1, 'Reshape inference not supported'


def get_layer_functions_dict():
return {
'Conv1D': show_conv_1d_layer,
Expand All @@ -521,8 +515,7 @@ def get_layer_functions_dict():
'Bidirectional': show_bidirectional_layer,
'TimeDistributed': show_time_distributed_layer,
'Input': show_input_layer,
'Softmax': show_softmax_layer,
'Reshape': show_reshape_layer
'Softmax': show_softmax_layer
}


Expand Down
2 changes: 2 additions & 0 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,8 @@ def get_test_model_variable():
outputs.append(Conv2D(8, (3, 3), padding='same', activation='elu')(inputs[0]))
outputs.append(Conv2D(8, (3, 3), padding='same', activation='relu')(inputs[1]))
outputs.append(GlobalMaxPooling2D()(inputs[0]))
outputs.append(Reshape((2, -1))(inputs[2]))
outputs.append(Reshape((-1, 2))(inputs[2]))
outputs.append(MaxPooling2D()(inputs[1]))
outputs.append(AveragePooling1D()(inputs[2]))

Expand Down

0 comments on commit a6ec354

Please sign in to comment.