diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cc3bf5ca410..f4dea1cd2aa 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- Allow accessing arbitrary attributes on Pandas ExtensionArrays. + By `Deepak Cherian `_. + Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 269016ddfd1..096a427e425 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -2,7 +2,7 @@ from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import Generic, cast +from typing import Any, Generic, cast import numpy as np import pandas as pd @@ -142,3 +142,10 @@ def __array__( return np.asarray(self.array, dtype=dtype, copy=copy) else: return np.asarray(self.array, dtype=dtype) + + def __getattr__(self, attr: str) -> Any: + # with __deepcopy__ or __copy__, the object is first constructed and then the sub-objects are attached (see https://docs.python.org/3/library/copy.html) + # Thus, if we didn't have `super().__getattribute__("array")` this method would call `self.array` (i.e., `getattr(self, "array")`) again while looking for `__setstate__` + # (which is apparently the first thing sought in copy.copy from the under-construction copied object), + # which would cause a recursion error since `array` is not present on the object when it is being constructed during `__{deep}copy__`. + return getattr(super().__getattribute__("array"), attr) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index af7db7294a8..11f9e88c69b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -297,7 +297,7 @@ def test_repr(self) -> None: var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364 var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423 var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555 - var4 (dim1) category 64B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a' + var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a' Attributes: foo: bar""".format( data["dim3"].dtype, diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index ff84041f8f1..4052d414f63 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime as dt +import pickle import warnings import numpy as np @@ -1094,3 +1095,19 @@ def test_extension_array_singleton_equality(categorical1): def test_extension_array_repr(int1): int_duck_array = PandasExtensionArray(int1) assert repr(int1) in repr(int_duck_array) + + +def test_extension_array_attr(): + array = pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) + wrapped = PandasExtensionArray(array) + assert_array_equal(array.categories, wrapped.categories) + assert array.nbytes == wrapped.nbytes + + roundtripped = pickle.loads(pickle.dumps(wrapped)) + assert isinstance(roundtripped, PandasExtensionArray) + assert (roundtripped == wrapped).all() + + interval_array = pd.arrays.IntervalArray.from_breaks([0, 1, 2, 3], closed="right") + wrapped = PandasExtensionArray(interval_array) + assert_array_equal(wrapped.left, interval_array.left, strict=True) + assert wrapped.closed == interval_array.closed