Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions docs/source/pitfall.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,54 @@ in the returned expression.
Replacing ``auto tmp`` with ``xt::xarray<double> tmp`` does not change anything, ``tmp``
is still an lvalue and thus captured by reference.

.. warning::

This issue is particularly subtle with reducer functions like :cpp:func:`xt::amax`,
:cpp:func:`xt::sum`, etc. Consider the following function:

.. code::

template <typename T>
xt::xtensor<T, 2> logSoftmax(const xt::xtensor<T, 2> &matrix)
{
xt::xtensor<T, 2> maxVals = xt::amax(matrix, {1}, xt::keep_dims);
auto shifted = matrix - maxVals;
auto expVals = xt::exp(shifted);
auto sumExp = xt::sum(expVals, {1}, xt::keep_dims);
return shifted - xt::log(sumExp);
}

This function may produce incorrect results or crash, especially in optimized builds.
The issue is that ``shifted``, ``expVals``, and ``sumExp`` are all lazy expressions
that hold references to local variables. When the function returns, these local
variables are destroyed, and the returned expression contains dangling references.

The fix is to use explicit container types to force evaluation:

.. code::

template <typename T>
xt::xtensor<T, 2> logSoftmax(const xt::xtensor<T, 2> &matrix)
{
xt::xtensor<T, 2> maxVals = xt::amax(matrix, {1}, xt::keep_dims);
xt::xtensor<T, 2> shifted = matrix - maxVals;
xt::xtensor<T, 2> expVals = xt::exp(shifted);
xt::xtensor<T, 2> sumExp = xt::sum(expVals, {1}, xt::keep_dims);
return shifted - xt::log(sumExp);
}

Alternatively, you can use :cpp:func:`xt::eval` to force evaluation:

.. code::

auto shifted = xt::eval(matrix - maxVals);

Or use the immediate evaluation strategy for reducers:

.. code::

auto sumExp = xt::sum(expVals, {1}, xt::evaluation_strategy::immediate | xt::keep_dims);

Random numbers not consistent
-----------------------------

Expand Down
46 changes: 46 additions & 0 deletions test/test_xmath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,4 +969,50 @@ namespace xt
EXPECT_TRUE(xt::allclose(expected, unwrapped));
}
}

// Test for GitHub issue #2871: Proper handling of intermediate results
// This test documents the correct way to use reducers with keep_dims
// when intermediate expressions are needed.
TEST(xmath, issue_2871_intermediate_result_handling)
{
// This test verifies the correct pattern for using reducers with
// intermediate results. Using 'auto' with lazy expressions can lead
// to dangling references when the function returns.

// The CORRECT way: use explicit container types for intermediate results
auto logSoftmax_correct = [](const xt::xtensor<double, 2>& matrix)
{
xt::xtensor<double, 2> maxVals = xt::amax(matrix, {1}, xt::keep_dims);
xt::xtensor<double, 2> shifted = matrix - maxVals;
xt::xtensor<double, 2> expVals = xt::exp(shifted);
xt::xtensor<double, 2> sumExp = xt::sum(expVals, {1}, xt::keep_dims);
return xt::xtensor<double, 2>(shifted - xt::log(sumExp));
};

// Alternative CORRECT way: use xt::eval for intermediate results
auto logSoftmax_eval = [](const xt::xtensor<double, 2>& matrix)
{
auto maxVals = xt::eval(xt::amax(matrix, {1}, xt::keep_dims));
auto shifted = xt::eval(matrix - maxVals);
auto expVals = xt::eval(xt::exp(shifted));
auto sumExp = xt::eval(xt::sum(expVals, {1}, xt::keep_dims));
return xt::xtensor<double, 2>(shifted - xt::log(sumExp));
};

// Test data
xt::xtensor<double, 2> input = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}};

// Both implementations should produce the same result
auto result1 = logSoftmax_correct(input);
auto result2 = logSoftmax_eval(input);

EXPECT_TRUE(xt::allclose(result1, result2));

// Verify the result is a valid log-softmax (rows sum to 0 in log space)
// exp(log_softmax).sum(axis=1) should equal 1
auto exp_result = xt::exp(result1);
auto row_sums = xt::sum(exp_result, {1});
xt::xtensor<double, 1> expected_sums = {1.0, 1.0};
EXPECT_TRUE(xt::allclose(row_sums, expected_sums));
}
}