From 196685d0f3ecee2f3dce4b9da2cdc8360dfd65c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 Dec 2024 10:45:02 +0100 Subject: [PATCH 1/2] TST: add regression tests for issue gh-540 --- unyt/tests/test_array_functions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/unyt/tests/test_array_functions.py b/unyt/tests/test_array_functions.py index a039939e..291db7d8 100644 --- a/unyt/tests/test_array_functions.py +++ b/unyt/tests/test_array_functions.py @@ -9,7 +9,7 @@ from numpy.testing import assert_allclose from packaging.version import Version -from unyt import A, K, cm, degC, delta_degC, g, km, rad, s +from unyt import A, K, Msun, cm, degC, delta_degC, g, km, rad, s from unyt._array_functions import ( _HANDLED_FUNCTIONS as HANDLED_FUNCTIONS, _UNSUPPORTED_FUNCTIONS as UNSUPPORTED_FUNCTIONS, @@ -831,6 +831,12 @@ def test_histogramdd_with_weights_and_dimless_arr(self): assert not hasattr(ywbins2, "units") assert not hasattr(zwbins2, "units") + @pytest.mark.parametrize("weights", [None, [0, 1, 2], [0, 1, 2] * cm]) + def test_histogramdd_recursion(self, weights): + # regression test for https://github.com/yt-project/unyt/issues/540 + sample = [unyt_array(np.arange(3), Msun)] + np.histogramdd(sample, density=True, weights=weights) + def test_histogram_bin_edges(self): rng = np.random.default_rng() arr = rng.normal(size=1000) * cm From d41543be27fc571ff0659d5fc3078e275677e3ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Wed, 18 Dec 2024 10:48:22 +0100 Subject: [PATCH 2/2] BUG: fix an issue where np.histogramdd could create infinite recursion on some inputs --- unyt/_array_functions.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/unyt/_array_functions.py b/unyt/_array_functions.py index 62aca69b..3bd811b7 100644 --- a/unyt/_array_functions.py +++ b/unyt/_array_functions.py @@ -286,9 +286,16 @@ def _histogramdd( # don't getattr(..., "units", NULL_UNIT) because e.g. we don't want # a unyt_array if weights are not a unyt_array and not density if density: + divider_units = 1 * NULL_UNIT for s in sample: - counts /= getattr(s, "units", 1) - counts *= getattr(weights, "units", 1) + if not hasattr(s, "units"): + continue + divider_units *= s.units + if divider_units != NULL_UNIT: + counts /= divider_units + + if weights is not None and hasattr(weights, "units"): + counts *= weights.units return counts, tuple(_bin * getattr(s, "units", 1) for _bin, s in zip(bins, sample))