diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c435b9c042a..0c875d2bac9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,7 +13,10 @@ v2026.05.0 (unreleased) New Features ~~~~~~~~~~~~ - +- Added `PyArrowCapsule interface `_ + to :py:class:`DataArray` (``__arrow_c_schema__`` and ``__arrow_c_stream__``), enabling near zero-copy + export to pyarrow, polars or duckdb. + By `Jules Chéron `_. Breaking Changes ~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index 640e4d3846c..d54e639a273 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ io = [ "cftime", "pooch", ] +arrow = ["pyarrow"] etc = ["sparse>=0.15"] parallel = ["dask[complete]"] viz = ["cartopy>=0.24", "matplotlib>=3.10", "nc-time-axis", "seaborn"] @@ -157,6 +158,7 @@ module = [ "opt_einsum.*", "pint.*", "pooch.*", + "polars.*", "pyarrow.*", "pydap.*", "seaborn.*", diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d0df9bc061b..9050583f987 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2,6 +2,7 @@ import copy import datetime +import json import warnings from collections.abc import ( Callable, @@ -255,6 +256,21 @@ def __setitem__(self, key, value) -> None: _THIS_ARRAY = ReprObject("") +class _NumpyEncoder(json.JSONEncoder): + """Encode numpy value in Arrow schema metadata""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.bool_): + return bool(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + class DataArray( AbstractArray, DataWithCoords, @@ -477,6 +493,94 @@ def __init__( self._close = None + def __arrow_c_schema__(self): + try: + import pyarrow as pa + except ImportError: + raise ImportError( + "pyarrow is required to export via the Arrow PyCapsule Interface." + ) from None + + values_column = self.name or "values" + + fields = [] + for name, coord in self._coords.items(): + arrow_dtype = pa.from_numpy_dtype(coord.dtype) + fields.append(pa.field(str(name), arrow_dtype)) + + fields.append( + pa.field(str(values_column), pa.from_numpy_dtype(self._variable.dtype)) + ) + + xarray_metadata = { + "name": self.name, + "dims": list(self._variable.dims), + "shape": list(self._variable.shape), + "attrs": self.attrs, + "coords": { + str(name): variable.to_dict(data=False) + for name, variable in self._coords.items() + }, + } + schema_metadata = { + b"xarray:arrow_schema_version": b"v1", + b"xarray": json.dumps(xarray_metadata, cls=_NumpyEncoder).encode(), + } + + schema = pa.schema(fields, metadata=schema_metadata) + return schema.__arrow_c_schema__() + + def __arrow_c_stream__(self, requested_schema: Any = None) -> Any: + """Export the DataArray through the Arrow PyCapsule Interface. + + https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html + """ + try: + import pyarrow as pa + except ImportError: + raise ImportError( + "pyarrow is required to export via the Arrow PyCapsule Interface." + ) from None + + values = self._variable.values + dims = self._variable.dims + shape = self._variable.shape + + values_column = self.name or "values" + + if not values.flags.c_contiguous: + values = np.ascontiguousarray(values) + + columns: dict[Hashable, pa.Array] = {} + for name, coord in self._coords.items(): + # Broadcast each coordinate up to the full data shape so that 1D + # dimension coordinates and N-D (e.g. curvilinear) coordinates + # flatten consistently with the data values. + + # Order axes based on Variable dims + dim_order = [coord.dims.index(dim) for dim in dims if dim in coord.dims] + + # Reorder coords values to variable dim order + ordered_coords = coord.values.transpose(dim_order) + + # Expand coord dims + # coord dims (x, y) variable dims (x,y,z) -> (x, y, 1) + # NOTE: Insert a length-1 axis for each data dim missing for coordinates + # (slice(None) keeps an existing axis, np.newaxis adds one) + indexer = tuple( + slice(None) if dim in coord.dims else np.newaxis for dim in dims + ) + expanded_coords = ordered_coords[indexer] + + # Broadcast to full flattened shape (x, y, 1) -> (x, y, z) + broadcasted = np.broadcast_to(expanded_coords, shape) + columns[name] = pa.array(np.ravel(broadcasted)) + + columns[values_column] = pa.array(np.ravel(values)) + + table = pa.table(columns, schema=pa.schema(self)) + return table.__arrow_c_stream__(requested_schema) + @classmethod def _construct_direct( cls, diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 2b1d276bb1a..82477e0d708 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -156,6 +156,7 @@ def _importorskip( has_iris, requires_iris = _importorskip("iris") has_numbagg, requires_numbagg = _importorskip("numbagg") has_pyarrow, requires_pyarrow = _importorskip("pyarrow") +has_polars, requires_polars = _importorskip("polars") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8eb52046a31..26c2bc19712 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -59,6 +59,7 @@ requires_iris, requires_numexpr, requires_pint, + requires_polars, requires_pyarrow, requires_scipy, requires_sparse, @@ -7673,3 +7674,229 @@ def test_unstack_index_var() -> None: name="x", ) assert_identical(actual, expected) + + +class TestArrowPyCapsule: + @requires_pyarrow + def test_pyarrow_table_1d(self): + import pyarrow as pa + + da = xr.DataArray( + [1.0, 2.0, 3.0], + dims=["x"], + coords={"x": [10, 20, 30]}, + name="temperature", + ) + table = pa.table(da) + + assert isinstance(table, pa.Table) + assert set(table.column_names) == {"x", "temperature"} + assert table.num_rows == 3 + assert table.schema.field("x").type == pa.int64() + assert table.schema.field("temperature").type == pa.float64() + np.testing.assert_array_equal(table["x"].to_pylist(), [10, 20, 30]) + np.testing.assert_array_equal(table["temperature"].to_pylist(), [1.0, 2.0, 3.0]) + + @requires_pyarrow + def test_pyarrow_table_2d(self): + import pyarrow as pa + + da = xr.DataArray( + np.arange(6, dtype=float).reshape(2, 3), + dims=["x", "y"], + coords={"x": [0, 1], "y": [10, 20, 30]}, + name="data", + ) + table = pa.table(da) + + assert isinstance(table, pa.Table) + assert set(table.column_names) == {"x", "y", "data"} + assert table.num_rows == 6 + assert table.schema.field("x").type == pa.int64() + assert table.schema.field("y").type == pa.int64() + assert table.schema.field("data").type == pa.float64() + np.testing.assert_array_equal( + table["data"].to_pylist(), list(np.arange(6, dtype=float)) + ) + + @requires_pyarrow + def test_data_array_unnamed_variable(self): + import pyarrow as pa + + da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [0, 1, 2]}) + table = pa.table(da) + + assert "values" in table.column_names + + @requires_polars + def test_polars_dataframe_1d(self): + import polars as pl + + da = xr.DataArray( + [1.0, 2.0, 3.0], + dims=["x"], + coords={"x": [10, 20, 30]}, + name="temperature", + ) + df = pl.from_arrow(da) + + assert isinstance(df, pl.DataFrame) + assert set(df.columns) == {"x", "temperature"} + assert len(df) == 3 + np.testing.assert_array_equal(df["x"].to_list(), [10, 20, 30]) + np.testing.assert_array_equal(df["temperature"].to_list(), [1.0, 2.0, 3.0]) + + @requires_polars + def test_polars_dataframe_2d(self): + import polars as pl + + da = xr.DataArray( + np.arange(6, dtype=float).reshape(2, 3), + dims=["x", "y"], + coords={"x": [0, 1], "y": [10, 20, 30]}, + name="data", + ) + df = pl.from_arrow(da) + + assert isinstance(df, pl.DataFrame) + assert set(df.columns) == {"x", "y", "data"} + assert len(df) == 6 + np.testing.assert_array_equal( + df["data"].to_list(), list(np.arange(6, dtype=float)) + ) + # x repeats for each y: [0,0,0,1,1,1] + np.testing.assert_array_equal(df["x"].to_list(), [0, 0, 0, 1, 1, 1]) + # y cycles for each x: [10,20,30,10,20,30] + np.testing.assert_array_equal(df["y"].to_list(), [10, 20, 30, 10, 20, 30]) + + @requires_dask + @requires_pyarrow + def test_dask_dataarray(self): + import dask.array as da + import pyarrow as pa + + dask_da = xr.DataArray( + da.from_array(np.arange(6, dtype=float).reshape(2, 3)), + dims=["x", "y"], + coords={"x": [0, 1], "y": [10, 20, 30]}, + name="data", + ) + + table = pa.table(dask_da) + assert isinstance(table, pa.Table) + np.testing.assert_array_equal( + table["data"].to_pylist(), list(np.arange(6, dtype=float)) + ) + + @requires_polars + @requires_pyarrow + def test_polars_pyarrow_consistent(self): + import polars as pl + import pyarrow as pa + + da = xr.DataArray( + np.arange(6, dtype=float).reshape(2, 3), + dims=["x", "y"], + coords={"x": [0, 1], "y": [10, 20, 30]}, + name="data", + ) + pa_table = pa.table(da) + pl_df = pl.from_arrow(da) + + for col in pa_table.column_names: + np.testing.assert_array_equal( + pa_table[col].to_pylist(), pl_df[col].to_list() + ) + + @requires_pyarrow + def test_arrow_schema_fields(self): + import pyarrow as pa + + da = xr.DataArray( + np.arange(6, dtype=float).reshape(2, 3), + dims=["x", "y"], + coords={"x": [0, 1], "y": [10, 20, 30]}, + name="data", + ) + schema = pa.schema(da) + + assert isinstance(schema, pa.Schema) + assert schema.names == ["x", "y", "data"] + assert schema.field("x").type == pa.int64() + assert schema.field("y").type == pa.int64() + assert schema.field("data").type == pa.float64() + + @requires_pyarrow + def test_arrow_schema_metadata(self): + import json + + import pyarrow as pa + + da = xr.DataArray( + [1.0, 2.0, 3.0], + dims=["x"], + coords={"x": [10, 20, 30]}, + name="temperature", + attrs={"units": "K", "long_name": "temperature"}, + ) + schema = pa.schema(da) + + assert schema.metadata[b"xarray:arrow_schema_version"] == b"v1" + + xarray_meta = json.loads(schema.metadata[b"xarray"]) + assert xarray_meta["name"] == "temperature" + assert xarray_meta["dims"] == ["x"] + assert xarray_meta["attrs"] == {"units": "K", "long_name": "temperature"} + assert "x" in xarray_meta["coords"] + + @requires_pyarrow + def test_pyarrow_table_curvilinear_coords(self): + import pyarrow as pa + + # non-dimension coordinates spanning multiple dims (e.g. a curvilinear + # grid with 2D lat/lon) should be supported + lat = np.array([[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]]) + lon = np.array([[20.0, 21.0, 22.0], [23.0, 24.0, 25.0]]) + da = xr.DataArray( + np.arange(6, dtype=float).reshape(2, 3), + dims=["x", "y"], + coords={"lat": (["x", "y"], lat), "lon": (["x", "y"], lon)}, + name="data", + ) + table = pa.table(da) + + assert isinstance(table, pa.Table) + assert set(table.column_names) == {"lat", "lon", "data"} + assert table.num_rows == 6 + assert table.schema.field("lat").type == pa.float64() + assert table.schema.field("lon").type == pa.float64() + assert table.schema.field("data").type == pa.float64() + np.testing.assert_array_equal(table["lat"].to_pylist(), lat.ravel()) + np.testing.assert_array_equal(table["lon"].to_pylist(), lon.ravel()) + np.testing.assert_array_equal( + table["data"].to_pylist(), np.arange(6, dtype=float) + ) + + @requires_pyarrow + def test_pyarrow_table_transposed_coords(self): + import pyarrow as pa + + lat = np.array([[10.0, 11.0], [12.0, 13.0], [14.0, 15.0]]) + + # Array with swapped dims order + da = xr.DataArray( + np.arange(6, dtype=float).reshape(2, 3), + dims=["x", "y"], + coords={"lat": (["y", "x"], lat)}, + name="data", + ) + table = pa.table(da) + + assert isinstance(table, pa.Table) + assert set(table.column_names) == {"lat", "data"} + assert table.num_rows == 6 + assert table.schema.field("lat").type == pa.float64() + np.testing.assert_array_equal(table["lat"].to_pylist(), lat.T.ravel()) + np.testing.assert_array_equal( + table["data"].to_pylist(), np.arange(6, dtype=float) + )