Skip to content

Commit

Permalink
Implement setitem syntax for .oindex and .vindex properties (#8845)
Browse files Browse the repository at this point in the history
* Implement setitem syntax for `.oindex` and `.vindex` properties

* Apply suggestions from code review

Co-authored-by: Deepak Cherian <[email protected]>

* use getter and setter properties instead of func_get and func_set methods

* delete unnecessary _indexing_array_and_key method

* Add tests for IndexCallable class

* fix bug/unnecessary code introduced in #8790

* add unit tests

---------

Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
andersy005 and dcherian authored Mar 19, 2024
1 parent c6c01b1 commit 79272c3
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 67 deletions.
171 changes: 114 additions & 57 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,18 +326,23 @@ def as_integer_slice(value):


class IndexCallable:
"""Provide getitem syntax for a callable object."""
"""Provide getitem and setitem syntax for callable objects."""

__slots__ = ("func",)
__slots__ = ("getter", "setter")

def __init__(self, func):
self.func = func
def __init__(self, getter, setter=None):
self.getter = getter
self.setter = setter

def __getitem__(self, key):
return self.func(key)
return self.getter(key)

def __setitem__(self, key, value):
raise NotImplementedError
if self.setter is None:
raise NotImplementedError(
"Setting values is not supported for this indexer."
)
self.setter(key, value)


class BasicIndexer(ExplicitIndexer):
Expand Down Expand Up @@ -486,10 +491,24 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
return np.asarray(self.get_duck_array(), dtype=dtype)

def _oindex_get(self, key):
raise NotImplementedError("This method should be overridden")
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_get method should be overridden"
)

def _vindex_get(self, key):
raise NotImplementedError("This method should be overridden")
raise NotImplementedError(
f"{self.__class__.__name__}._vindex_get method should be overridden"
)

def _oindex_set(self, key, value):
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_set method should be overridden"
)

def _vindex_set(self, key, value):
raise NotImplementedError(
f"{self.__class__.__name__}._vindex_set method should be overridden"
)

def _check_and_raise_if_non_basic_indexer(self, key):
if isinstance(key, (VectorizedIndexer, OuterIndexer)):
Expand All @@ -500,11 +519,11 @@ def _check_and_raise_if_non_basic_indexer(self, key):

@property
def oindex(self):
return IndexCallable(self._oindex_get)
return IndexCallable(self._oindex_get, self._oindex_set)

@property
def vindex(self):
return IndexCallable(self._vindex_get)
return IndexCallable(self._vindex_get, self._vindex_set)


class ImplicitToExplicitIndexingAdapter(NDArrayMixin):
Expand Down Expand Up @@ -616,12 +635,18 @@ def __getitem__(self, indexer):
self._check_and_raise_if_non_basic_indexer(indexer)
return type(self)(self.array, self._updated_key(indexer))

def _vindex_set(self, key, value):
raise NotImplementedError(
"Lazy item assignment with the vectorized indexer is not yet "
"implemented. Load your data first by .load() or compute()."
)

def _oindex_set(self, key, value):
full_key = self._updated_key(key)
self.array.oindex[full_key] = value

def __setitem__(self, key, value):
if isinstance(key, VectorizedIndexer):
raise NotImplementedError(
"Lazy item assignment with the vectorized indexer is not yet "
"implemented. Load your data first by .load() or compute()."
)
self._check_and_raise_if_non_basic_indexer(key)
full_key = self._updated_key(key)
self.array[full_key] = value

Expand Down Expand Up @@ -657,7 +682,6 @@ def shape(self) -> tuple[int, ...]:
return np.broadcast(*self.key.tuple).shape

def get_duck_array(self):

if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
array = apply_indexer(self.array, self.key)
else:
Expand Down Expand Up @@ -739,8 +763,18 @@ def __getitem__(self, key):
def transpose(self, order):
return self.array.transpose(order)

def _vindex_set(self, key, value):
self._ensure_copied()
self.array.vindex[key] = value

def _oindex_set(self, key, value):
self._ensure_copied()
self.array.oindex[key] = value

