Skip to content

Commit

Permalink
Fix softmax implementation, calculate max per depth slice, issue #393
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Sep 20, 2023
1 parent 67a8fbc commit 1d2982e
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions include/fdeep/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1226,29 +1226,28 @@ inline tensor smart_resize_tensor_2d(const tensor& in_vol, const shape2& target_

inline tensor softmax(const tensor& input)
{
// Get unnormalized values of exponent function.
const auto ex = [](float_type x) -> float_type
{
return std::exp(x);
};
const float_type m = input.get(tensor_max_pos(input));
const auto inp_shifted = subtract_tensor(input, tensor(input.shape(), m));
auto output = transform_tensor(ex, inp_shifted);
tensor output = tensor(input.shape(), static_cast<float_type>(0));

// Softmax function is applied along channel dimension.
for (size_t y = 0; y < input.shape().height_; ++y)
{
for (size_t x = 0; x < input.shape().width_; ++x)
{
// Get the sum of unnormalized values for one pixel.
float_type m = std::numeric_limits<float_type>::lowest();
for (size_t z_class = 0; z_class < input.shape().depth_; ++z_class)
{
m = std::max(m, input.get_ignore_rank(tensor_pos(y, x, z_class)));
}
const auto inp_shifted = subtract_tensor(input, tensor(input.shape(), m));

// We are not using Kahan summation, since the number
// of object classes is usually quite small.
float_type sum_shifted = 0.0f;
for (size_t z_class = 0; z_class < input.shape().depth_; ++z_class)
{
sum_shifted += output.get_ignore_rank(tensor_pos(y, x, z_class));
sum_shifted += std::exp(inp_shifted.get_ignore_rank(tensor_pos(y, x, z_class)));
}
// Divide the unnormalized values of each pixel by the stacks sum.

const auto log_sum_shifted = std::log(sum_shifted);
for (size_t z_class = 0; z_class < input.shape().depth_; ++z_class)
{
Expand Down

0 comments on commit 1d2982e

Please sign in to comment.