Skip to content

Commit

Permalink
Fix multiple grouping with missing groups (#9650)
Browse files Browse the repository at this point in the history
* Fix multiple grouping with missing groups

Closes #9360

* Small repr improvement

* Small optimization in mask

* Add whats-new

* fix doctests
  • Loading branch information
dcherian authored Oct 21, 2024
1 parent 01831a4 commit df87f69
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 20 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Bug fixes
- Fix the safe_chunks validation option on the to_zarr method
(:issue:`5511`, :pull:`9559`). By `Joseph Nowak
<https://github.com/josephnowak>`_.
- Fix binning by multiple variables where some bins have no observations. (:issue:`9630`).
By `Deepak Cherian <https://github.com/dcherian>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6800,7 +6800,7 @@ def groupby(
>>> da.groupby("letters")
<DataArrayGroupBy, grouped over 1 grouper(s), 2 groups in total:
'letters': 2 groups with labels 'a', 'b'>
'letters': 2/2 groups present with labels 'a', 'b'>
Execute a reduction
Expand All @@ -6816,8 +6816,8 @@ def groupby(
>>> da.groupby(["letters", "x"])
<DataArrayGroupBy, grouped over 2 grouper(s), 8 groups in total:
'letters': 2 groups with labels 'a', 'b'
'x': 4 groups with labels 10, 20, 30, 40>
'letters': 2/2 groups present with labels 'a', 'b'
'x': 4/4 groups present with labels 10, 20, 30, 40>
Use Grouper objects to express more complicated GroupBy operations
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10403,7 +10403,7 @@ def groupby(
>>> ds.groupby("letters")
<DatasetGroupBy, grouped over 1 grouper(s), 2 groups in total:
'letters': 2 groups with labels 'a', 'b'>
'letters': 2/2 groups present with labels 'a', 'b'>
Execute a reduction
Expand All @@ -10420,8 +10420,8 @@ def groupby(
>>> ds.groupby(["letters", "x"])
<DatasetGroupBy, grouped over 2 grouper(s), 8 groups in total:
'letters': 2 groups with labels 'a', 'b'
'x': 4 groups with labels 10, 20, 30, 40>
'letters': 2/2 groups present with labels 'a', 'b'
'x': 4/4 groups present with labels 10, 20, 30, 40>
Use Grouper objects to express more complicated GroupBy operations
Expand Down
19 changes: 7 additions & 12 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def factorize(self) -> EncodedGroups:
# At this point all arrays have been factorized.
codes = tuple(grouper.codes for grouper in groupers)
shape = tuple(grouper.size for grouper in groupers)
masks = tuple((code == -1) for code in codes)
# We broadcast the codes against each other
broadcasted_codes = broadcast(*codes)
# This fully broadcasted DataArray is used as a template later
Expand All @@ -464,24 +465,18 @@ def factorize(self) -> EncodedGroups:
)
# NaNs; as well as values outside the bins are coded by -1
# Restore these after the raveling
mask = functools.reduce(
np.logical_or, # type: ignore[arg-type]
[(code == -1) for code in broadcasted_codes],
)
broadcasted_masks = broadcast(*masks)
mask = functools.reduce(np.logical_or, broadcasted_masks) # type: ignore[arg-type]
_flatcodes[mask] = -1

midx = pd.MultiIndex.from_product(
(grouper.unique_coord.data for grouper in groupers),
full_index = pd.MultiIndex.from_product(
(grouper.full_index.values for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
# Constructing an index from the product is wrong when there are missing groups
# (e.g. binning, resampling). Account for that now.
midx = midx[np.sort(pd.unique(_flatcodes[~mask]))]
midx = full_index[np.sort(pd.unique(_flatcodes[~mask]))]

full_index = pd.MultiIndex.from_product(
(grouper.full_index.values for grouper in groupers),
names=tuple(grouper.name for grouper in groupers),
)
dim_name = "stacked_" + "_".join(str(grouper.name) for grouper in groupers)

coords = Coordinates.from_pandas_multiindex(midx, dim=dim_name)
Expand Down Expand Up @@ -684,7 +679,7 @@ def __repr__(self) -> str:
for grouper in self.groupers:
coord = grouper.unique_coord
labels = ", ".join(format_array_flat(coord, 30).split())
text += f"\n {grouper.name!r}: {coord.size} groups with labels {labels}"
text += f"\n {grouper.name!r}: {coord.size}/{grouper.full_index.size} groups present with labels {labels}"
return text + ">"

def _iter_grouped(self) -> Iterator[T_Xarray]:
Expand Down
36 changes: 34 additions & 2 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def test_groupby_repr(obj, dim) -> None:
N = len(np.unique(obj[dim]))
expected = f"<{obj.__class__.__name__}GroupBy"
expected += f", grouped over 1 grouper(s), {N} groups in total:"
expected += f"\n {dim!r}: {N} groups with labels "
expected += f"\n {dim!r}: {N}/{N} groups present with labels "
if dim == "x":
expected += "1, 2, 3, 4, 5>"
elif dim == "y":
Expand All @@ -623,7 +623,7 @@ def test_groupby_repr_datetime(obj) -> None:
actual = repr(obj.groupby("t.month"))
expected = f"<{obj.__class__.__name__}GroupBy"
expected += ", grouped over 1 grouper(s), 12 groups in total:\n"
expected += " 'month': 12 groups with labels "
expected += " 'month': 12/12 groups present with labels "
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>"
assert actual == expected

Expand Down Expand Up @@ -2953,3 +2953,35 @@ def test_groupby_transpose():
second = data.groupby("x").sum()

assert_identical(first, second.transpose(*first.dims))


def test_groupby_multiple_bin_grouper_missing_groups():
from numpy import nan

ds = xr.Dataset(
{"foo": (("z"), np.arange(12))},
coords={"x": ("z", np.arange(12)), "y": ("z", np.arange(12))},
)

actual = ds.groupby(
x=BinGrouper(np.arange(0, 13, 4)), y=BinGrouper(bins=np.arange(0, 16, 2))
).count()
expected = Dataset(
{
"foo": (
("x_bins", "y_bins"),
np.array(
[
[2.0, 2.0, nan, nan, nan, nan, nan],
[nan, nan, 2.0, 2.0, nan, nan, nan],
[nan, nan, nan, nan, 2.0, 1.0, nan],
]
),
)
},
coords={
"x_bins": ("x_bins", pd.IntervalIndex.from_breaks(np.arange(0, 13, 4))),
"y_bins": ("y_bins", pd.IntervalIndex.from_breaks(np.arange(0, 16, 2))),
},
)
assert_identical(actual, expected)

0 comments on commit df87f69

Please sign in to comment.