Skip to content

Commit

Permalink
Expand use of .oindex and .vindex (#8790)
Browse files Browse the repository at this point in the history
* refactor __getitem__() by removing vectorized and orthogonal indexing logic from it

* implement explicit routing of vectorized and outer indexers

* Add VectorizedIndexer and OuterIndexer to ScipyArrayWrapper's __getitem__ method

* Refactor indexing in LazilyIndexedArray and LazilyVectorizedIndexedArray

* Add vindex and oindex methods to StackedBytesArray

* handle explicitlyindexed arrays

* Refactor LazilyIndexedArray and LazilyVectorizedIndexedArray classes

* Remove TODO comments in indexing.py

* use indexing.explicit_indexing_adapter() in scipy backend

* update docstring

* fix docstring

* Add _oindex_get and _vindex_get methods to NativeEndiannessArray and BoolTypeArray

* Update indexing support in ScipyArrayWrapper

* Update xarray/tests/test_indexing.py

Co-authored-by: Deepak Cherian <[email protected]>

* Fix assert statement in test_decompose_indexers

* add comments to clarifying why the else branch is needed

* Add _oindex_get and _vindex_get methods to _ElementwiseFunctionArray

* update whats-new

* Refactor apply_indexer function in indexing.py and variable.py for code reuse

* cleanup

---------

Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
3 people authored Mar 15, 2024
1 parent 3dcfa31 commit c9d3084
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 55 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ New Features
- Add the ``.vindex`` property to Explicitly Indexed Arrays for vectorized indexing functionality. (:issue:`8238`, :pull:`8780`)
By `Anderson Banihirwe <https://github.com/andersy005>`_.

- Expand use of ``.oindex`` and ``.vindex`` properties. (:pull: `8790`)
By `Anderson Banihirwe <https://github.com/andersy005>`_ and `Deepak Cherian <https://github.com/dcherian>`_.

Breaking changes
~~~~~~~~~~~~~~~~

Expand Down
11 changes: 9 additions & 2 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
is_valid_nc3_name,
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core.indexing import NumpyIndexingAdapter
from xarray.core import indexing
from xarray.core.utils import (
Frozen,
FrozenDict,
Expand Down Expand Up @@ -63,8 +63,15 @@ def get_variable(self, needs_lock=True):
ds = self.datastore._manager.acquire(needs_lock)
return ds.variables[self.variable_name]

def _getitem(self, key):
with self.datastore.lock:
data = self.get_variable(needs_lock=False).data
return data[key]

def __getitem__(self, key):
data = NumpyIndexingAdapter(self.get_variable().data)[key]
data = indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
# Copy data if the source file is mmapped. This makes things consistent
# with the netCDF4 library by ensuring we can safely read arrays even
# after closing associated files.
Expand Down
6 changes: 6 additions & 0 deletions xarray/coding/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def shape(self) -> tuple[int, ...]:
def __repr__(self):
return f"{type(self).__name__}({self.array!r})"

def _vindex_get(self, key):
return _numpy_char_to_bytes(self.array.vindex[key])

def _oindex_get(self, key):
return _numpy_char_to_bytes(self.array.oindex[key])

def __getitem__(self, key):
# require slicing the last dimension completely
key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim))
Expand Down
18 changes: 18 additions & 0 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ def __init__(self, array, func: Callable, dtype: np.typing.DTypeLike):
def dtype(self) -> np.dtype:
return np.dtype(self._dtype)

def _oindex_get(self, key):
return type(self)(self.array.oindex[key], self.func, self.dtype)

def _vindex_get(self, key):
return type(self)(self.array.vindex[key], self.func, self.dtype)

def __getitem__(self, key):
return type(self)(self.array[key], self.func, self.dtype)

Expand Down Expand Up @@ -109,6 +115,12 @@ def __init__(self, array) -> None:
def dtype(self) -> np.dtype:
return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize))

def _oindex_get(self, key):
return np.asarray(self.array.oindex[key], dtype=self.dtype)

def _vindex_get(self, key):
return np.asarray(self.array.vindex[key], dtype=self.dtype)

def __getitem__(self, key) -> np.ndarray:
return np.asarray(self.array[key], dtype=self.dtype)

Expand Down Expand Up @@ -141,6 +153,12 @@ def __init__(self, array) -> None:
def dtype(self) -> np.dtype:
return np.dtype("bool")

def _oindex_get(self, key):
return np.asarray(self.array.oindex[key], dtype=self.dtype)

def _vindex_get(self, key):
return np.asarray(self.array.vindex[key], dtype=self.dtype)

def __getitem__(self, key) -> np.ndarray:
return np.asarray(self.array[key], dtype=self.dtype)

Expand Down
88 changes: 56 additions & 32 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,13 @@ def _oindex_get(self, key):
def _vindex_get(self, key):
raise NotImplementedError("This method should be overridden")

