From ed5900bb9f46389d0cf04015cf3f7126d87253c0 Mon Sep 17 00:00:00 2001 From: Tiago Sanona <40792244+tsanona@users.noreply.github.com> Date: Wed, 21 Aug 2024 22:47:35 +0200 Subject: [PATCH] Extend padding functionalities (#9353) * add functionality to pad dataset data variables with unique constant values * clean up implementation of variable specific padding for dataset. Add tests * more expressive docsting and symplefy type signature with alias in dataset pad func. enforce number values to be converted to tuples for all in `_pad_options_dim_to_index`. make variable pad funtion consistent with dataset. extend tests * fix typing * add terms to conf.py, make docstrings more accurate, expand tests for dataset pad function * filter constant value types without mutating input map * add todo to change default padding for missing variables in constant_values * add changes to whats new --------- Co-authored-by: tiago Co-authored-by: Deepak Cherian --- doc/conf.py | 3 ++ doc/whats-new.rst | 2 + xarray/core/dataset.py | 32 +++++++++----- xarray/core/types.py | 5 +++ xarray/core/variable.py | 18 ++++---- xarray/tests/test_dataset.py | 84 +++++++++++++++++++++++++++++++----- 6 files changed, 114 insertions(+), 30 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 4f1fc6751d2..93a0e459a33 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -153,6 +153,9 @@ "matplotlib colormap name": ":doc:`matplotlib colormap name `", "matplotlib axes object": ":py:class:`matplotlib axes object `", "colormap": ":py:class:`colormap `", + # xarray terms + "dim name": ":term:`dimension name `", + "var name": ":term:`variable name `", # objects without namespace: xarray "DataArray": "~xarray.DataArray", "Dataset": "~xarray.Dataset", diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e4ffc5ca799..2cf2d5928bf 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,6 +24,8 @@ New Features ~~~~~~~~~~~~ - Make chunk manager an option in ``set_options`` (:pull:`9362`). By `Tom White `_. +- Allow data variable specific ``constant_values`` in the dataset ``pad`` function (:pull:`9353``). + By `Tiago Sanona `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 0b9a085cebc..dbc00a03025 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -163,6 +163,7 @@ ReindexMethodOptions, SideOptions, T_ChunkDimFreq, + T_DatasetPadConstantValues, T_Xarray, ) from xarray.core.weighted import DatasetWeighted @@ -9153,9 +9154,7 @@ def pad( stat_length: ( int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None ) = None, - constant_values: ( - float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None - ) = None, + constant_values: T_DatasetPadConstantValues | None = None, end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, @@ -9211,17 +9210,19 @@ def pad( (stat_length,) or int is a shortcut for before = after = statistic length for all axes. Default is ``None``, to use the entire axis. - constant_values : scalar, tuple or mapping of hashable to tuple, default: 0 - Used in 'constant'. The values to set the padded values for each - axis. + constant_values : scalar, tuple, mapping of dim name to scalar or tuple, or \ + mapping of var name to scalar, tuple or to mapping of dim name to scalar or tuple, default: None + Used in 'constant'. The values to set the padded values for each data variable / axis. + ``{var_1: {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}, ... + var_M: (before, after)}`` unique pad constants per data variable. ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique pad constants along each dimension. ``((before, after),)`` yields same before and after constants for each dimension. ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for all dimensions. - Default is 0. - end_values : scalar, tuple or mapping of hashable to tuple, default: 0 + Default is ``None``, pads with ``np.nan``. + end_values : scalar, tuple or mapping of hashable to tuple, default: None Used in 'linear_ramp'. The values used for the ending value of the linear_ramp and that will form the edge of the padded array. ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique @@ -9230,7 +9231,7 @@ def pad( axis. ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for all axes. - Default is 0. + Default is None. reflect_type : {"even", "odd", None}, optional Used in "reflect", and "symmetric". The "even" style is the default with an unaltered reflection around the edge value. For @@ -9304,11 +9305,22 @@ def pad( if not var_pad_width: variables[name] = var elif name in self.data_vars: + if utils.is_dict_like(constant_values): + if name in constant_values.keys(): + filtered_constant_values = constant_values[name] + elif not set(var.dims).isdisjoint(constant_values.keys()): + filtered_constant_values = { + k: v for k, v in constant_values.items() if k in var.dims + } + else: + filtered_constant_values = 0 # TODO: https://github.com/pydata/xarray/pull/9353#discussion_r1724018352 + else: + filtered_constant_values = constant_values variables[name] = var.pad( pad_width=var_pad_width, mode=mode, stat_length=stat_length, - constant_values=constant_values, + constant_values=filtered_constant_values, end_values=end_values, reflect_type=reflect_type, keep_attrs=keep_attrs, diff --git a/xarray/core/types.py b/xarray/core/types.py index 0e432283ba9..3eb97f86c4a 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -243,6 +243,11 @@ def copy( "symmetric", "wrap", ] +T_PadConstantValues = float | tuple[float, float] +T_VarPadConstantValues = T_PadConstantValues | Mapping[Any, T_PadConstantValues] +T_DatasetPadConstantValues = ( + T_VarPadConstantValues | Mapping[Any, T_VarPadConstantValues] +) PadReflectOptions = Literal["even", "odd", None] CFCalendar = Literal[ diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 3cd8e4acbd5..a74fb4d8ce9 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -65,6 +65,7 @@ Self, T_Chunks, T_DuckArray, + T_VarPadConstantValues, ) from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -1121,9 +1122,14 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): def _pad_options_dim_to_index( self, - pad_option: Mapping[Any, int | tuple[int, int]], + pad_option: Mapping[Any, int | float | tuple[int, int] | tuple[float, float]], fill_with_shape=False, ): + # change number values to a tuple of two of those values + for k, v in pad_option.items(): + if isinstance(v, numbers.Number): + pad_option[k] = (v, v) + if fill_with_shape: return [ (n, n) if d not in pad_option else pad_option[d] @@ -1138,9 +1144,7 @@ def pad( stat_length: ( int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None ) = None, - constant_values: ( - float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None - ) = None, + constant_values: T_VarPadConstantValues | None = None, end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None, reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, @@ -1160,7 +1164,7 @@ def pad( stat_length : int, tuple or mapping of hashable to tuple Used in 'maximum', 'mean', 'median', and 'minimum'. Number of values at edge of each axis used to calculate the statistic value. - constant_values : scalar, tuple or mapping of hashable to tuple + constant_values : scalar, tuple or mapping of hashable to scalar or tuple Used in 'constant'. The values to set the padded values for each axis. end_values : scalar, tuple or mapping of hashable to tuple @@ -1207,10 +1211,6 @@ def pad( if stat_length is None and mode in ["maximum", "mean", "median", "minimum"]: stat_length = [(n, n) for n in self.data.shape] # type: ignore[assignment] - # change integer values to a tuple of two of those values and change pad_width to index - for k, v in pad_width.items(): - if isinstance(v, numbers.Number): - pad_width[k] = (v, v) pad_width_by_index = self._pad_options_dim_to_index(pad_width) # create pad_options_kwargs, numpy/dask requires only relevant kwargs to be nonempty diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fb3d487f2ef..f2e712e334c 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6704,18 +6704,80 @@ def test_polyfit_warnings(self) -> None: ds.var1.polyfit("dim2", 10, full=True) assert len(ws) == 1 - def test_pad(self) -> None: - ds = create_test_data(seed=1) - padded = ds.pad(dim2=(1, 1), constant_values=42) - - assert padded["dim2"].shape == (11,) - assert padded["var1"].shape == (8, 11) - assert padded["var2"].shape == (8, 11) - assert padded["var3"].shape == (10, 8) - assert dict(padded.sizes) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20} + @staticmethod + def _test_data_var_interior( + original_data_var, padded_data_var, padded_dim_name, expected_pad_values + ): + np.testing.assert_equal( + np.unique(padded_data_var.isel({padded_dim_name: [0, -1]})), + expected_pad_values, + ) + np.testing.assert_array_equal( + padded_data_var.isel({padded_dim_name: slice(1, -1)}), original_data_var + ) - np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42) - np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan) + @pytest.mark.parametrize("padded_dim_name", ["dim1", "dim2", "dim3", "time"]) + @pytest.mark.parametrize( + ["constant_values"], + [ + pytest.param(None, id="default"), + pytest.param(42, id="scalar"), + pytest.param((42, 43), id="tuple"), + pytest.param({"dim1": 42, "dim2": 43}, id="per dim scalar"), + pytest.param({"dim1": (42, 43), "dim2": (43, 44)}, id="per dim tuple"), + pytest.param({"var1": 42, "var2": (42, 43)}, id="per var"), + pytest.param({"var1": 42, "dim1": (42, 43)}, id="mixed"), + ], + ) + def test_pad(self, padded_dim_name, constant_values) -> None: + ds = create_test_data(seed=1) + padded = ds.pad({padded_dim_name: (1, 1)}, constant_values=constant_values) + + # test padded dim values and size + for ds_dim_name, ds_dim in ds.sizes.items(): + if ds_dim_name == padded_dim_name: + np.testing.assert_equal(padded.sizes[ds_dim_name], ds_dim + 2) + if ds_dim_name in padded.coords: + assert padded[ds_dim_name][[0, -1]].isnull().all() + else: + np.testing.assert_equal(padded.sizes[ds_dim_name], ds_dim) + + # check if coord "numbers" with dimention dim3 is paded correctly + if padded_dim_name == "dim3": + assert padded["numbers"][[0, -1]].isnull().all() + # twarning: passes but dtype changes from int to float + np.testing.assert_array_equal(padded["numbers"][1:-1], ds["numbers"]) + + # test if data_vars are paded with correct values + for data_var_name, data_var in padded.data_vars.items(): + if padded_dim_name in data_var.dims: + if utils.is_dict_like(constant_values): + if ( + expected := constant_values.get(data_var_name, None) + ) is not None: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, expected + ) + elif ( + expected := constant_values.get(padded_dim_name, None) + ) is not None: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, expected + ) + else: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, 0 + ) + elif constant_values: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, constant_values + ) + else: + self._test_data_var_interior( + ds[data_var_name], data_var, padded_dim_name, np.nan + ) + else: + assert_array_equal(data_var, ds[data_var_name]) @pytest.mark.parametrize( ["keep_attrs", "attrs", "expected"],