def __setitem__(self, key, value):
self._check_and_raise_if_non_basic_indexer(key)
self._ensure_copied()

self.array[key] = value

def __deepcopy__(self, memo):
Expand Down Expand Up @@ -779,7 +813,14 @@ def __getitem__(self, key):
def transpose(self, order):
return self.array.transpose(order)

def _vindex_set(self, key, value):
self.array.vindex[key] = value

def _oindex_set(self, key, value):
self.array.oindex[key] = value

def __setitem__(self, key, value):
self._check_and_raise_if_non_basic_indexer(key)
self.array[key] = value


Expand Down Expand Up @@ -950,6 +991,16 @@ def apply_indexer(indexable, indexer):
return indexable[indexer]


def set_with_indexer(indexable, indexer, value):
"""Set values in an indexable object using an indexer."""
if isinstance(indexer, VectorizedIndexer):
indexable.vindex[indexer] = value
elif isinstance(indexer, OuterIndexer):
indexable.oindex[indexer] = value
else:
indexable[indexer] = value


def decompose_indexer(
indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
Expand Down Expand Up @@ -1399,24 +1450,6 @@ def __init__(self, array):
)
self.array = array

def _indexing_array_and_key(self, key):
if isinstance(key, OuterIndexer):
array = self.array
key = _outer_to_numpy_indexer(key, self.array.shape)
elif isinstance(key, VectorizedIndexer):
array = NumpyVIndexAdapter(self.array)
key = key.tuple
elif isinstance(key, BasicIndexer):
array = self.array
# We want 0d slices rather than scalars. This is achieved by
# appending an ellipsis (see
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
key = key.tuple + (Ellipsis,)
else:
raise TypeError(f"unexpected key type: {type(key)}")

return array, key

def transpose(self, order):
return self.array.transpose(order)

Expand All @@ -1430,22 +1463,43 @@ def _vindex_get(self, key):

def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
array, key = self._indexing_array_and_key(key)

array = self.array
# We want 0d slices rather than scalars. This is achieved by
# appending an ellipsis (see
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
key = key.tuple + (Ellipsis,)
return array[key]

def __setitem__(self, key, value):
array, key = self._indexing_array_and_key(key)
def _safe_setitem(self, array, key, value):
try:
array[key] = value
except ValueError:
except ValueError as exc:
# More informative exception if read-only view
if not array.flags.writeable and not array.flags.owndata:
raise ValueError(
"Assignment destination is a view. "
"Do you want to .copy() array first?"
)
else:
raise
raise exc

def _oindex_set(self, key, value):
key = _outer_to_numpy_indexer(key, self.array.shape)
self._safe_setitem(self.array, key, value)

def _vindex_set(self, key, value):
array = NumpyVIndexAdapter(self.array)
self._safe_setitem(array, key.tuple, value)

def __setitem__(self, key, value):
self._check_and_raise_if_non_basic_indexer(key)
array = self.array
# We want 0d slices rather than scalars. This is achieved by
# appending an ellipsis (see
# https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
key = key.tuple + (Ellipsis,)
self._safe_setitem(array, key, value)


class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter):
Expand Down Expand Up @@ -1488,13 +1542,15 @@ def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
return self.array[key.tuple]

def _oindex_set(self, key, value):
self.array[key.tuple] = value

def _vindex_set(self, key, value):
raise TypeError("Vectorized indexing is not supported")

def __setitem__(self, key, value):
if isinstance(key, (BasicIndexer, OuterIndexer)):
self.array[key.tuple] = value
elif isinstance(key, VectorizedIndexer):
raise TypeError("Vectorized indexing is not supported")
else:
raise TypeError(f"Unrecognized indexer: {key}")
self._check_and_raise_if_non_basic_indexer(key)
self.array[key.tuple] = value

def transpose(self, order):
xp = self.array.__array_namespace__()
Expand Down Expand Up @@ -1530,19 +1586,20 @@ def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
return self.array[key.tuple]

