diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9d7eb55071b..e57be1df177 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,8 @@ New Features - Add the ``.oindex`` property to Explicitly Indexed Arrays for orthogonal indexing functionality. (:issue:`8238`, :pull:`8750`) By `Anderson Banihirwe `_. +- Add the ``.vindex`` property to Explicitly Indexed Arrays for vectorized indexing functionality. (:issue:`8238`, :pull:`8780`) + By `Anderson Banihirwe `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 43867bc552b..62889e03861 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -488,10 +488,17 @@ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray: def _oindex_get(self, key): raise NotImplementedError("This method should be overridden") + def _vindex_get(self, key): + raise NotImplementedError("This method should be overridden") + @property def oindex(self): return IndexCallable(self._oindex_get) + @property + def vindex(self): + return IndexCallable(self._vindex_get) + class ImplicitToExplicitIndexingAdapter(NDArrayMixin): """Wrap an array, converting tuples into the indicated explicit indexer.""" @@ -585,6 +592,10 @@ def transpose(self, order): def _oindex_get(self, indexer): return type(self)(self.array, self._updated_key(indexer)) + def _vindex_get(self, indexer): + array = LazilyVectorizedIndexedArray(self.array, self.key) + return array[indexer] + def __getitem__(self, indexer): if isinstance(indexer, VectorizedIndexer): array = LazilyVectorizedIndexedArray(self.array, self.key) @@ -644,6 +655,12 @@ def get_duck_array(self): def _updated_key(self, new_key): return _combine_indexers(self.key, self.shape, new_key) + def _oindex_get(self, indexer): + return type(self)(self.array, self._updated_key(indexer)) + + def _vindex_get(self, indexer): + return type(self)(self.array, self._updated_key(indexer)) + def __getitem__(self, indexer): # If the indexed array becomes a scalar, return LazilyIndexedArray if all(isinstance(ind, integer_types) for ind in indexer.tuple): @@ -691,6 +708,9 @@ def get_duck_array(self): def _oindex_get(self, key): return type(self)(_wrap_numpy_scalars(self.array[key])) + def _vindex_get(self, key): + return type(self)(_wrap_numpy_scalars(self.array[key])) + def __getitem__(self, key): return type(self)(_wrap_numpy_scalars(self.array[key])) @@ -727,6 +747,9 @@ def get_duck_array(self): def _oindex_get(self, key): return type(self)(_wrap_numpy_scalars(self.array[key])) + def _vindex_get(self, key): + return type(self)(_wrap_numpy_scalars(self.array[key])) + def __getitem__(self, key): return type(self)(_wrap_numpy_scalars(self.array[key])) @@ -1364,8 +1387,12 @@ def transpose(self, order): return self.array.transpose(order) def _oindex_get(self, key): - array, key = self._indexing_array_and_key(key) - return array[key] + key = _outer_to_numpy_indexer(key, self.array.shape) + return self.array[key] + + def _vindex_get(self, key): + array = NumpyVIndexAdapter(self.array) + return array[key.tuple] def __getitem__(self, key): array, key = self._indexing_array_and_key(key) @@ -1419,6 +1446,9 @@ def _oindex_get(self, key): value = value[(slice(None),) * axis + (subkey, Ellipsis)] return value + def _vindex_get(self, key): + raise TypeError("Vectorized indexing is not supported") + def __getitem__(self, key): if isinstance(key, BasicIndexer): return self.array[key.tuple] @@ -1465,11 +1495,14 @@ def _oindex_get(self, key): value = value[(slice(None),) * axis + (subkey,)] return value + def _vindex_get(self, key): + return self.array.vindex[key.tuple] + def __getitem__(self, key): if isinstance(key, BasicIndexer): return self.array[key.tuple] elif isinstance(key, VectorizedIndexer): - return self.array.vindex[key.tuple] + return self.vindex[key] else: assert isinstance(key, OuterIndexer) return self.oindex[key] @@ -1551,6 +1584,9 @@ def _convert_scalar(self, item): def _oindex_get(self, key): return self.__getitem__(key) + def _vindex_get(self, key): + return self.__getitem__(key) + def __getitem__( self, indexer ) -> ( diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6834931fa11..4da3e5fd841 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -759,10 +759,10 @@ def __getitem__(self, key) -> Self: dims, indexer, new_order = self._broadcast_indexes(key) indexable = as_indexable(self._data) - if isinstance(indexer, BasicIndexer): - data = indexable[indexer] - elif isinstance(indexer, OuterIndexer): + if isinstance(indexer, OuterIndexer): data = indexable.oindex[indexer] + elif isinstance(indexer, VectorizedIndexer): + data = indexable.vindex[indexer] else: data = indexable[indexer] if new_order: @@ -801,6 +801,9 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): if isinstance(indexer, OuterIndexer): data = indexable.oindex[indexer] + + elif isinstance(indexer, VectorizedIndexer): + data = indexable.vindex[indexer] else: data = indexable[actual_indexer] mask = indexing.create_mask(indexer, self.shape, data)