Skip to content

Commit

Permalink
Merge pull request #466 from neutrinoceros/hotfix_1D_histogram_implic…
Browse files Browse the repository at this point in the history
…t_units
  • Loading branch information
jzuhone authored Nov 2, 2023
2 parents bba2d87 + 7bb2324 commit 9e4974a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
13 changes: 10 additions & 3 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,16 @@ def _sanitize_range(_range, units):
ilim = _range[2 * i : 2 * (i + 1)]
imin, imax = ilim
if not (hasattr(imin, "units") and hasattr(imax, "units")):
raise TypeError(
f"Elements of range must both have a 'units' attribute. Got {_range}"
)
if len(units) == 1:
# allow range to be pure numerical scalars
# for backward compatibility with unyt 2.9.5
# see https://github.com/yt-project/unyt/issues/465
imin *= units[0]
imax *= units[0]
else:
raise TypeError(
f"Elements of range must both have a 'units' attribute. Got {_range}"
)
new_range[i] = imin.to_value(units[i]), imax.to_value(units[i])
return new_range.squeeze()

Expand Down
8 changes: 8 additions & 0 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,14 @@ def test_histogram():
assert bins.units == arr.units


def test_histogram_implicit_units():
# see https://github.com/yt-project/unyt/issues/465
arr = np.random.normal(size=1000) * cm
counts, bins = np.histogram(arr, bins=10, range=(arr.min().value, arr.max().value))
assert type(counts) is np.ndarray
assert bins.units == arr.units


def test_histogram2d():
x = np.random.normal(size=100) * cm
y = np.random.normal(loc=10, size=100) * s
Expand Down

0 comments on commit 9e4974a

Please sign in to comment.