-
Notifications
You must be signed in to change notification settings - Fork 131
Fix numba dispatch not returning arrays or wrong dtypes #1406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix numba dispatch not returning arrays or wrong dtypes #1406
Conversation
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype( | ||
x, y | ||
) | ||
np.testing.assert_allclose(x, y, rtol=1e-4, strict=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strict=True covers the shape/dtype mismatch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old logic with return x and y, didn't trigger y (the compare_shape_dtype
), because the assert eithers fails or returns None which is Falsy. We don't do anything with the output of this return which again is always Falsy if it doesn't fail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes an issue where the numba dispatch for Det and SLogDet was returning non-array outputs, which was causing failures in statespace models.
- Consolidated tests for both Det and SLogDet with parameterized inputs and dtypes.
- Updated the numba dispatch functions to wrap scalar outputs in NumPy arrays with type conversion.
- Refined the assertion checks in tests to ensure both outputs are consistently NumPy arrays.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
tests/link/numba/test_nlinalg.py | Consolidated test cases for Det and SLogDet with parametrize for dtype and op. |
tests/link/numba/test_basic.py | Modified the assertion function to ensure output array consistency using strict checks. |
pytensor/link/numba/dispatch/nlinalg.py | Updated Det and SLogDet dispatch functions to return NumPy arrays with proper type conversion. |
@@ -52,7 +52,7 @@ def numba_funcify_Det(op, node, **kwargs): | |||
|
|||
@numba_basic.numba_njit(inline="always") | |||
def det(x): | |||
return numba_basic.direct_cast(np.linalg.det(inputs_cast(x)), out_dtype) | |||
return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider adding a comment explaining why wrapping the output with np.array is necessary to ensure consistency in array outputs, which aids in maintainability.
Copilot uses AI. Check for mistakes.
The new tests show other Ops that were failing to respect dtypes. For instance |
NOTE: CI failing at this point
1657d4a
to
594d433
Compare
594d433
to
614ffdd
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (52.63%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1406 +/- ##
========================================
Coverage 82.10% 82.11%
========================================
Files 208 211 +3
Lines 49576 49686 +110
Branches 8791 8813 +22
========================================
+ Hits 40704 40798 +94
- Misses 6699 6710 +11
- Partials 2173 2178 +5
🚀 New features to boost your workflow:
|
According to the numba devs the dtypes are correct but python converts them back to float / integer when we get out of numba. Will tweak the tests to not consider it an xfail then, although it would be better if numba returned numpy scalars |
is that a problem on our side or numba's side? |
It's numba behavior. It doesn't return numpy scalars from jitted functions. import numpy as np
import numba
@numba.njit
def f(x):
return x + 1
x = np.int32(0)
y = f(x)
type(y), type(x + 1) # (int, numpy.int32) |
This lead to failures in statespace models, as the Elemwise raises if the inputs are not arrays.
Closes pymc-devs/pymc-extras#476
📚 Documentation preview 📚: https://pytensor--1406.org.readthedocs.build/en/1406/