Skip to content

Fix numba dispatch not returning arrays or wrong dtypes #1406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 1 addition & 26 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
fgraph_to_python,
)
from pytensor.scalar.basic import ScalarType
from pytensor.scalar.math import Softplus
from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot
Expand Down Expand Up @@ -466,7 +465,7 @@ def argort_vec(X, axis):
axis = axis.item()

Y = np.swapaxes(X, axis, 0)
result = np.empty_like(Y)
result = np.empty_like(Y, dtype="int64")

indices = list(np.ndindex(Y.shape[1:]))

Expand Down Expand Up @@ -607,25 +606,6 @@ def dot(x, y):
return dot


@numba_funcify.register(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
x_dtype = np.dtype(node.inputs[0].dtype)

@numba_njit
def softplus(x):
if x < -37.0:
value = np.exp(x)
elif x < 18.0:
value = np.log1p(np.exp(x))
elif x < 33.3:
value = x + np.exp(-x)
else:
value = x
return direct_cast(value, x_dtype)

return softplus


@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
Expand Down Expand Up @@ -689,11 +669,6 @@ def batched_dot(x, y):
return batched_dot


# NOTE: The remaining `pytensor.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba


@numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def numba_funcify_Argmax(op, node, **kwargs):

@numba_basic.numba_njit(inline="always")
def argmax(x):
return 0
return np.array(0, dtype="int64")

else:
axes = tuple(int(ax) for ax in axis)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/linalg/decomposition/lu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
IPIV = IPIV - 1
p_inv = _pivot_to_permutation(IPIV, dtype=dtype)
perm = np.argsort(p_inv)
perm = np.argsort(p_inv).astype("int32")

Check warning on line 33 in pytensor/link/numba/dispatch/linalg/decomposition/lu.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/linalg/decomposition/lu.py#L33

Added line #L33 was not covered by tests

return perm, L, U

Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/numba/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def numba_funcify_Det(op, node, **kwargs):

@numba_basic.numba_njit(inline="always")
def det(x):
return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype)
Copy link
Preview

Copilot AI May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider adding a comment explaining why wrapping the output with np.array is necessary to ensure consistency in array outputs, which aids in maintainability.

Copilot uses AI. Check for mistakes.


return det

Expand All @@ -68,8 +68,8 @@ def numba_funcify_SLogDet(op, node, **kwargs):
def slogdet(x):
sign, det = np.linalg.slogdet(inputs_cast(x))
return (
numba_basic.direct_cast(sign, out_dtype_1),
numba_basic.direct_cast(det, out_dtype_2),
np.array(sign).astype(out_dtype_1),
np.array(det).astype(out_dtype_2),
)

return slogdet
Expand Down
21 changes: 20 additions & 1 deletion pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Second,
Switch,
)
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid
from pytensor.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid, Softplus


@numba_funcify.register(ScalarOp)
Expand Down Expand Up @@ -312,3 +312,22 @@
@numba_funcify.register(Erfc)
def numba_funcify_Erfc(op, **kwargs):
return numba_basic.global_numba_func(erfc)


