diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index 3d9fa7f5..8e59f50d 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -18,6 +18,8 @@ UnytError, ) from unyt.testing import assert_array_equal_units +from unyt.unit_object import Unit +from unyt.unit_registry import UnitRegistry NUMPY_VERSION = Version(version("numpy")) @@ -265,6 +267,23 @@ def test_wrapping_completeness(): assert function in all_funcs +@pytest.mark.parametrize( + "arrays", + [ + [np.array([1]), [2] * Unit()], + [np.array([1]), [2] * Unit(registry=UnitRegistry())], + # [[1], [2] * Unit()], + ], +) +def test_unit_validation(arrays): + # see https://github.com/yt-project/unyt/issues/462 + # numpy.concatenate isn't essential to this test + # what we're really testing is the unit consistency validation + # underneath, but we do so using public API + res = np.concatenate(arrays) + assert res.units.is_dimensionless + + def test_array_repr(): arr = [1, 2, 3] * cm assert re.fullmatch(r"unyt_array\(\[1, 2, 3\], (units=)?'cm'\)", np.array_repr(arr))