def _check_and_raise_if_non_basic_indexer(self, key):
if isinstance(key, (VectorizedIndexer, OuterIndexer)):
raise TypeError(
"Vectorized indexing with vectorized or outer indexers is not supported. "
"Please use .vindex and .oindex properties to index the array."
)

@property
def oindex(self):
return IndexCallable(self._oindex_get)
Expand All @@ -517,7 +524,10 @@ def get_duck_array(self):

def __getitem__(self, key):
key = expanded_indexer(key, self.ndim)
result = self.array[self.indexer_cls(key)]
indexer = self.indexer_cls(key)

result = apply_indexer(self.array, indexer)

if isinstance(result, ExplicitlyIndexed):
return type(self)(result, self.indexer_cls)
else:
Expand Down Expand Up @@ -577,7 +587,13 @@ def shape(self) -> tuple[int, ...]:
return tuple(shape)

def get_duck_array(self):
array = self.array[self.key]
if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
array = apply_indexer(self.array, self.key)
else:
# If the array is not an ExplicitlyIndexedNDArrayMixin,
# it may wrap a BackendArray so use its __getitem__
array = self.array[self.key]

# self.array[self.key] is now a numpy array when
# self.array is a BackendArray subclass
# and self.key is BasicIndexer((slice(None, None, None),))
Expand All @@ -594,12 +610,10 @@ def _oindex_get(self, indexer):

def _vindex_get(self, indexer):
array = LazilyVectorizedIndexedArray(self.array, self.key)
return array[indexer]
return array.vindex[indexer]

def __getitem__(self, indexer):
if isinstance(indexer, VectorizedIndexer):
array = LazilyVectorizedIndexedArray(self.array, self.key)
return array[indexer]
self._check_and_raise_if_non_basic_indexer(indexer)
return type(self)(self.array, self._updated_key(indexer))

def __setitem__(self, key, value):
Expand Down Expand Up @@ -643,7 +657,13 @@ def shape(self) -> tuple[int, ...]:
return np.broadcast(*self.key.tuple).shape

def get_duck_array(self):
array = self.array[self.key]

if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
array = apply_indexer(self.array, self.key)
else:
# If the array is not an ExplicitlyIndexedNDArrayMixin,
# it may wrap a BackendArray so use its __getitem__
array = self.array[self.key]
# self.array[self.key] is now a numpy array when
# self.array is a BackendArray subclass
# and self.key is BasicIndexer((slice(None, None, None),))
Expand All @@ -662,6 +682,7 @@ def _vindex_get(self, indexer):
return type(self)(self.array, self._updated_key(indexer))

def __getitem__(self, indexer):
self._check_and_raise_if_non_basic_indexer(indexer)
# If the indexed array becomes a scalar, return LazilyIndexedArray
if all(isinstance(ind, integer_types) for ind in indexer.tuple):
key = BasicIndexer(tuple(k[indexer.tuple] for k in self.key.tuple))
Expand Down Expand Up @@ -706,12 +727,13 @@ def get_duck_array(self):
return self.array.get_duck_array()

def _oindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))
return type(self)(_wrap_numpy_scalars(self.array.oindex[key]))

def _vindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))
return type(self)(_wrap_numpy_scalars(self.array.vindex[key]))

def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
return type(self)(_wrap_numpy_scalars(self.array[key]))

def transpose(self, order):
Expand Down Expand Up @@ -745,12 +767,13 @@ def get_duck_array(self):
return self.array.get_duck_array()

def _oindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))
return type(self)(_wrap_numpy_scalars(self.array.oindex[key]))

def _vindex_get(self, key):
return type(self)(_wrap_numpy_scalars(self.array[key]))
return type(self)(_wrap_numpy_scalars(self.array.vindex[key]))

def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
return type(self)(_wrap_numpy_scalars(self.array[key]))

