Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Nov 28, 2023
1 parent f7630b1 commit 689454f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
2 changes: 2 additions & 0 deletions include/fdeep/layers/batch_normalization_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ class batch_normalization_layer : public layer
tensors apply_impl(const tensors& inputs) const override
{
const auto input = single_tensor_from_tensors(inputs);

std::vector<std::size_t> dims(5, 1);
dims[rank_aligned_axis_to_absolute_axis(input.shape().rank(), axis_) - 1] = moving_mean_->size();
const tensor_shape params_shape = create_tensor_shape_from_dims(dims);

return {batch_normalization(
input,
broadcast(tensor(params_shape, moving_mean_), input.shape()),
Expand Down
4 changes: 3 additions & 1 deletion include/fdeep/layers/layer_normalization_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ class layer_normalization_layer : public layer

tensors apply_impl(const tensors& inputs) const override
{
// https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/normalization/layer_normalization.py#L291-L304
const auto& input = single_tensor_from_tensors(inputs);

// https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/normalization/layer_normalization.py#L291-L304
const auto& input_moments = moments(input, axes_);
const auto& mean = input_moments.first;
const auto& variance = input_moments.second;
Expand All @@ -52,6 +53,7 @@ class layer_normalization_layer : public layer
dims[pos] = input_shape_dimensions[pos];
}
const tensor_shape params_shape = create_tensor_shape_from_dims(dims);

return {batch_normalization(
input,
mean,
Expand Down
3 changes: 2 additions & 1 deletion include/fdeep/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,9 @@ inline tensor broadcast(const tensor& t, const tensor_shape& shape)
(t.shape().width_ == 1 || t.shape().width_ == shape.width_) &&
(t.shape().depth_ == 1 || t.shape().depth_ == shape.depth_),
"Invalid shapes for combining tensors.");

tensor out_tensor = tensor(shape, static_cast<float_type>(0));

loop_over_all_dims(out_tensor.shape(), [&](
std::size_t dim5, std::size_t dim4, std::size_t y, std::size_t x, std::size_t z)
{
Expand Down Expand Up @@ -907,7 +909,6 @@ inline tensor batch_normalization(
transform_tensor(
fplus::add_to(variance_epsilon), variance)),
scale);

return add_tensors(
mult_tensors(x, inv),
subtract_tensors(
Expand Down

0 comments on commit 689454f

Please sign in to comment.