Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): Support for pandas ExtensionArray #8723

Merged
merged 101 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
b2712f1
(feat): first pass supporting extension arrays
ilan-gold Feb 2, 2024
47bddd2
(feat): categorical tests + functionality
ilan-gold Feb 2, 2024
dc8b788
(feat): use multiple dispatch for unimplemented ops
ilan-gold Feb 5, 2024
75524c8
(feat): implement (not really) broadcasting
ilan-gold Feb 5, 2024
c9ab452
(chore): add more `groupby` tests
ilan-gold Feb 5, 2024
1f3d0fa
(fix): fix more groupby incompatibility
ilan-gold Feb 5, 2024
8a70e3c
(bug): fix unused categories
ilan-gold Feb 5, 2024
f5a6505
(chore): refactor dispatched methods + tests
ilan-gold Feb 5, 2024
08a4feb
(fix): shared type should check for extension arrays first and then f…
ilan-gold Feb 6, 2024
d5b218b
(refactor): tests moved
ilan-gold Feb 6, 2024
00256fa
(chore): more higher level tests
ilan-gold Feb 6, 2024
b7ddbd6
(feat): to/from dataframe
ilan-gold Feb 8, 2024
a165851
(chore): check for plum import
ilan-gold Feb 8, 2024
a826edd
(fix): `__setitem__`/`__getitem__`
ilan-gold Feb 8, 2024
fde19ea
(chore): disallow stacking
ilan-gold Feb 8, 2024
4c55707
(fix): `pyproject.toml`
ilan-gold Feb 8, 2024
58ba17d
(fix): `as_shared_type` fix
ilan-gold Feb 8, 2024
a255310
(chore): add variable tests
ilan-gold Feb 8, 2024
4e78b7e
(fix): dask + categoricals
ilan-gold Feb 8, 2024
d9cedf5
(chore): notes/docs
ilan-gold Feb 8, 2024
426664d
(chore): remove old testing file
ilan-gold Feb 8, 2024
22ca77d
(chore): remove ocmmented out code
ilan-gold Feb 8, 2024
f32cfdf
Merge branch 'main' into extension_arrays
ilan-gold Feb 8, 2024
60f8927
(fix): import plum dispatch
ilan-gold Feb 8, 2024
ff22d76
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 8, 2024
2153e81
Merge branch 'main' into extension_arrays
ilan-gold Feb 9, 2024
b6d0b31
(refactor): use `is_extension_array_dtype` as much as possible
ilan-gold Feb 9, 2024
d285871
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 9, 2024
d847277
Merge branch 'main' into extension_arrays
ilan-gold Feb 9, 2024
8238c64
(refactor): `extension_array`->`array` + move to `indexing`
ilan-gold Feb 10, 2024
1260cd4
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 10, 2024
b04ef98
(refactor): change order of classes
ilan-gold Feb 10, 2024
b9937bf
(chore): add small pyarrow test
ilan-gold Feb 12, 2024
0bba03f
(fix): fix some mypy issues
ilan-gold Feb 12, 2024
b714549
(fix): don't register unregisterable method
ilan-gold Feb 12, 2024
a3a678c
(fix): appease mypy
ilan-gold Feb 12, 2024
e521844
(fix): more sensible default implemetations allow most use without `p…
ilan-gold Feb 12, 2024
2d3e930
(fix): handling `pyarrow` tests
ilan-gold Feb 12, 2024
04c9969
(fix): actually do import correctly
ilan-gold Feb 12, 2024
5514539
Merge branch 'main' into extension_arrays
ilan-gold Feb 12, 2024
bedfa5c
(fix): `reduce` condition
ilan-gold Feb 13, 2024
e6c2690
Merge branch 'main' into extension_arrays
ilan-gold Feb 13, 2024
82dbda9
(fix): column ordering for dataframes
ilan-gold Feb 13, 2024
12217ed
(refactor): remove encoding business
ilan-gold Feb 13, 2024
dd5b87d
(refactor): raise error for dask + extension array
ilan-gold Feb 13, 2024
761a874
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 13, 2024
52cabc8
Merge branch 'main' into extension_arrays
ilan-gold Feb 13, 2024
e0d58fa
(fix): only wrap `ExtensionDuckArray` that has a `.array` which is a …
ilan-gold Feb 15, 2024
c1e0e64
(fix): use duck array equality method, not pandas
ilan-gold Feb 15, 2024
17e3390
(refactor): bye plum!
ilan-gold Feb 15, 2024
dd2ef39
Merge branch 'main' into extension_arrays
ilan-gold Feb 15, 2024
c8e6bfe
(fix): `and` to `or` for casting to `ExtensionDuckArray`
ilan-gold Feb 15, 2024
b2a9517
(fix): check for class, not type
ilan-gold Feb 16, 2024
f5e1bd0
Merge branch 'main' into extension_arrays
ilan-gold Feb 16, 2024
407fad1
(fix): only support native endianness
ilan-gold Feb 19, 2024
3a47f09
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 19, 2024
fdd3de4
Merge branch 'main' into extension_arrays
ilan-gold Feb 19, 2024
6b23629
Merge branch 'main' into extension_arrays
ilan-gold Feb 20, 2024
1c9047f
(refactor): no need for superfluous checks in `_maybe_wrap_data`
ilan-gold Feb 22, 2024
9be6b03
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 22, 2024
d9304f1
(chore): clean up docs to no longer reference `plum`
ilan-gold Feb 22, 2024
6ec6725
(fix): no longer allow `ExtensionDuckArray` to wrap `ExtensionDuckArray`
ilan-gold Feb 22, 2024
bc9ac4c
(refactor): move `implements` logic to `indexing`
ilan-gold Feb 22, 2024
1e906db
Merge branch 'main' into extension_arrays
ilan-gold Feb 29, 2024
6fb8668
(refactor): `indexing.py` -> `extension_array.py`
ilan-gold Feb 29, 2024
8f034b4
(refactor): `ExtensionDuckArray` -> `PandasExtensionArray`
ilan-gold Feb 29, 2024
90a6de6
Merge branch 'main' into extension_arrays
dcherian Mar 3, 2024
2bd422a
Merge branch 'main' into extension_arrays
ilan-gold Mar 18, 2024
ff67943
Merge branch 'main' into extension_arrays
ilan-gold Mar 25, 2024
661d9f2
(fix): add writeable property
ilan-gold Mar 25, 2024
caee1c6
(fix): don't check writeable for `PandasExtensionArray`
ilan-gold Mar 25, 2024
1d12f5e
(fix): move check eariler
ilan-gold Mar 25, 2024
31dfbb5
Merge branch 'main' into extension_arrays
ilan-gold Mar 26, 2024
23b347f
Merge branch 'main' into extension_arrays
ilan-gold Mar 28, 2024
902c74b
(refactor): correct guard clause
ilan-gold Mar 28, 2024
0b64506
(chore): remove unnecessary `AttributeError`
ilan-gold Mar 28, 2024
0c7e023
(feat): singleton wrapped as array
ilan-gold Mar 28, 2024
dd7fe98
(feat): remove shared dtype casting
ilan-gold Mar 28, 2024
f0df768
(feat): loop once over `dataframe.items`
ilan-gold Mar 28, 2024
e2f0487
(feat): add `__len__` attribute
ilan-gold Mar 28, 2024
1eb6741
(fix): ensure constructor recieves `pd.Categorical`
ilan-gold Mar 28, 2024
2a7300a
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Mar 28, 2024
9cceadc
Update xarray/core/extension_array.py
ilan-gold Mar 28, 2024
f2588c1
Update xarray/core/extension_array.py
ilan-gold Mar 28, 2024
a0a63bd
(fix): drop condition for categorical corrected
ilan-gold Mar 28, 2024
5bb2bde
Merge branch 'main' into extension_arrays
ilan-gold Mar 28, 2024
f85f166
Merge branch 'main' into extension_arrays
ilan-gold Apr 3, 2024
7ecdeba
Merge branch 'main' into extension_arrays
ilan-gold Apr 4, 2024
6bc40fc
Merge branch 'main' into extension_arrays
ilan-gold Apr 11, 2024
e9dc53f
Apply suggestions from code review
dcherian Apr 13, 2024
4791799
(chore): test `chunk` behavior
ilan-gold Apr 16, 2024
c649362
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Apr 16, 2024
fc60dcf
Merge branch 'main' into extension_arrays
ilan-gold Apr 16, 2024
0374086
Update xarray/core/variable.py
dcherian Apr 16, 2024
b9515a6
Merge branch 'main' into extension_arrays
dcherian Apr 16, 2024
72bf807
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
63b6c42
(fix): bring back error
ilan-gold Apr 17, 2024
1d18439
(chore): add test for dropping cat for mean
ilan-gold Apr 17, 2024
17f05da
Update whats-new.rst
dcherian Apr 17, 2024
c906c81
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
e6db83b
Merge branch 'main' into extension_arrays
ilan-gold Apr 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ New Features
This is currently limited to the linear interpolation method (`method='linear'`).
(:issue:`7377`, :pull:`8684`) By `Marco Wolsza <https://github.com/maawoo>`_.

- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array
by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`, for example,
will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` then, such as broadcasting.

Breaking changes
~~~~~~~~~~~~~~~~

Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [

[project.optional-dependencies]
accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"]
complete = ["xarray[accel,io,parallel,viz,dev]"]
complete = ["xarray[accel,io,parallel,viz,dev,extension_arrays]"]
dev = [
"hypothesis",
"pre-commit",
Expand All @@ -45,6 +45,7 @@ dev = [
io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"]
parallel = ["dask[complete]"]
viz = ["matplotlib", "seaborn", "nc-time-axis"]
extension-arrays = ["plum-dispatch"]
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved

[project.urls]
Documentation = "https://docs.xarray.dev"
Expand Down Expand Up @@ -124,6 +125,7 @@ module = [
"opt_einsum.*",
"pandas.*",
"pooch.*",
"pyarrow.*",
"pydap.*",
"pytest.*",
"scipy.*",
Expand Down
8 changes: 4 additions & 4 deletions xarray/coding/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import partial

import numpy as np
from pandas.api.types import is_extension_array_dtype

from xarray.coding.variables import (
VariableCoder,
Expand All @@ -27,11 +28,10 @@ def create_vlen_dtype(element_type):


def check_vlen_dtype(dtype):
if dtype.kind != "O" or dtype.metadata is None:
if is_extension_array_dtype(dtype) or dtype.kind != "O" or dtype.metadata is None:
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
return None
else:
# check xarray (element_type) as well as h5py (vlen)
return dtype.metadata.get("element_type", dtype.metadata.get("vlen"))
# check xarray (element_type) as well as h5py (vlen)
return dtype.metadata.get("element_type", dtype.metadata.get("vlen"))


def is_unicode_dtype(dtype):
Expand Down
12 changes: 8 additions & 4 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype
from pandas.errors import OutOfBoundsDatetime, OutOfBoundsTimedelta

from xarray.coding.variables import (
Expand Down Expand Up @@ -967,9 +968,10 @@ def __init__(self, use_cftime: bool | None = None) -> None:
self.use_cftime = use_cftime

def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(
variable.data.dtype, np.datetime64
) or contains_cftime_datetimes(variable):
if (not is_extension_array_dtype(variable.data)) and (
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
np.issubdtype(variable.data.dtype, np.datetime64)
or contains_cftime_datetimes(variable)
):
dims, data, attrs, encoding = unpack_for_encoding(variable)

units = encoding.pop("units", None)
Expand Down Expand Up @@ -1007,7 +1009,9 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:

class CFTimedeltaCoder(VariableCoder):
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(variable.data.dtype, np.timedelta64):
if (not is_extension_array_dtype(variable.data)) and np.issubdtype(
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
variable.data.dtype, np.timedelta64
):
dims, data, attrs, encoding = unpack_for_encoding(variable)

data, units = encode_cf_timedelta(
Expand Down
20 changes: 12 additions & 8 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype

from xarray.core import dtypes, duck_array_ops, indexing
from xarray.core.parallelcompat import get_chunked_array_type
Expand Down Expand Up @@ -250,7 +251,6 @@ class CFMaskCoder(VariableCoder):
def encode(self, variable: Variable, name: T_Name = None):
dims, data, attrs, encoding = unpack_for_encoding(variable)

dtype = np.dtype(encoding.get("dtype", data.dtype))
fv = encoding.get("_FillValue")
mv = encoding.get("missing_value")

Expand All @@ -268,6 +268,8 @@ def encode(self, variable: Variable, name: T_Name = None):
# special case DateTime to properly handle NaT
is_time_like = _is_time_like(attrs.get("units"))

dtype = np.dtype(encoding.get("dtype", data.dtype))

if fv_exists:
# Ensure _FillValue is cast to same dtype as data's
encoding["_FillValue"] = dtype.type(fv)
Expand Down Expand Up @@ -472,16 +474,18 @@ class DefaultFillvalueCoder(VariableCoder):

def encode(self, variable: Variable, name: T_Name = None) -> Variable:
dims, data, attrs, encoding = unpack_for_encoding(variable)
has_no_fill = "_FillValue" not in attrs and "_FillValue" not in encoding
# make NaN the fill value for float types
if (
"_FillValue" not in attrs
and "_FillValue" not in encoding
and np.issubdtype(variable.dtype, np.floating)
):
if is_extension_array_dtype(data):
if not has_no_fill:
raise ValueError(
"Found _FillValue encoding or attr on extension array."
)
return variable
if has_no_fill and np.issubdtype(variable.dtype, np.floating):
attrs["_FillValue"] = variable.dtype.type(np.nan)
return Variable(dims, data, attrs, encoding, fastpath=True)
else:
return variable
return variable

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
raise NotImplementedError()
Expand Down
12 changes: 8 additions & 4 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype

from xarray.coding import strings, times, variables
from xarray.coding.variables import SerializationWarning, pop_to
Expand Down Expand Up @@ -114,7 +115,10 @@ def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
dims, data, attrs, encoding = variables.unpack_for_encoding(var)

# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
if (
is_extension_array_dtype(data)
or strings.check_vlen_dtype(data.dtype) is not None
):
return var

if is_duck_dask_array(data):
Expand Down Expand Up @@ -356,9 +360,9 @@ def _update_bounds_encoding(variables: T_Variables) -> None:
attrs = v.attrs
encoding = v.encoding
has_date_units = "units" in encoding and "since" in encoding["units"]
is_datetime_type = np.issubdtype(
v.dtype, np.datetime64
) or contains_cftime_datetimes(v)
is_datetime_type = (not is_extension_array_dtype(v)) and (
contains_cftime_datetimes(v) or np.issubdtype(v.dtype, np.datetime64)
)

if (
is_datetime_type
Expand Down
43 changes: 36 additions & 7 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload

import numpy as np
from pandas.api.types import is_extension_array_dtype

# remove once numpy 2.0 is the oldest supported version
try:
Expand Down Expand Up @@ -6835,6 +6836,7 @@ def reduce(
# that don't have the reduce dims: PR5393
not reduce_dims
or not numeric_only
or not is_extension_array_dtype(var.dtype)
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
):
Expand Down Expand Up @@ -7149,13 +7151,33 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
)

def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
columns = [k for k in self.variables if k not in self.dims]
columns = [
k
for k in self.variables
if k not in self.dims
and not is_extension_array_dtype(self.variables[k].data)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomNicholas see for example here about checking explicitly, so decoupling this PR should be doable. Behavior on tests seemed the same

]
extension_array_columns = [
k
for k in self.variables
if k not in self.dims and is_extension_array_dtype(self.variables[k].data)
]
data = [
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
for k in columns
]
index = self.coords.to_index([*ordered_dims])
return pd.DataFrame(dict(zip(columns, data)), index=index)
broadcasted_df = pd.DataFrame(dict(zip(columns, data)), index=index)
for extension_array_column in extension_array_columns:
extension_array = self.variables[extension_array_column].data.array
index = self[self.variables[extension_array_column].dims[0]].data
cat_df = pd.DataFrame(
{extension_array_column: extension_array},
index=self[self.variables[extension_array_column].dims[0]].data,
)
cat_df.index.name = self.variables[extension_array_column].dims[0]
broadcasted_df = broadcasted_df.join(cat_df)
return broadcasted_df

def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame:
"""Convert this dataset into a pandas.DataFrame.
Expand Down Expand Up @@ -7301,11 +7323,14 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
)

# Cast to a NumPy array first, in case the Series is a pandas Extension
# array (which doesn't have a valid NumPy dtype)
# TODO: allow users to control how this casting happens, e.g., by
# forwarding arguments to pandas.Series.to_numpy?
arrays = [(k, np.asarray(v)) for k, v in dataframe.items()]
arrays = [
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
(k, np.asarray(v))
for k, v in dataframe.items()
if not is_extension_array_dtype(v)
]
extension_arrays = [
(k, v) for k, v in dataframe.items() if is_extension_array_dtype(v)
]

indexes: dict[Hashable, Index] = {}
index_vars: dict[Hashable, Variable] = {}
Expand All @@ -7319,6 +7344,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
xr_idx = PandasIndex(lev, dim)
indexes[dim] = xr_idx
index_vars.update(xr_idx.create_variables())
arrays += [(k, np.asarray(v)) for k, v in extension_arrays]
extension_arrays = []
else:
index_name = idx.name if idx.name is not None else "index"
dims = (index_name,)
Expand All @@ -7332,6 +7359,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
obj._set_sparse_data_from_dataframe(idx, arrays, dims)
else:
obj._set_numpy_data_from_dataframe(idx, arrays, dims)
for name, extension_array in extension_arrays:
obj[name] = (dims, extension_array)
return obj

def to_dask_dataframe(
Expand Down
86 changes: 82 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import datetime
import inspect
import warnings
from collections.abc import Sequence
from functools import partial
from importlib import import_module
from typing import Callable

import numpy as np
import pandas as pd
Expand All @@ -32,11 +34,21 @@
from numpy import concatenate as _concatenate
from numpy.lib.stride_tricks import sliding_window_view # noqa
from packaging.version import Version
from pandas.api.types import is_extension_array_dtype

try:
from plum import dispatch # type: ignore[import-not-found]
except ImportError:

def dispatch(func):
return func


from xarray.core import dask_array_ops, dtypes, nputils, pycompat
from xarray.core.options import OPTIONS
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.core.pycompat import array_type, is_duck_dask_array
from xarray.core.types import DTypeLikeSave, T_ExtensionArray
from xarray.core.utils import is_duck_array, module_available

# remove once numpy 2.0 is the oldest supported version
Expand All @@ -53,6 +65,64 @@
dask_available = module_available("dask")


HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {}


def implements(numpy_function):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
"""Register an __array_function__ implementation for MyArray objects."""

def decorator(func):
HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func
return func

return decorator


@implements(np.issubdtype)
@dispatch
def __extension_duck_array__issubdtype(
extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave
) -> bool:
return False # never want a function to think a pandas extension dtype is a subtype of numpy


@implements(np.broadcast_to)
@dispatch
def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benbovy speaking of the indexing.py file, I noticed the PandasIndexingAdapter returns a np.array version of the index in certain cases so that multi-dim opertaions (such as broadcast?) can work. Should I do the same here? I explicitly don't support this functionality here (see point 1. on this comment for why I made this decision). I don't have strong feelings, but I am inclined to not support it by default, especially since with plum.dispatch, the behavior can be overridden.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong feelings either on this.

if shape[0] == len(arr) and len(shape) == 1:
return arr
raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.")


@implements(np.stack)
@dispatch
def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int):
raise NotImplementedError("Cannot stack 1d-only pandas categorical array.")


@implements(np.concatenate)
@dispatch
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
def __extension_duck_array__concatenate(
arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None
) -> T_ExtensionArray:
return type(arrays[0])._concat_same_type(arrays)


@implements(np.where)
@dispatch
def __extension_duck_array__where(
condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray
) -> T_ExtensionArray:
if (
isinstance(x, pd.Categorical)
and isinstance(y, pd.Categorical)
and x.dtype != y.dtype
):
x = x.add_categories(set(y.categories).difference(set(x.categories)))
y = y.add_categories(set(x.categories).difference(set(y.categories)))
return pd.Series(x).where(condition, pd.Series(y)).array


def get_array_namespace(x):
if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
Expand Down Expand Up @@ -155,7 +225,7 @@ def isnull(data):
return full_like(data, dtype=bool, fill_value=False)
else:
# at this point, array should have dtype=object
if isinstance(data, np.ndarray):
if isinstance(data, np.ndarray) or is_extension_array_dtype(data):
return pandas_isnull(data)
else:
# Not reachable yet, but intended for use with other duck array
Expand Down Expand Up @@ -220,9 +290,17 @@ def asarray(data, xp=np):

def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
array_type_cupy = array_type("cupy")
if array_type_cupy and any(
isinstance(x, array_type_cupy) for x in scalars_or_arrays
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
if len(extension_array_types) == len(scalars_or_arrays) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return scalars_or_arrays
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
arrays = [asarray(np.array(x), xp=xp) for x in scalars_or_arrays]
elif array_type_cupy := array_type("cupy") and any( # noqa: F841
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
):
import cupy as cp

Expand Down
Loading
Loading