def transpose(self, order):
Expand Down Expand Up @@ -912,10 +935,21 @@ def explicit_indexing_adapter(
result = raw_indexing_method(raw_key.tuple)
if numpy_indices.tuple:
# index the loaded np.ndarray
result = NumpyIndexingAdapter(result)[numpy_indices]
indexable = NumpyIndexingAdapter(result)
result = apply_indexer(indexable, numpy_indices)
return result


def apply_indexer(indexable, indexer):
"""Apply an indexer to an indexable object."""
if isinstance(indexer, VectorizedIndexer):
return indexable.vindex[indexer]
elif isinstance(indexer, OuterIndexer):
return indexable.oindex[indexer]
else:
return indexable[indexer]


def decompose_indexer(
indexer: ExplicitIndexer, shape: tuple[int, ...], indexing_support: IndexingSupport
) -> tuple[ExplicitIndexer, ExplicitIndexer]:
Expand Down Expand Up @@ -987,10 +1021,10 @@ def _decompose_vectorized_indexer(
>>> array = np.arange(36).reshape(6, 6)
>>> backend_indexer = OuterIndexer((np.array([0, 1, 3]), np.array([2, 3])))
>>> # load subslice of the array
... array = NumpyIndexingAdapter(array)[backend_indexer]
... array = NumpyIndexingAdapter(array).oindex[backend_indexer]
>>> np_indexer = VectorizedIndexer((np.array([0, 2, 1]), np.array([0, 1, 0])))
>>> # vectorized indexing for on-memory np.ndarray.
... NumpyIndexingAdapter(array)[np_indexer]
... NumpyIndexingAdapter(array).vindex[np_indexer]
array([ 2, 21, 8])
"""
assert isinstance(indexer, VectorizedIndexer)
Expand Down Expand Up @@ -1072,7 +1106,7 @@ def _decompose_outer_indexer(
... array = NumpyIndexingAdapter(array)[backend_indexer]
>>> np_indexer = OuterIndexer((np.array([0, 2, 1]), np.array([0, 1, 0])))
>>> # outer indexing for on-memory np.ndarray.
... NumpyIndexingAdapter(array)[np_indexer]
... NumpyIndexingAdapter(array).oindex[np_indexer]
array([[ 2, 3, 2],
[14, 15, 14],
[ 8, 9, 8]])
Expand Down Expand Up @@ -1395,6 +1429,7 @@ def _vindex_get(self, key):
return array[key.tuple]

def __getitem__(self, key):
self._check_and_raise_if_non_basic_indexer(key)
array, key = self._indexing_array_and_key(key)
return array[key]

Expand Down Expand Up @@ -1450,15 +1485,8 @@ def _vindex_get(self, key):
raise TypeError("Vectorized indexing is not supported")

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, OuterIndexer):
return self.oindex[key]
else:
if isinstance(key, VectorizedIndexer):
raise TypeError("Vectorized indexing is not supported")
else:
raise TypeError(f"Unrecognized indexer: {key}")
self._check_and_raise_if_non_basic_indexer(key)
return self.array[key.tuple]

def __setitem__(self, key, value):
if isinstance(key, (BasicIndexer, OuterIndexer)):
Expand Down Expand Up @@ -1499,13 +1527,8 @@ def _vindex_get(self, key):
return self.array.vindex[key.tuple]

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, VectorizedIndexer):
return self.vindex[key]
else:
assert isinstance(key, OuterIndexer)
return self.oindex[key]
self._check_and_raise_if_non_basic_indexer(key)
return self.array[key.tuple]

def __setitem__(self, key, value):
if isinstance(key, BasicIndexer):
Expand Down Expand Up @@ -1603,7 +1626,8 @@ def __getitem__(
(key,) = key

if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
return NumpyIndexingAdapter(np.asarray(self))[indexer]
indexable = NumpyIndexingAdapter(np.asarray(self))
return apply_indexer(indexable, indexer)

result = self.array[key]

Expand Down
17 changes: 4 additions & 13 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,12 +761,8 @@ def __getitem__(self, key) -> Self:
dims, indexer, new_order = self._broadcast_indexes(key)
indexable = as_indexable(self._data)

if isinstance(indexer, OuterIndexer):
data = indexable.oindex[indexer]
elif isinstance(indexer, VectorizedIndexer):
data = indexable.vindex[indexer]
else:
data = indexable[indexer]
data = indexing.apply_indexer(indexable, indexer)

if new_order:
data = np.moveaxis(data, range(len(new_order)), new_order)
return self._finalize_indexing_result(dims, data)
Expand All @@ -791,6 +787,7 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA):
dims, indexer, new_order = self._broadcast_indexes(key)

if self.size:

if is_duck_dask_array(self._data):
# dask's indexing is faster this way; also vindex does not
# support negative indices yet:
Expand All @@ -800,14 +797,8 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA):
actual_indexer = indexer

indexable = as_indexable(self._data)
data = indexing.apply_indexer(indexable, actual_indexer)

if isinstance(indexer, OuterIndexer):
data = indexable.oindex[indexer]

elif isinstance(indexer, VectorizedIndexer):
data = indexable.vindex[indexer]
else:
data = indexable[actual_indexer]
mask = indexing.create_mask(indexer, self.shape, data)
# we need to invert the mask in order to pass data first. This helps
# pint to choose the correct unit
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_coding_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_StackedBytesArray_vectorized_indexing() -> None:

V = IndexerMaker(indexing.VectorizedIndexer)
indexer = V[np.array([[0, 1], [1, 0]])]
actual = stacked[indexer]
actual = stacked.vindex[indexer]
assert_array_equal(actual, expected)


Expand Down
Loading

0 comments on commit c9d3084

Please sign in to comment.