Skip to content

Commit

Permalink
Merge pull request #463 from neutrinoceros/hotfix_unit_consistency_va…
Browse files Browse the repository at this point in the history
…lidation

BUG: fix an issue where array functions would raise UnitConsistencyError on unyt arrays using non-default unit registries
  • Loading branch information
jzuhone authored Nov 2, 2023
2 parents 9e4974a + 9dbd8ec commit e59cea9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
4 changes: 2 additions & 2 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def _validate_units_consistency(objs):
# because it's already a necessary condition for numpy to use our
# custom implementations
units = get_units(objs)
sunits = set(units)
if len(sunits) == 1:
unique_units = set(units)
if len(unique_units) == 1 or all(u.is_dimensionless for u in unique_units):
return units[0]
else:
raise UnitInconsistencyError(*units)
Expand Down
19 changes: 19 additions & 0 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
UnytError,
)
from unyt.testing import assert_array_equal_units
from unyt.unit_object import Unit
from unyt.unit_registry import UnitRegistry

NUMPY_VERSION = Version(version("numpy"))

Expand Down Expand Up @@ -265,6 +267,23 @@ def test_wrapping_completeness():
assert function in all_funcs


@pytest.mark.parametrize(
"arrays",
[
[np.array([1]), [2] * Unit()],
[np.array([1]), [2] * Unit(registry=UnitRegistry())],
# [[1], [2] * Unit()],
],
)
def test_unit_validation(arrays):
# see https://github.com/yt-project/unyt/issues/462
# numpy.concatenate isn't essential to this test
# what we're really testing is the unit consistency validation
# underneath, but we do so using public API
res = np.concatenate(arrays)
assert res.units.is_dimensionless


def test_array_repr():
arr = [1, 2, 3] * cm
assert re.fullmatch(r"unyt_array\(\[1, 2, 3\], (units=)?'cm'\)", np.array_repr(arr))
Expand Down

0 comments on commit e59cea9

Please sign in to comment.