-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* (feat): first pass supporting extension arrays * (feat): categorical tests + functionality * (feat): use multiple dispatch for unimplemented ops * (feat): implement (not really) broadcasting * (chore): add more `groupby` tests * (fix): fix more groupby incompatibility * (bug): fix unused categories * (chore): refactor dispatched methods + tests * (fix): shared type should check for extension arrays first and then fall back to numpy * (refactor): tests moved * (chore): more higher level tests * (feat): to/from dataframe * (chore): check for plum import * (fix): `__setitem__`/`__getitem__` * (chore): disallow stacking * (fix): `pyproject.toml` * (fix): `as_shared_type` fix * (chore): add variable tests * (fix): dask + categoricals * (chore): notes/docs * (chore): remove old testing file * (chore): remove ocmmented out code * (fix): import plum dispatch * (refactor): use `is_extension_array_dtype` as much as possible * (refactor): `extension_array`->`array` + move to `indexing` * (refactor): change order of classes * (chore): add small pyarrow test * (fix): fix some mypy issues * (fix): don't register unregisterable method * (fix): appease mypy * (fix): more sensible default implemetations allow most use without `plum` * (fix): handling `pyarrow` tests * (fix): actually do import correctly * (fix): `reduce` condition * (fix): column ordering for dataframes * (refactor): remove encoding business * (refactor): raise error for dask + extension array * (fix): only wrap `ExtensionDuckArray` that has a `.array` which is a pandas extension array * (fix): use duck array equality method, not pandas * (refactor): bye plum! * (fix): `and` to `or` for casting to `ExtensionDuckArray` * (fix): check for class, not type * (fix): only support native endianness * (refactor): no need for superfluous checks in `_maybe_wrap_data` * (chore): clean up docs to no longer reference `plum` * (fix): no longer allow `ExtensionDuckArray` to wrap `ExtensionDuckArray` * (refactor): move `implements` logic to `indexing` * (refactor): `indexing.py` -> `extension_array.py` * (refactor): `ExtensionDuckArray` -> `PandasExtensionArray` * (fix): add writeable property * (fix): don't check writeable for `PandasExtensionArray` * (fix): move check eariler * (refactor): correct guard clause * (chore): remove unnecessary `AttributeError` * (feat): singleton wrapped as array * (feat): remove shared dtype casting * (feat): loop once over `dataframe.items` * (feat): add `__len__` attribute * (fix): ensure constructor recieves `pd.Categorical` * Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian <[email protected]> * Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian <[email protected]> * (fix): drop condition for categorical corrected * Apply suggestions from code review * (chore): test `chunk` behavior * Update xarray/core/variable.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * (fix): bring back error * (chore): add test for dropping cat for mean * Update whats-new.rst * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
60f3e74
commit 9eb180b
Showing
16 changed files
with
434 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,6 +130,7 @@ module = [ | |
"opt_einsum.*", | ||
"pandas.*", | ||
"pooch.*", | ||
"pyarrow.*", | ||
"pydap.*", | ||
"pytest.*", | ||
"scipy.*", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Sequence | ||
from typing import Callable, Generic | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from pandas.api.types import is_extension_array_dtype | ||
|
||
from xarray.core.types import DTypeLikeSave, T_ExtensionArray | ||
|
||
HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} | ||
|
||
|
||
def implements(numpy_function): | ||
"""Register an __array_function__ implementation for MyArray objects.""" | ||
|
||
def decorator(func): | ||
HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func | ||
return func | ||
|
||
return decorator | ||
|
||
|
||
@implements(np.issubdtype) | ||
def __extension_duck_array__issubdtype( | ||
extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave | ||
) -> bool: | ||
return False # never want a function to think a pandas extension dtype is a subtype of numpy | ||
|
||
|
||
@implements(np.broadcast_to) | ||
def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): | ||
if shape[0] == len(arr) and len(shape) == 1: | ||
return arr | ||
raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") | ||
|
||
|
||
@implements(np.stack) | ||
def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): | ||
raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") | ||
|
||
|
||
@implements(np.concatenate) | ||
def __extension_duck_array__concatenate( | ||
arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None | ||
) -> T_ExtensionArray: | ||
return type(arrays[0])._concat_same_type(arrays) | ||
|
||
|
||
@implements(np.where) | ||
def __extension_duck_array__where( | ||
condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray | ||
) -> T_ExtensionArray: | ||
if ( | ||
isinstance(x, pd.Categorical) | ||
and isinstance(y, pd.Categorical) | ||
and x.dtype != y.dtype | ||
): | ||
x = x.add_categories(set(y.categories).difference(set(x.categories))) | ||
y = y.add_categories(set(x.categories).difference(set(y.categories))) | ||
return pd.Series(x).where(condition, pd.Series(y)).array | ||
|
||
|
||
class PandasExtensionArray(Generic[T_ExtensionArray]): | ||
array: T_ExtensionArray | ||
|
||
def __init__(self, array: T_ExtensionArray): | ||
"""NEP-18 compliant wrapper for pandas extension arrays. | ||
Parameters | ||
---------- | ||
array : T_ExtensionArray | ||
The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. | ||
``` | ||
""" | ||
if not isinstance(array, pd.api.extensions.ExtensionArray): | ||
raise TypeError(f"{array} is not an pandas ExtensionArray.") | ||
self.array = array | ||
|
||
def __array_function__(self, func, types, args, kwargs): | ||
def replace_duck_with_extension_array(args) -> list: | ||
args_as_list = list(args) | ||
for index, value in enumerate(args_as_list): | ||
if isinstance(value, PandasExtensionArray): | ||
args_as_list[index] = value.array | ||
elif isinstance( | ||
value, tuple | ||
): # should handle more than just tuple? iterable? | ||
args_as_list[index] = tuple( | ||
replace_duck_with_extension_array(value) | ||
) | ||
elif isinstance(value, list): | ||
args_as_list[index] = replace_duck_with_extension_array(value) | ||
return args_as_list | ||
|
||
args = tuple(replace_duck_with_extension_array(args)) | ||
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: | ||
return func(*args, **kwargs) | ||
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) | ||
if is_extension_array_dtype(res): | ||
return type(self)[type(res)](res) | ||
return res | ||
|
||
def __array_ufunc__(ufunc, method, *inputs, **kwargs): | ||
return ufunc(*inputs, **kwargs) | ||
|
||
def __repr__(self): | ||
return f"{type(self)}(array={repr(self.array)})" | ||
|
||
def __getattr__(self, attr: str) -> object: | ||
return getattr(self.array, attr) | ||
|
||
def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: | ||
item = self.array[key] | ||
if is_extension_array_dtype(item): | ||
return type(self)(item) | ||
if np.isscalar(item): | ||
return type(self)(type(self.array)([item])) | ||
return item | ||
|
||
def __setitem__(self, key, val): | ||
self.array[key] = val | ||
|
||
def __eq__(self, other): | ||
if np.isscalar(other): | ||
other = type(self)(type(self.array)([other])) | ||
if isinstance(other, PandasExtensionArray): | ||
return self.array == other.array | ||
return self.array == other | ||
|
||
def __ne__(self, other): | ||
return ~(self == other) | ||
|
||
def __len__(self): | ||
return len(self.array) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.