def _oindex_set(self, key, value):
num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple)
if num_non_slices > 1:
raise NotImplementedError(
"xarray can't set arrays with multiple " "array indices to dask yet."
)
self.array[key.tuple] = value

def _vindex_set(self, key, value):
self.array.vindex[key.tuple] = value

def __setitem__(self, key, value):
if isinstance(key, BasicIndexer):
self.array[key.tuple] = value
elif isinstance(key, VectorizedIndexer):
self.array.vindex[key.tuple] = value
elif isinstance(key, OuterIndexer):
num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in key.tuple)
if num_non_slices > 1:
raise NotImplementedError(
"xarray can't set arrays with multiple "
"array indices to dask yet."
)
self.array[key.tuple] = value
self._check_and_raise_if_non_basic_indexer(key)
self.array[key.tuple] = value

def transpose(self, order):
return self.array.transpose(order)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def __setitem__(self, key, value):
value = np.moveaxis(value, new_order, range(len(new_order)))

indexable = as_indexable(self._data)
indexable[index_tuple] = value
indexing.set_with_indexer(indexable, index_tuple, value)

@property
def encoding(self) -> dict[Any, Any]:
Expand Down
68 changes: 59 additions & 9 deletions xarray/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,28 @@
B = IndexerMaker(indexing.BasicIndexer)


class TestIndexCallable:
def test_getitem(self):
def getter(key):
return key * 2

indexer = indexing.IndexCallable(getter)
assert indexer[3] == 6
assert indexer[0] == 0
assert indexer[-1] == -2

def test_setitem(self):
def getter(key):
return key * 2

def setter(key, value):
raise NotImplementedError("Setter not implemented")

indexer = indexing.IndexCallable(getter, setter)
with pytest.raises(NotImplementedError):
indexer[3] = 6


class TestIndexers:
def set_to_zero(self, x, i):
x = x.copy()
Expand Down Expand Up @@ -361,15 +383,8 @@ def test_vectorized_lazily_indexed_array(self) -> None:

def check_indexing(v_eager, v_lazy, indexers):
for indexer in indexers:
if isinstance(indexer, indexing.VectorizedIndexer):
actual = v_lazy.vindex[indexer]
expected = v_eager.vindex[indexer]
elif isinstance(indexer, indexing.OuterIndexer):
actual = v_lazy.oindex[indexer]
expected = v_eager.oindex[indexer]
else:
actual = v_lazy[indexer]
expected = v_eager[indexer]
actual = v_lazy[indexer]
expected = v_eager[indexer]
assert expected.shape == actual.shape
assert isinstance(
actual._data,
Expand Down Expand Up @@ -406,6 +421,41 @@ def check_indexing(v_eager, v_lazy, indexers):
]
check_indexing(v_eager, v_lazy, indexers)

def test_lazily_indexed_array_vindex_setitem(self) -> None:

lazy = indexing.LazilyIndexedArray(np.random.rand(10, 20, 30))

# vectorized indexing
indexer = indexing.VectorizedIndexer(
(np.array([0, 1]), np.array([0, 1]), slice(None, None, None))
)
with pytest.raises(
NotImplementedError,
match=r"Lazy item assignment with the vectorized indexer is not yet",
):
lazy.vindex[indexer] = 0

@pytest.mark.parametrize(
"indexer_class, key, value",
[
(indexing.OuterIndexer, (0, 1, slice(None, None, None)), 10),
(indexing.BasicIndexer, (0, 1, slice(None, None, None)), 10),
],
)
def test_lazily_indexed_array_setitem(self, indexer_class, key, value) -> None:
original = np.random.rand(10, 20, 30)
x = indexing.NumpyIndexingAdapter(original)
lazy = indexing.LazilyIndexedArray(x)

if indexer_class is indexing.BasicIndexer:
indexer = indexer_class(key)
lazy[indexer] = value
elif indexer_class is indexing.OuterIndexer:
indexer = indexer_class(key)
lazy.oindex[indexer] = value

assert_array_equal(original[key], value)


class TestCopyOnWriteArray:
def test_setitem(self) -> None:
Expand Down

0 comments on commit 79272c3

Please sign in to comment.