From 8754a5c1abaff4cc83021e78ad70b02153b83be3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 1 Nov 2023 22:38:15 +0100 Subject: [PATCH 1/2] TST: add tests for bug 462 --- unyt/tests/test_array_functions.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index 3d9fa7f5..8e59f50d 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -18,6 +18,8 @@ UnytError, ) from unyt.testing import assert_array_equal_units +from unyt.unit_object import Unit +from unyt.unit_registry import UnitRegistry NUMPY_VERSION = Version(version("numpy")) @@ -265,6 +267,23 @@ def test_wrapping_completeness(): assert function in all_funcs +@pytest.mark.parametrize( + "arrays", + [ + [np.array([1]), [2] * Unit()], + [np.array([1]), [2] * Unit(registry=UnitRegistry())], + # [[1], [2] * Unit()], + ], +) +def test_unit_validation(arrays): + # see https://github.com/yt-project/unyt/issues/462 + # numpy.concatenate isn't essential to this test + # what we're really testing is the unit consistency validation + # underneath, but we do so using public API + res = np.concatenate(arrays) + assert res.units.is_dimensionless + + def test_array_repr(): arr = [1, 2, 3] * cm assert re.fullmatch(r"unyt_array\(\[1, 2, 3\], (units=)?'cm'\)", np.array_repr(arr)) From 9dbd8ec36107008b7c628f97ae39cf1584212a19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 1 Nov 2023 22:50:10 +0100 Subject: [PATCH 2/2] BUG: fix an issue where array functions would raise UnitConsistencyError on unyt arrays using non-default unit registries --- unyt/_array_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index ded81b49..0585df24 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -216,8 +216,8 @@ def _validate_units_consistency(objs): # because it's already a necessary condition for numpy to use our # custom implementations units = get_units(objs) - sunits = set(units) - if len(sunits) == 1: + unique_units = set(units) + if len(unique_units) == 1 or all(u.is_dimensionless for u in unique_units): return units[0] else: raise UnitInconsistencyError(*units)