Skip to content

Commit

Permalink
Extend padding functionalities (#9353)
Browse files Browse the repository at this point in the history
* add functionality to pad dataset data variables with unique constant values

* clean up implementation of variable specific padding for dataset. Add tests

* more expressive docsting and symplefy type signature with alias in dataset pad func. enforce number values to be converted to tuples for all in `_pad_options_dim_to_index`. make variable pad funtion consistent with dataset. extend tests

* fix typing

* add terms to conf.py, make docstrings more accurate, expand tests for dataset pad function

* filter constant value types without mutating input map

* add todo to change default padding for missing variables in constant_values

* add changes to whats new

---------

Co-authored-by: tiago <[email protected]>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
3 people authored Aug 21, 2024
1 parent a56a407 commit ed5900b
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 30 deletions.
3 changes: 3 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@
"matplotlib colormap name": ":doc:`matplotlib colormap name <matplotlib:gallery/color/colormap_reference>`",
"matplotlib axes object": ":py:class:`matplotlib axes object <matplotlib.axes.Axes>`",
"colormap": ":py:class:`colormap <matplotlib.colors.Colormap>`",
# xarray terms
"dim name": ":term:`dimension name <name>`",
"var name": ":term:`variable name <name>`",
# objects without namespace: xarray
"DataArray": "~xarray.DataArray",
"Dataset": "~xarray.Dataset",
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ New Features
~~~~~~~~~~~~
- Make chunk manager an option in ``set_options`` (:pull:`9362`).
By `Tom White <https://github.com/tomwhite>`_.
- Allow data variable specific ``constant_values`` in the dataset ``pad`` function (:pull:`9353``).
By `Tiago Sanona <https://github.com/tsanona>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
32 changes: 22 additions & 10 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
ReindexMethodOptions,
SideOptions,
T_ChunkDimFreq,
T_DatasetPadConstantValues,
T_Xarray,
)
from xarray.core.weighted import DatasetWeighted
Expand Down Expand Up @@ -9153,9 +9154,7 @@ def pad(
stat_length: (
int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None
) = None,
constant_values: (
float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None
) = None,
constant_values: T_DatasetPadConstantValues | None = None,
end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None,
reflect_type: PadReflectOptions = None,
keep_attrs: bool | None = None,
Expand Down Expand Up @@ -9211,17 +9210,19 @@ def pad(
(stat_length,) or int is a shortcut for before = after = statistic
length for all axes.
Default is ``None``, to use the entire axis.
constant_values : scalar, tuple or mapping of hashable to tuple, default: 0
Used in 'constant'. The values to set the padded values for each
axis.
constant_values : scalar, tuple, mapping of dim name to scalar or tuple, or \
mapping of var name to scalar, tuple or to mapping of dim name to scalar or tuple, default: None
Used in 'constant'. The values to set the padded values for each data variable / axis.
``{var_1: {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}, ...
var_M: (before, after)}`` unique pad constants per data variable.
``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique
pad constants along each dimension.
``((before, after),)`` yields same before and after constants for each
dimension.
``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for
all dimensions.
Default is 0.
end_values : scalar, tuple or mapping of hashable to tuple, default: 0
Default is ``None``, pads with ``np.nan``.
end_values : scalar, tuple or mapping of hashable to tuple, default: None
Used in 'linear_ramp'. The values used for the ending value of the
linear_ramp and that will form the edge of the padded array.
``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique
Expand All @@ -9230,7 +9231,7 @@ def pad(
axis.
``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for
all axes.
Default is 0.
Default is None.
reflect_type : {"even", "odd", None}, optional
Used in "reflect", and "symmetric". The "even" style is the
default with an unaltered reflection around the edge value. For
Expand Down Expand Up @@ -9304,11 +9305,22 @@ def pad(
if not var_pad_width:
variables[name] = var
elif name in self.data_vars:
if utils.is_dict_like(constant_values):
if name in constant_values.keys():
filtered_constant_values = constant_values[name]
elif not set(var.dims).isdisjoint(constant_values.keys()):
filtered_constant_values = {
k: v for k, v in constant_values.items() if k in var.dims
}
else:
filtered_constant_values = 0 # TODO: https://github.com/pydata/xarray/pull/9353#discussion_r1724018352
else:
filtered_constant_values = constant_values
variables[name] = var.pad(
pad_width=var_pad_width,
mode=mode,
stat_length=stat_length,
constant_values=constant_values,
constant_values=filtered_constant_values,
end_values=end_values,
reflect_type=reflect_type,
keep_attrs=keep_attrs,
Expand Down
5 changes: 5 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ def copy(
"symmetric",
"wrap",
]
T_PadConstantValues = float | tuple[float, float]
T_VarPadConstantValues = T_PadConstantValues | Mapping[Any, T_PadConstantValues]
T_DatasetPadConstantValues = (
T_VarPadConstantValues | Mapping[Any, T_VarPadConstantValues]
)
PadReflectOptions = Literal["even", "odd", None]

CFCalendar = Literal[
Expand Down
18 changes: 9 additions & 9 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
Self,
T_Chunks,
T_DuckArray,
T_VarPadConstantValues,
)
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint

Expand Down Expand Up @@ -1121,9 +1122,14 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):

def _pad_options_dim_to_index(
self,
pad_option: Mapping[Any, int | tuple[int, int]],
pad_option: Mapping[Any, int | float | tuple[int, int] | tuple[float, float]],
fill_with_shape=False,
):
# change number values to a tuple of two of those values
for k, v in pad_option.items():
if isinstance(v, numbers.Number):
pad_option[k] = (v, v)

if fill_with_shape:
return [
(n, n) if d not in pad_option else pad_option[d]
Expand All @@ -1138,9 +1144,7 @@ def pad(
stat_length: (
int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None
) = None,
constant_values: (
float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None
) = None,
constant_values: T_VarPadConstantValues | None = None,
end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None,
reflect_type: PadReflectOptions = None,
keep_attrs: bool | None = None,
Expand All @@ -1160,7 +1164,7 @@ def pad(
stat_length : int, tuple or mapping of hashable to tuple
Used in 'maximum', 'mean', 'median', and 'minimum'. Number of
values at edge of each axis used to calculate the statistic value.
constant_values : scalar, tuple or mapping of hashable to tuple
constant_values : scalar, tuple or mapping of hashable to scalar or tuple
Used in 'constant'. The values to set the padded values for each
axis.
end_values : scalar, tuple or mapping of hashable to tuple
Expand Down Expand Up @@ -1207,10 +1211,6 @@ def pad(
if stat_length is None and mode in ["maximum", "mean", "median", "minimum"]:
stat_length = [(n, n) for n in self.data.shape] # type: ignore[assignment]

# change integer values to a tuple of two of those values and change pad_width to index
for k, v in pad_width.items():
if isinstance(v, numbers.Number):
pad_width[k] = (v, v)
pad_width_by_index = self._pad_options_dim_to_index(pad_width)

# create pad_options_kwargs, numpy/dask requires only relevant kwargs to be nonempty
Expand Down
84 changes: 73 additions & 11 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6704,18 +6704,80 @@ def test_polyfit_warnings(self) -> None:
ds.var1.polyfit("dim2", 10, full=True)
assert len(ws) == 1

def test_pad(self) -> None:
ds = create_test_data(seed=1)
padded = ds.pad(dim2=(1, 1), constant_values=42)

assert padded["dim2"].shape == (11,)
assert padded["var1"].shape == (8, 11)
assert padded["var2"].shape == (8, 11)
assert padded["var3"].shape == (10, 8)
assert dict(padded.sizes) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20}
@staticmethod
def _test_data_var_interior(
original_data_var, padded_data_var, padded_dim_name, expected_pad_values
):
np.testing.assert_equal(
np.unique(padded_data_var.isel({padded_dim_name: [0, -1]})),
expected_pad_values,
)
np.testing.assert_array_equal(
padded_data_var.isel({padded_dim_name: slice(1, -1)}), original_data_var
)

np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42)
np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan)
@pytest.mark.parametrize("padded_dim_name", ["dim1", "dim2", "dim3", "time"])
@pytest.mark.parametrize(
["constant_values"],
[
pytest.param(None, id="default"),
pytest.param(42, id="scalar"),
pytest.param((42, 43), id="tuple"),
pytest.param({"dim1": 42, "dim2": 43}, id="per dim scalar"),
pytest.param({"dim1": (42, 43), "dim2": (43, 44)}, id="per dim tuple"),
pytest.param({"var1": 42, "var2": (42, 43)}, id="per var"),
pytest.param({"var1": 42, "dim1": (42, 43)}, id="mixed"),
],
)
def test_pad(self, padded_dim_name, constant_values) -> None:
ds = create_test_data(seed=1)
padded = ds.pad({padded_dim_name: (1, 1)}, constant_values=constant_values)

# test padded dim values and size
for ds_dim_name, ds_dim in ds.sizes.items():
if ds_dim_name == padded_dim_name:
np.testing.assert_equal(padded.sizes[ds_dim_name], ds_dim + 2)
if ds_dim_name in padded.coords:
assert padded[ds_dim_name][[0, -1]].isnull().all()
else:
np.testing.assert_equal(padded.sizes[ds_dim_name], ds_dim)

# check if coord "numbers" with dimention dim3 is paded correctly
if padded_dim_name == "dim3":
assert padded["numbers"][[0, -1]].isnull().all()
# twarning: passes but dtype changes from int to float
np.testing.assert_array_equal(padded["numbers"][1:-1], ds["numbers"])

# test if data_vars are paded with correct values
for data_var_name, data_var in padded.data_vars.items():
if padded_dim_name in data_var.dims:
if utils.is_dict_like(constant_values):
if (
expected := constant_values.get(data_var_name, None)
) is not None:
self._test_data_var_interior(
ds[data_var_name], data_var, padded_dim_name, expected
)
elif (
expected := constant_values.get(padded_dim_name, None)
) is not None:
self._test_data_var_interior(
ds[data_var_name], data_var, padded_dim_name, expected
)
else:
self._test_data_var_interior(
ds[data_var_name], data_var, padded_dim_name, 0
)
elif constant_values:
self._test_data_var_interior(
ds[data_var_name], data_var, padded_dim_name, constant_values
)
else:
self._test_data_var_interior(
ds[data_var_name], data_var, padded_dim_name, np.nan
)
else:
assert_array_equal(data_var, ds[data_var_name])

@pytest.mark.parametrize(
["keep_attrs", "attrs", "expected"],
Expand Down

0 comments on commit ed5900b

Please sign in to comment.