-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: add __arrow_c_stream__ function #11338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f70df33
11fd74b
8b0c828
ac02132
f17c813
74a7251
e69b4f4
57b836d
e7cc3f1
ba824b4
ecadc76
52fed6b
5872890
5c9c8c0
276c394
64b695d
d998785
1eda714
ad8c76d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import copy | ||
| import datetime | ||
| import json | ||
| import warnings | ||
| from collections.abc import ( | ||
| Callable, | ||
|
|
@@ -255,6 +256,21 @@ def __setitem__(self, key, value) -> None: | |
| _THIS_ARRAY = ReprObject("<this-array>") | ||
|
|
||
|
|
||
| class _NumpyEncoder(json.JSONEncoder): | ||
| """Encode numpy value in Arrow schema metadata""" | ||
|
|
||
| def default(self, obj: Any) -> Any: | ||
| if isinstance(obj, np.integer): | ||
| return int(obj) | ||
| if isinstance(obj, np.floating): | ||
| return float(obj) | ||
| if isinstance(obj, np.bool_): | ||
| return bool(obj) | ||
| if isinstance(obj, np.ndarray): | ||
| return obj.tolist() | ||
| return super().default(obj) | ||
|
|
||
|
|
||
| class DataArray( | ||
| AbstractArray, | ||
| DataWithCoords, | ||
|
|
@@ -477,6 +493,94 @@ def __init__( | |
|
|
||
| self._close = None | ||
|
|
||
| def __arrow_c_schema__(self): | ||
| try: | ||
| import pyarrow as pa | ||
| except ImportError: | ||
| raise ImportError( | ||
| "pyarrow is required to export via the Arrow PyCapsule Interface." | ||
| ) from None | ||
|
|
||
| values_column = self.name or "values" | ||
|
|
||
| fields = [] | ||
| for name, coord in self._coords.items(): | ||
| arrow_dtype = pa.from_numpy_dtype(coord.dtype) | ||
| fields.append(pa.field(str(name), arrow_dtype)) | ||
|
|
||
| fields.append( | ||
| pa.field(str(values_column), pa.from_numpy_dtype(self._variable.dtype)) | ||
| ) | ||
|
|
||
| xarray_metadata = { | ||
| "name": self.name, | ||
| "dims": list(self._variable.dims), | ||
| "shape": list(self._variable.shape), | ||
| "attrs": self.attrs, | ||
| "coords": { | ||
| str(name): variable.to_dict(data=False) | ||
| for name, variable in self._coords.items() | ||
| }, | ||
|
dcherian marked this conversation as resolved.
|
||
| } | ||
| schema_metadata = { | ||
| b"xarray:arrow_schema_version": b"v1", | ||
| b"xarray": json.dumps(xarray_metadata, cls=_NumpyEncoder).encode(), | ||
| } | ||
|
|
||
| schema = pa.schema(fields, metadata=schema_metadata) | ||
| return schema.__arrow_c_schema__() | ||
|
|
||
| def __arrow_c_stream__(self, requested_schema: Any = None) -> Any: | ||
| """Export the DataArray through the Arrow PyCapsule Interface. | ||
|
|
||
| https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html | ||
| """ | ||
| try: | ||
| import pyarrow as pa | ||
| except ImportError: | ||
| raise ImportError( | ||
| "pyarrow is required to export via the Arrow PyCapsule Interface." | ||
| ) from None | ||
|
|
||
| values = self._variable.values | ||
| dims = self._variable.dims | ||
| shape = self._variable.shape | ||
|
|
||
| values_column = self.name or "values" | ||
|
|
||
| if not values.flags.c_contiguous: | ||
| values = np.ascontiguousarray(values) | ||
|
Comment on lines
+551
to
+552
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can only use |
||
|
|
||
| columns: dict[Hashable, pa.Array] = {} | ||
| for name, coord in self._coords.items(): | ||
| # Broadcast each coordinate up to the full data shape so that 1D | ||
| # dimension coordinates and N-D (e.g. curvilinear) coordinates | ||
| # flatten consistently with the data values. | ||
|
|
||
| # Order axes based on Variable dims | ||
| dim_order = [coord.dims.index(dim) for dim in dims if dim in coord.dims] | ||
|
|
||
| # Reorder coords values to variable dim order | ||
| ordered_coords = coord.values.transpose(dim_order) | ||
|
|
||
| # Expand coord dims | ||
| # coord dims (x, y) variable dims (x,y,z) -> (x, y, 1) | ||
| # NOTE: Insert a length-1 axis for each data dim missing for coordinates | ||
| # (slice(None) keeps an existing axis, np.newaxis adds one) | ||
| indexer = tuple( | ||
| slice(None) if dim in coord.dims else np.newaxis for dim in dims | ||
| ) | ||
| expanded_coords = ordered_coords[indexer] | ||
|
|
||
| # Broadcast to full flattened shape (x, y, 1) -> (x, y, z) | ||
| broadcasted = np.broadcast_to(expanded_coords, shape) | ||
|
Comment on lines
+560
to
+576
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we share this with the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a bit fearful of using this for now for the I think another PR is the way to go, and I'm not sure we can factor everything here. |
||
| columns[name] = pa.array(np.ravel(broadcasted)) | ||
|
|
||
| columns[values_column] = pa.array(np.ravel(values)) | ||
|
|
||
| table = pa.table(columns, schema=pa.schema(self)) | ||
| return table.__arrow_c_stream__(requested_schema) | ||
|
|
||
| @classmethod | ||
| def _construct_direct( | ||
| cls, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,7 @@ | |
| requires_iris, | ||
| requires_numexpr, | ||
| requires_pint, | ||
| requires_polars, | ||
| requires_pyarrow, | ||
| requires_scipy, | ||
| requires_sparse, | ||
|
|
@@ -7673,3 +7674,229 @@ def test_unstack_index_var() -> None: | |
| name="x", | ||
| ) | ||
| assert_identical(actual, expected) | ||
|
|
||
|
|
||
| class TestArrowPyCapsule: | ||
|
dcherian marked this conversation as resolved.
|
||
| @requires_pyarrow | ||
| def test_pyarrow_table_1d(self): | ||
| import pyarrow as pa | ||
|
|
||
| da = xr.DataArray( | ||
| [1.0, 2.0, 3.0], | ||
| dims=["x"], | ||
| coords={"x": [10, 20, 30]}, | ||
| name="temperature", | ||
| ) | ||
| table = pa.table(da) | ||
|
|
||
| assert isinstance(table, pa.Table) | ||
| assert set(table.column_names) == {"x", "temperature"} | ||
| assert table.num_rows == 3 | ||
| assert table.schema.field("x").type == pa.int64() | ||
| assert table.schema.field("temperature").type == pa.float64() | ||
| np.testing.assert_array_equal(table["x"].to_pylist(), [10, 20, 30]) | ||
| np.testing.assert_array_equal(table["temperature"].to_pylist(), [1.0, 2.0, 3.0]) | ||
|
|
||
| @requires_pyarrow | ||
| def test_pyarrow_table_2d(self): | ||
| import pyarrow as pa | ||
|
|
||
| da = xr.DataArray( | ||
| np.arange(6, dtype=float).reshape(2, 3), | ||
| dims=["x", "y"], | ||
| coords={"x": [0, 1], "y": [10, 20, 30]}, | ||
| name="data", | ||
| ) | ||
| table = pa.table(da) | ||
|
|
||
| assert isinstance(table, pa.Table) | ||
| assert set(table.column_names) == {"x", "y", "data"} | ||
| assert table.num_rows == 6 | ||
| assert table.schema.field("x").type == pa.int64() | ||
| assert table.schema.field("y").type == pa.int64() | ||
| assert table.schema.field("data").type == pa.float64() | ||
| np.testing.assert_array_equal( | ||
| table["data"].to_pylist(), list(np.arange(6, dtype=float)) | ||
| ) | ||
|
|
||
| @requires_pyarrow | ||
| def test_data_array_unnamed_variable(self): | ||
| import pyarrow as pa | ||
|
|
||
| da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [0, 1, 2]}) | ||
| table = pa.table(da) | ||
|
|
||
| assert "values" in table.column_names | ||
|
|
||
| @requires_polars | ||
| def test_polars_dataframe_1d(self): | ||
| import polars as pl | ||
|
|
||
| da = xr.DataArray( | ||
| [1.0, 2.0, 3.0], | ||
| dims=["x"], | ||
| coords={"x": [10, 20, 30]}, | ||
| name="temperature", | ||
| ) | ||
| df = pl.from_arrow(da) | ||
|
|
||
| assert isinstance(df, pl.DataFrame) | ||
| assert set(df.columns) == {"x", "temperature"} | ||
| assert len(df) == 3 | ||
| np.testing.assert_array_equal(df["x"].to_list(), [10, 20, 30]) | ||
| np.testing.assert_array_equal(df["temperature"].to_list(), [1.0, 2.0, 3.0]) | ||
|
|
||
| @requires_polars | ||
| def test_polars_dataframe_2d(self): | ||
| import polars as pl | ||
|
|
||
| da = xr.DataArray( | ||
| np.arange(6, dtype=float).reshape(2, 3), | ||
| dims=["x", "y"], | ||
| coords={"x": [0, 1], "y": [10, 20, 30]}, | ||
| name="data", | ||
| ) | ||
| df = pl.from_arrow(da) | ||
|
|
||
| assert isinstance(df, pl.DataFrame) | ||
| assert set(df.columns) == {"x", "y", "data"} | ||
| assert len(df) == 6 | ||
| np.testing.assert_array_equal( | ||
| df["data"].to_list(), list(np.arange(6, dtype=float)) | ||
| ) | ||
| # x repeats for each y: [0,0,0,1,1,1] | ||
| np.testing.assert_array_equal(df["x"].to_list(), [0, 0, 0, 1, 1, 1]) | ||
| # y cycles for each x: [10,20,30,10,20,30] | ||
| np.testing.assert_array_equal(df["y"].to_list(), [10, 20, 30, 10, 20, 30]) | ||
|
|
||
| @requires_dask | ||
| @requires_pyarrow | ||
| def test_dask_dataarray(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is the intent here? Does it load to memory?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just wanted to have a test for dask too, to ensure everything works as intented. Also in a following PR we can implement chunking with I can remove it, if you find it out of scope.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes for now it loads into memory with the casting to numpy values and then conversion to pyarrow array before constructing the table with it.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We should error and ask the user to
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You want explicit import dask.array as da
import pyarrow as pa
dask_da = xr.DataArray(
da.from_array(np.arange(6, dtype=float).reshape(2, 3)),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)
print(dask_da.to_dataframe())I can raise an error if you want.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep I prefer the error. That behaviour seems like a bug and we should deprecate it |
||
| import dask.array as da | ||
| import pyarrow as pa | ||
|
|
||
| dask_da = xr.DataArray( | ||
| da.from_array(np.arange(6, dtype=float).reshape(2, 3)), | ||
| dims=["x", "y"], | ||
| coords={"x": [0, 1], "y": [10, 20, 30]}, | ||
| name="data", | ||
| ) | ||
|
|
||
| table = pa.table(dask_da) | ||
| assert isinstance(table, pa.Table) | ||
| np.testing.assert_array_equal( | ||
| table["data"].to_pylist(), list(np.arange(6, dtype=float)) | ||
| ) | ||
|
|
||
| @requires_polars | ||
| @requires_pyarrow | ||
| def test_polars_pyarrow_consistent(self): | ||
| import polars as pl | ||
| import pyarrow as pa | ||
|
|
||
| da = xr.DataArray( | ||
| np.arange(6, dtype=float).reshape(2, 3), | ||
| dims=["x", "y"], | ||
| coords={"x": [0, 1], "y": [10, 20, 30]}, | ||
| name="data", | ||
| ) | ||
| pa_table = pa.table(da) | ||
| pl_df = pl.from_arrow(da) | ||
|
|
||
| for col in pa_table.column_names: | ||
| np.testing.assert_array_equal( | ||
| pa_table[col].to_pylist(), pl_df[col].to_list() | ||
| ) | ||
|
|
||
| @requires_pyarrow | ||
| def test_arrow_schema_fields(self): | ||
| import pyarrow as pa | ||
|
|
||
| da = xr.DataArray( | ||
| np.arange(6, dtype=float).reshape(2, 3), | ||
| dims=["x", "y"], | ||
| coords={"x": [0, 1], "y": [10, 20, 30]}, | ||
| name="data", | ||
| ) | ||
| schema = pa.schema(da) | ||
|
|
||
| assert isinstance(schema, pa.Schema) | ||
| assert schema.names == ["x", "y", "data"] | ||
| assert schema.field("x").type == pa.int64() | ||
| assert schema.field("y").type == pa.int64() | ||
| assert schema.field("data").type == pa.float64() | ||
|
|
||
| @requires_pyarrow | ||
| def test_arrow_schema_metadata(self): | ||
| import json | ||
|
|
||
| import pyarrow as pa | ||
|
|
||
| da = xr.DataArray( | ||
| [1.0, 2.0, 3.0], | ||
| dims=["x"], | ||
| coords={"x": [10, 20, 30]}, | ||
| name="temperature", | ||
| attrs={"units": "K", "long_name": "temperature"}, | ||
| ) | ||
| schema = pa.schema(da) | ||
|
|
||
| assert schema.metadata[b"xarray:arrow_schema_version"] == b"v1" | ||
|
|
||
| xarray_meta = json.loads(schema.metadata[b"xarray"]) | ||
| assert xarray_meta["name"] == "temperature" | ||
| assert xarray_meta["dims"] == ["x"] | ||
| assert xarray_meta["attrs"] == {"units": "K", "long_name": "temperature"} | ||
| assert "x" in xarray_meta["coords"] | ||
|
|
||
| @requires_pyarrow | ||
| def test_pyarrow_table_curvilinear_coords(self): | ||
| import pyarrow as pa | ||
|
|
||
| # non-dimension coordinates spanning multiple dims (e.g. a curvilinear | ||
| # grid with 2D lat/lon) should be supported | ||
| lat = np.array([[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]]) | ||
| lon = np.array([[20.0, 21.0, 22.0], [23.0, 24.0, 25.0]]) | ||
| da = xr.DataArray( | ||
| np.arange(6, dtype=float).reshape(2, 3), | ||
| dims=["x", "y"], | ||
| coords={"lat": (["x", "y"], lat), "lon": (["x", "y"], lon)}, | ||
| name="data", | ||
| ) | ||
| table = pa.table(da) | ||
|
|
||
| assert isinstance(table, pa.Table) | ||
| assert set(table.column_names) == {"lat", "lon", "data"} | ||
| assert table.num_rows == 6 | ||
| assert table.schema.field("lat").type == pa.float64() | ||
| assert table.schema.field("lon").type == pa.float64() | ||
| assert table.schema.field("data").type == pa.float64() | ||
| np.testing.assert_array_equal(table["lat"].to_pylist(), lat.ravel()) | ||
| np.testing.assert_array_equal(table["lon"].to_pylist(), lon.ravel()) | ||
| np.testing.assert_array_equal( | ||
| table["data"].to_pylist(), np.arange(6, dtype=float) | ||
| ) | ||
|
|
||
| @requires_pyarrow | ||
| def test_pyarrow_table_transposed_coords(self): | ||
| import pyarrow as pa | ||
|
|
||
| lat = np.array([[10.0, 11.0], [12.0, 13.0], [14.0, 15.0]]) | ||
|
|
||
| # Array with swapped dims order | ||
| da = xr.DataArray( | ||
| np.arange(6, dtype=float).reshape(2, 3), | ||
| dims=["x", "y"], | ||
| coords={"lat": (["y", "x"], lat)}, | ||
| name="data", | ||
| ) | ||
| table = pa.table(da) | ||
|
|
||
| assert isinstance(table, pa.Table) | ||
| assert set(table.column_names) == {"lat", "data"} | ||
| assert table.num_rows == 6 | ||
| assert table.schema.field("lat").type == pa.float64() | ||
| np.testing.assert_array_equal(table["lat"].to_pylist(), lat.T.ravel()) | ||
| np.testing.assert_array_equal( | ||
| table["data"].to_pylist(), np.arange(6, dtype=float) | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.