From 01831a4507f4c1f5bd362041521d5fdb7274addf Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 21 Oct 2024 09:52:47 -0600 Subject: [PATCH] flox: Properly propagate multiindex (#9649) * flox: Properly propagate multiindex Closes #9648 * skip test on old pandas * small optimization * fix --- doc/whats-new.rst | 2 +- xarray/core/coordinates.py | 11 ++++++++++ xarray/core/groupby.py | 41 ++++++++++++++++-------------------- xarray/groupers.py | 13 +----------- xarray/tests/__init__.py | 2 +- xarray/tests/test_groupby.py | 21 ++++++++++++++++++ 6 files changed, 53 insertions(+), 37 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e290cd88485..c47b184faa9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -63,7 +63,7 @@ Bug fixes the non-missing times could in theory be encoded with integers (:issue:`9488`, :pull:`9497`). By `Spencer Clark `_. -- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`). +- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`, :issue:`9648`). By `Deepak Cherian `_. - Fix the safe_chunks validation option on the to_zarr method (:issue:`5511`, :pull:`9559`). By `Joseph Nowak diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 91ef9b6ccad..c4a082f22b7 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -1116,3 +1116,14 @@ def create_coords_with_default_indexes( new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes) return new_coords + + +def _coordinates_from_variable(variable: Variable) -> Coordinates: + from xarray.core.indexes import create_default_index_implicit + + (name,) = variable.dims + new_index, index_vars = create_default_index_implicit(variable) + indexes = {k: new_index for k in index_vars} + new_vars = new_index.create_variables() + new_vars[name].attrs = variable.attrs + return Coordinates(new_vars, indexes) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index b09d7cf852c..5536c5d2e26 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -21,13 +21,13 @@ from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce from xarray.core.concat import concat -from xarray.core.coordinates import Coordinates +from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.formatting import format_array_flat from xarray.core.indexes import ( - PandasIndex, PandasMultiIndex, filter_indexes_from_coords, ) +from xarray.core.merge import merge_coords from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import ( Dims, @@ -851,7 +851,6 @@ def _flox_reduce( from flox.xarray import xarray_reduce from xarray.core.dataset import Dataset - from xarray.groupers import BinGrouper obj = self._original_obj variables = ( @@ -901,13 +900,6 @@ def _flox_reduce( # set explicitly to avoid unnecessarily accumulating count kwargs["min_count"] = 0 - unindexed_dims: tuple[Hashable, ...] = tuple( - grouper.name - for grouper in self.groupers - if isinstance(grouper.group, _DummyGroup) - and not isinstance(grouper.grouper, BinGrouper) - ) - parsed_dim: tuple[Hashable, ...] if isinstance(dim, str): parsed_dim = (dim,) @@ -963,26 +955,29 @@ def _flox_reduce( # we did end up reducing over dimension(s) that are # in the grouped variable group_dims = set(grouper.group.dims) - new_coords = {} + new_coords = [] + to_drop = [] if group_dims.issubset(set(parsed_dim)): - new_indexes = {} for grouper in self.groupers: output_index = grouper.full_index if isinstance(output_index, pd.RangeIndex): + # flox always assigns an index so we must drop it here if we don't need it. + to_drop.append(grouper.name) continue - name = grouper.name - new_coords[name] = IndexVariable( - dims=name, data=np.array(output_index), attrs=grouper.codes.attrs - ) - index_cls = ( - PandasIndex - if not isinstance(output_index, pd.MultiIndex) - else PandasMultiIndex + new_coords.append( + # Using IndexVariable here ensures we reconstruct PandasMultiIndex with + # all associated levels properly. + _coordinates_from_variable( + IndexVariable( + dims=grouper.name, + data=output_index, + attrs=grouper.codes.attrs, + ) + ) ) - new_indexes[name] = index_cls(output_index, dim=name) result = result.assign_coords( - Coordinates(new_coords, new_indexes) - ).drop_vars(unindexed_dims) + Coordinates._construct_direct(*merge_coords(new_coords)) + ).drop_vars(to_drop) # broadcast any non-dim coord variables that don't # share all dimensions with the grouper diff --git a/xarray/groupers.py b/xarray/groupers.py index e4cb884e6de..996f86317b9 100644 --- a/xarray/groupers.py +++ b/xarray/groupers.py @@ -16,7 +16,7 @@ from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq from xarray.core import duck_array_ops -from xarray.core.coordinates import Coordinates +from xarray.core.coordinates import Coordinates, _coordinates_from_variable from xarray.core.dataarray import DataArray from xarray.core.groupby import T_Group, _DummyGroup from xarray.core.indexes import safe_cast_to_index @@ -42,17 +42,6 @@ RESAMPLE_DIM = "__resample_dim__" -def _coordinates_from_variable(variable: Variable) -> Coordinates: - from xarray.core.indexes import create_default_index_implicit - - (name,) = variable.dims - new_index, index_vars = create_default_index_implicit(variable) - indexes = {k: new_index for k in index_vars} - new_vars = new_index.create_variables() - new_vars[name].attrs = variable.attrs - return Coordinates(new_vars, indexes) - - @dataclass(init=False) class EncodedGroups: """ diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index bd7ec6297b9..a0ac8d51f95 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -135,7 +135,7 @@ def _importorskip( has_pint, requires_pint = _importorskip("pint") has_numexpr, requires_numexpr = _importorskip("numexpr") has_flox, requires_flox = _importorskip("flox") -has_pandas_ge_2_2, __ = _importorskip("pandas", "2.2") +has_pandas_ge_2_2, requires_pandas_ge_2_2 = _importorskip("pandas", "2.2") has_pandas_3, requires_pandas_3 = _importorskip("pandas", "3.0.0.dev0") diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index dc869cc3a34..3c321166619 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -35,6 +35,7 @@ requires_dask, requires_flox, requires_flox_0_9_12, + requires_pandas_ge_2_2, requires_scipy, ) @@ -145,6 +146,26 @@ def test_multi_index_groupby_sum() -> None: assert_equal(expected, actual) +@requires_pandas_ge_2_2 +def test_multi_index_propagation(): + # regression test for GH9648 + times = pd.date_range("2023-01-01", periods=4) + locations = ["A", "B"] + data = [[0.5, 0.7], [0.6, 0.5], [0.4, 0.6], [0.4, 0.9]] + + da = xr.DataArray( + data, dims=["time", "location"], coords={"time": times, "location": locations} + ) + da = da.stack(multiindex=["time", "location"]) + grouped = da.groupby("multiindex") + + with xr.set_options(use_flox=True): + actual = grouped.sum() + with xr.set_options(use_flox=False): + expected = grouped.first() + assert_identical(actual, expected) + + def test_groupby_da_datetime() -> None: # test groupby with a DataArray of dtype datetime for GH1132 # create test data