Skip to content

Commit

Permalink
Add support for keepdims in global pooling layers (#387)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd authored Jul 23, 2023
1 parent c3f4c64 commit e5db13c
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
10 changes: 6 additions & 4 deletions include/fdeep/import_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,17 +568,19 @@ inline layer_ptr create_average_pooling_3d_layer(
}

inline layer_ptr create_global_max_pooling_3d_layer(
const get_param_f&, const nlohmann::json&,
const get_param_f&, const nlohmann::json& data,
const std::string& name)
{
return std::make_shared<global_max_pooling_3d_layer>(name);
const bool keepdims = data["config"]["keepdims"];
return std::make_shared<global_max_pooling_3d_layer>(name, keepdims);
}

inline layer_ptr create_global_average_pooling_3d_layer(
const get_param_f&, const nlohmann::json&,
const get_param_f&, const nlohmann::json& data,
const std::string& name)
{
return std::make_shared<global_average_pooling_3d_layer>(name);
const bool keepdims = data["config"]["keepdims"];
return std::make_shared<global_average_pooling_3d_layer>(name, keepdims);
}

inline layer_ptr create_upsampling_1d_layer(
Expand Down
10 changes: 7 additions & 3 deletions include/fdeep/layers/global_average_pooling_3d_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ namespace fdeep { namespace internal
class global_average_pooling_3d_layer : public global_pooling_layer
{
public:
explicit global_average_pooling_3d_layer(const std::string& name) :
global_pooling_layer(name)
explicit global_average_pooling_3d_layer(const std::string& name, bool keepdims) :
global_pooling_layer(name), keepdims_(keepdims)
{
}
protected:
tensor pool(const tensor& in) const override
{
tensor out(tensor_shape(in.shape().depth_), 0);
const auto out_dimensions = keepdims_ ?
fplus::append_elem(in.shape().depth_, std::vector<std::size_t>(in.shape().rank() - 1, 1)) :
fplus::singleton_seq(in.shape().depth_);
tensor out(create_tensor_shape_from_dims(out_dimensions), 0);
for (std::size_t z = 0; z < in.shape().depth_; ++z)
{
float_type val = 0;
Expand All @@ -41,6 +44,7 @@ class global_average_pooling_3d_layer : public global_pooling_layer
}
return out;
}
bool keepdims_;
};

} } // namespace fdeep, namespace internal
10 changes: 7 additions & 3 deletions include/fdeep/layers/global_max_pooling_3d_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ namespace fdeep { namespace internal
class global_max_pooling_3d_layer : public global_pooling_layer
{
public:
explicit global_max_pooling_3d_layer(const std::string& name) :
global_pooling_layer(name)
explicit global_max_pooling_3d_layer(const std::string& name, bool keepdims) :
global_pooling_layer(name), keepdims_(keepdims)
{
}
protected:
tensor pool(const tensor& in) const override
{
tensor out(tensor_shape(in.shape().depth_), 0);
const auto out_dimensions = keepdims_ ?
fplus::append_elem(in.shape().depth_, std::vector<std::size_t>(in.shape().rank() - 1, 1)) :
fplus::singleton_seq(in.shape().depth_);
tensor out(create_tensor_shape_from_dims(out_dimensions), 0);
for (std::size_t z = 0; z < in.shape().depth_; ++z)
{
float_type val = std::numeric_limits<float_type>::lowest();
Expand All @@ -43,6 +46,7 @@ class global_max_pooling_3d_layer : public global_pooling_layer
}
return out;
}
bool keepdims_;
};

} } // namespace fdeep, namespace internal
2 changes: 0 additions & 2 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,6 @@ def get_layer_weights(layer, name):
layer_type = type(layer).__name__
if hasattr(layer, 'data_format'):
assert layer.data_format == 'channels_last'
if hasattr(layer, 'keepdims'): # Pooling layers
assert not layer.keepdims

show_func = get_layer_functions_dict().get(layer_type, None)
shown_layer = None
Expand Down
6 changes: 6 additions & 0 deletions keras_export/generate_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def get_test_model_exhaustive():
outputs.append(AveragePooling1D(2, strides=2, padding='same')(inputs[6]))
outputs.append(GlobalMaxPooling1D()(inputs[6]))
outputs.append(GlobalAveragePooling1D()(inputs[6]))
outputs.append(GlobalMaxPooling1D(keepdims=True)(inputs[6]))
outputs.append(GlobalAveragePooling1D(keepdims=True)(inputs[6]))

outputs.append(Normalization(axis=None, mean=2.1, variance=2.2)(inputs[4]))
outputs.append(Normalization(axis=-1, mean=2.1, variance=2.2)(inputs[6]))
Expand Down Expand Up @@ -203,6 +205,10 @@ def get_test_model_exhaustive():
outputs.append(GlobalAveragePooling3D()(inputs[2]))
outputs.append(GlobalMaxPooling2D()(inputs[4]))
outputs.append(GlobalMaxPooling3D()(inputs[2]))
outputs.append(GlobalAveragePooling2D(keepdims=True)(inputs[4]))
outputs.append(GlobalAveragePooling3D(keepdims=True)(inputs[2]))
outputs.append(GlobalMaxPooling2D(keepdims=True)(inputs[4]))
outputs.append(GlobalMaxPooling3D(keepdims=True)(inputs[2]))

outputs.append(CenterCrop(4, 5)(inputs[4]))
outputs.append(CenterCrop(5, 6)(inputs[4]))
Expand Down

0 comments on commit e5db13c

Please sign in to comment.