diff --git a/include/fdeep/layers/batch_normalization_layer.hpp b/include/fdeep/layers/batch_normalization_layer.hpp index 796c73fc..e9b7e872 100644 --- a/include/fdeep/layers/batch_normalization_layer.hpp +++ b/include/fdeep/layers/batch_normalization_layer.hpp @@ -48,9 +48,11 @@ class batch_normalization_layer : public layer tensors apply_impl(const tensors& inputs) const override { const auto input = single_tensor_from_tensors(inputs); + std::vector dims(5, 1); dims[rank_aligned_axis_to_absolute_axis(input.shape().rank(), axis_) - 1] = moving_mean_->size(); const tensor_shape params_shape = create_tensor_shape_from_dims(dims); + return {batch_normalization( input, broadcast(tensor(params_shape, moving_mean_), input.shape()), diff --git a/include/fdeep/layers/layer_normalization_layer.hpp b/include/fdeep/layers/layer_normalization_layer.hpp index c2887ee1..8c9c99a9 100644 --- a/include/fdeep/layers/layer_normalization_layer.hpp +++ b/include/fdeep/layers/layer_normalization_layer.hpp @@ -37,8 +37,9 @@ class layer_normalization_layer : public layer tensors apply_impl(const tensors& inputs) const override { - // https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/normalization/layer_normalization.py#L291-L304 const auto& input = single_tensor_from_tensors(inputs); + + // https://github.com/keras-team/keras/blob/v2.14.0/keras/layers/normalization/layer_normalization.py#L291-L304 const auto& input_moments = moments(input, axes_); const auto& mean = input_moments.first; const auto& variance = input_moments.second; @@ -52,6 +53,7 @@ class layer_normalization_layer : public layer dims[pos] = input_shape_dimensions[pos]; } const tensor_shape params_shape = create_tensor_shape_from_dims(dims); + return {batch_normalization( input, mean, diff --git a/include/fdeep/tensor.hpp b/include/fdeep/tensor.hpp index 3d3ae4ac..ddc44f01 100644 --- a/include/fdeep/tensor.hpp +++ b/include/fdeep/tensor.hpp @@ -814,7 +814,9 @@ inline tensor broadcast(const tensor& t, const tensor_shape& shape) (t.shape().width_ == 1 || t.shape().width_ == shape.width_) && (t.shape().depth_ == 1 || t.shape().depth_ == shape.depth_), "Invalid shapes for combining tensors."); + tensor out_tensor = tensor(shape, static_cast(0)); + loop_over_all_dims(out_tensor.shape(), [&]( std::size_t dim5, std::size_t dim4, std::size_t y, std::size_t x, std::size_t z) { @@ -907,7 +909,6 @@ inline tensor batch_normalization( transform_tensor( fplus::add_to(variance_epsilon), variance)), scale); - return add_tensors( mult_tensors(x, inv), subtract_tensors(