@numba_funcify.register(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
out_dtype = np.dtype(node.outputs[0].type.dtype)

@numba_basic.numba_njit
def softplus(x):
if x < -37.0:
value = np.exp(x)

Check warning on line 324 in pytensor/link/numba/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/scalar.py#L324

Added line #L324 was not covered by tests
elif x < 18.0:
value = np.log1p(np.exp(x))

Check warning on line 326 in pytensor/link/numba/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/scalar.py#L326

Added line #L326 was not covered by tests
elif x < 33.3:
value = x + np.exp(-x)

Check warning on line 328 in pytensor/link/numba/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/scalar.py#L328

Added line #L328 was not covered by tests
else:
value = x
return numba_basic.direct_cast(value, out_dtype)

Check warning on line 331 in pytensor/link/numba/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/scalar.py#L330-L331

Added lines #L330 - L331 were not covered by tests

return softplus
56 changes: 8 additions & 48 deletions tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
numba = pytest.importorskip("numba")

import pytensor.scalar as ps
import pytensor.scalar.math as psm
import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor import config, shared
Expand Down Expand Up @@ -260,9 +259,12 @@ def compare_numba_and_py(
if assert_fn is None:

def assert_fn(x, y):
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype(
x, y
)
np.testing.assert_allclose(x, y, rtol=1e-4, strict=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strict=True covers the shape/dtype mismatch

Copy link
Member Author

@ricardoV94 ricardoV94 May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old logic with return x and y, didn't trigger y (the compare_shape_dtype), because the assert eithers fails or returns None which is Falsy. We don't do anything with the output of this return which again is always Falsy if it doesn't fail

# Make sure we don't have one input be a np.ndarray while the other is not
if isinstance(x, np.ndarray):
assert isinstance(y, np.ndarray), "y is not a NumPy array, but x is"
else:
assert not isinstance(y, np.ndarray), "y is a NumPy array, but x is not"

if any(
inp.owner is not None
Expand Down Expand Up @@ -295,8 +297,8 @@ def assert_fn(x, y):
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
numba_res = pytensor_numba_fn(*test_inputs_copy)
if isinstance(graph_outputs, tuple | list):
for j, p in zip(numba_res, py_res, strict=True):
assert_fn(j, p)
for numba_res_i, python_res_i in zip(numba_res, py_res, strict=True):
assert_fn(numba_res_i, python_res_i)
else:
assert_fn(numba_res, py_res)

Expand Down Expand Up @@ -640,48 +642,6 @@ def test_Dot(x, y, exc):
)


@pytest.mark.parametrize(
"x, exc",
[
(
(ps.float64(), np.array(0.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(-32.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(-40.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(32.0, dtype="float64")),
None,
),
(
(ps.float64(), np.array(40.0, dtype="float64")),
None,
),
(
(ps.int64(), np.array(32, dtype="int64")),
None,
),
],
)
def test_Softplus(x, exc):
x, x_test_value = x
g = psm.Softplus(ps.upgrade_to_float)(x)

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
[g],
[x_test_value],
)


@pytest.mark.parametrize(
"x, y, exc",
[
Expand Down
66 changes: 9 additions & 57 deletions tests/link/numba/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,68 +11,20 @@
rng = np.random.default_rng(42849)


@pytest.mark.parametrize("dtype", ("float64", "int64"))
@pytest.mark.parametrize(
"x, exc",
[
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
None,
),
],
"op", (nlinalg.Det(), nlinalg.SLogDet()), ids=["det", "slogdet"]
)
def test_Det(x, exc):
x, test_x = x
g = nlinalg.Det()(x)
def test_Det_SLogDet(op, dtype):
x = pt.matrix(dtype=dtype)

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
g,
[test_x],
)
rng = np.random.default_rng([50, sum(map(ord, dtype))])
x_ = rng.random(size=(3, 3)).astype(dtype)
test_x = x_.T.dot(x_)

g = op(x)

@pytest.mark.parametrize(
"x, exc",
[
(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
None,
),
],
)
def test_SLogDet(x, exc):
x, test_x = x
g = nlinalg.SLogDet()(x)

cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x],
g,
[test_x],
)
compare_numba_and_py([x], g, [test_x])


# We were seeing some weird results in CI where the following two almost
Expand Down
45 changes: 42 additions & 3 deletions tests/link/numba/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import pytensor.scalar as ps
import pytensor.scalar.basic as psb
import pytensor.scalar.math as psm
import pytensor.tensor as pt
from pytensor import config
from pytensor import config, function
from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode, py_mode


rng = np.random.default_rng(42849)
Expand Down Expand Up @@ -99,7 +100,11 @@ def test_Composite(inputs, input_values, scalar_fn):
"v, dtype",
[
((pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64),
((pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32),
pytest.param(
(pt.dscalar(), np.array(1.0, dtype="float64")),
psb.float32,
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
],
)
def test_Cast(v, dtype):
Expand Down Expand Up @@ -145,3 +150,37 @@ def test_isnan(composite):
[out],
[np.array([1, 0], dtype="float64")],
)


@pytest.mark.parametrize(
"dtype",
[
pytest.param(
"float32",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"float64",
pytest.param(
"int16",
marks=pytest.mark.xfail(reason="Scalar downcasting not supported in numba"),
),
"int64",
"uint32",
],
)
def test_Softplus(dtype):
x = ps.get_scalar_type(dtype)("x")
g = psm.softplus(x)

py_fn = function([x], g, mode=py_mode)
numba_fn = function([x], g, mode=numba_mode)
for value in (-40, -32, 0, 32, 40):
if value < 0 and dtype.startswith("u"):
continue
test_x = np.dtype(dtype).type(value)
np.testing.assert_allclose(
py_fn(test_x),
numba_fn(test_x),
strict=True,
err_msg=f"Failed for value {value}",
)