Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ v2026.05.0 (unreleased)

New Features
~~~~~~~~~~~~

- Added `PyArrowCapsule interface <https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html#arrow-pycapsule-interface>`_
to :py:class:`DataArray` (``__arrow_c_schema__`` and ``__arrow_c_stream__``), enabling near zero-copy
export to pyarrow, polars or duckdb.
By `Jules Chéron <https://github.com/jules-ch>`_.

Breaking Changes
~~~~~~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ io = [
"cftime",
"pooch",
]
arrow = ["pyarrow"]
etc = ["sparse>=0.15"]
parallel = ["dask[complete]"]
viz = ["cartopy>=0.24", "matplotlib>=3.10", "nc-time-axis", "seaborn"]
Expand Down Expand Up @@ -157,6 +158,7 @@ module = [
"opt_einsum.*",
"pint.*",
"pooch.*",
"polars.*",
"pyarrow.*",
"pydap.*",
"seaborn.*",
Expand Down
104 changes: 104 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import datetime
import json
import warnings
from collections.abc import (
Callable,
Expand Down Expand Up @@ -255,6 +256,21 @@ def __setitem__(self, key, value) -> None:
_THIS_ARRAY = ReprObject("<this-array>")


class _NumpyEncoder(json.JSONEncoder):
"""Encode numpy value in Arrow schema metadata"""

def default(self, obj: Any) -> Any:
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.bool_):
return bool(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)


class DataArray(
AbstractArray,
DataWithCoords,
Expand Down Expand Up @@ -477,6 +493,94 @@ def __init__(

self._close = None

def __arrow_c_schema__(self):
try:
import pyarrow as pa
except ImportError:
raise ImportError(
"pyarrow is required to export via the Arrow PyCapsule Interface."
) from None

values_column = self.name or "values"

fields = []
for name, coord in self._coords.items():
arrow_dtype = pa.from_numpy_dtype(coord.dtype)
fields.append(pa.field(str(name), arrow_dtype))
Comment thread
dcherian marked this conversation as resolved.

fields.append(
pa.field(str(values_column), pa.from_numpy_dtype(self._variable.dtype))
)

xarray_metadata = {
"name": self.name,
"dims": list(self._variable.dims),
"shape": list(self._variable.shape),
"attrs": self.attrs,
"coords": {
str(name): variable.to_dict(data=False)
for name, variable in self._coords.items()
},
Comment thread
dcherian marked this conversation as resolved.
}
schema_metadata = {
b"xarray:arrow_schema_version": b"v1",
b"xarray": json.dumps(xarray_metadata, cls=_NumpyEncoder).encode(),
}

schema = pa.schema(fields, metadata=schema_metadata)
return schema.__arrow_c_schema__()

def __arrow_c_stream__(self, requested_schema: Any = None) -> Any:
"""Export the DataArray through the Arrow PyCapsule Interface.

https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
"""
try:
import pyarrow as pa
except ImportError:
raise ImportError(
"pyarrow is required to export via the Arrow PyCapsule Interface."
) from None

values = self._variable.values
dims = self._variable.dims
shape = self._variable.shape

values_column = self.name or "values"

if not values.flags.c_contiguous:
values = np.ascontiguousarray(values)
Comment on lines +551 to +552

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can only use values.ravel down there to ensure contiguous array.


columns: dict[Hashable, pa.Array] = {}
for name, coord in self._coords.items():
# Broadcast each coordinate up to the full data shape so that 1D
# dimension coordinates and N-D (e.g. curvilinear) coordinates
# flatten consistently with the data values.

# Order axes based on Variable dims
dim_order = [coord.dims.index(dim) for dim in dims if dim in coord.dims]

# Reorder coords values to variable dim order
ordered_coords = coord.values.transpose(dim_order)

# Expand coord dims
# coord dims (x, y) variable dims (x,y,z) -> (x, y, 1)
# NOTE: Insert a length-1 axis for each data dim missing for coordinates
# (slice(None) keeps an existing axis, np.newaxis adds one)
indexer = tuple(
slice(None) if dim in coord.dims else np.newaxis for dim in dims
)
expanded_coords = ordered_coords[indexer]

# Broadcast to full flattened shape (x, y, 1) -> (x, y, z)
broadcasted = np.broadcast_to(expanded_coords, shape)
Comment on lines +560 to +576

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we share this with the to_dataframe path? alternatively, can we rewrite the to_dataframe path to use this function? (that would be a different PR)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit fearful of using this for now for theto_dataframe function, the path seems similar but the to Dataframe is using dataset conversion and then conversion to dataframe converting the coords to pandas MultiIndex.

I think another PR is the way to go, and I'm not sure we can factor everything here.

columns[name] = pa.array(np.ravel(broadcasted))

columns[values_column] = pa.array(np.ravel(values))

table = pa.table(columns, schema=pa.schema(self))
return table.__arrow_c_stream__(requested_schema)

@classmethod
def _construct_direct(
cls,
Expand Down
1 change: 1 addition & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def _importorskip(
has_iris, requires_iris = _importorskip("iris")
has_numbagg, requires_numbagg = _importorskip("numbagg")
has_pyarrow, requires_pyarrow = _importorskip("pyarrow")
has_polars, requires_polars = _importorskip("polars")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
227 changes: 227 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
requires_iris,
requires_numexpr,
requires_pint,
requires_polars,
requires_pyarrow,
requires_scipy,
requires_sparse,
Expand Down Expand Up @@ -7673,3 +7674,229 @@ def test_unstack_index_var() -> None:
name="x",
)
assert_identical(actual, expected)


class TestArrowPyCapsule:
Comment thread
dcherian marked this conversation as resolved.
@requires_pyarrow
def test_pyarrow_table_1d(self):
import pyarrow as pa

da = xr.DataArray(
[1.0, 2.0, 3.0],
dims=["x"],
coords={"x": [10, 20, 30]},
name="temperature",
)
table = pa.table(da)

assert isinstance(table, pa.Table)
assert set(table.column_names) == {"x", "temperature"}
assert table.num_rows == 3
assert table.schema.field("x").type == pa.int64()
assert table.schema.field("temperature").type == pa.float64()
np.testing.assert_array_equal(table["x"].to_pylist(), [10, 20, 30])
np.testing.assert_array_equal(table["temperature"].to_pylist(), [1.0, 2.0, 3.0])

@requires_pyarrow
def test_pyarrow_table_2d(self):
import pyarrow as pa

da = xr.DataArray(
np.arange(6, dtype=float).reshape(2, 3),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)
table = pa.table(da)

assert isinstance(table, pa.Table)
assert set(table.column_names) == {"x", "y", "data"}
assert table.num_rows == 6
assert table.schema.field("x").type == pa.int64()
assert table.schema.field("y").type == pa.int64()
assert table.schema.field("data").type == pa.float64()
np.testing.assert_array_equal(
table["data"].to_pylist(), list(np.arange(6, dtype=float))
)

@requires_pyarrow
def test_data_array_unnamed_variable(self):
import pyarrow as pa

da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [0, 1, 2]})
table = pa.table(da)

assert "values" in table.column_names

@requires_polars
def test_polars_dataframe_1d(self):
import polars as pl

da = xr.DataArray(
[1.0, 2.0, 3.0],
dims=["x"],
coords={"x": [10, 20, 30]},
name="temperature",
)
df = pl.from_arrow(da)

assert isinstance(df, pl.DataFrame)
assert set(df.columns) == {"x", "temperature"}
assert len(df) == 3
np.testing.assert_array_equal(df["x"].to_list(), [10, 20, 30])
np.testing.assert_array_equal(df["temperature"].to_list(), [1.0, 2.0, 3.0])

@requires_polars
def test_polars_dataframe_2d(self):
import polars as pl

da = xr.DataArray(
np.arange(6, dtype=float).reshape(2, 3),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)
df = pl.from_arrow(da)

assert isinstance(df, pl.DataFrame)
assert set(df.columns) == {"x", "y", "data"}
assert len(df) == 6
np.testing.assert_array_equal(
df["data"].to_list(), list(np.arange(6, dtype=float))
)
# x repeats for each y: [0,0,0,1,1,1]
np.testing.assert_array_equal(df["x"].to_list(), [0, 0, 0, 1, 1, 1])
# y cycles for each x: [10,20,30,10,20,30]
np.testing.assert_array_equal(df["y"].to_list(), [10, 20, 30, 10, 20, 30])

@requires_dask
@requires_pyarrow
def test_dask_dataarray(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the intent here? Does it load to memory?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to have a test for dask too, to ensure everything works as intented.

Also in a following PR we can implement chunking with RecordBatch for streaming dask chunks as RecordBatch for native arrow streaming.

I can remove it, if you find it out of scope.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes for now it loads into memory with the casting to numpy values and then conversion to pyarrow array before constructing the table with it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now it loads into memory

We should error and ask the user to compute instead. Xarray's does not compute by default.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You want explicit compute that's right ? to_dataframe also load data when using a dask backed-DataArray:

import dask.array as da
import pyarrow as pa

dask_da = xr.DataArray(
    da.from_array(np.arange(6, dtype=float).reshape(2, 3)),
    dims=["x", "y"],
    coords={"x": [0, 1], "y": [10, 20, 30]},
    name="data",
)
print(dask_da.to_dataframe())

I can raise an error if you want.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep I prefer the error. That behaviour seems like a bug and we should deprecate it

import dask.array as da
import pyarrow as pa

dask_da = xr.DataArray(
da.from_array(np.arange(6, dtype=float).reshape(2, 3)),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)

table = pa.table(dask_da)
assert isinstance(table, pa.Table)
np.testing.assert_array_equal(
table["data"].to_pylist(), list(np.arange(6, dtype=float))
)

@requires_polars
@requires_pyarrow
def test_polars_pyarrow_consistent(self):
import polars as pl
import pyarrow as pa

da = xr.DataArray(
np.arange(6, dtype=float).reshape(2, 3),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)
pa_table = pa.table(da)
pl_df = pl.from_arrow(da)

for col in pa_table.column_names:
np.testing.assert_array_equal(
pa_table[col].to_pylist(), pl_df[col].to_list()
)

@requires_pyarrow
def test_arrow_schema_fields(self):
import pyarrow as pa

da = xr.DataArray(
np.arange(6, dtype=float).reshape(2, 3),
dims=["x", "y"],
coords={"x": [0, 1], "y": [10, 20, 30]},
name="data",
)
schema = pa.schema(da)

assert isinstance(schema, pa.Schema)
assert schema.names == ["x", "y", "data"]
assert schema.field("x").type == pa.int64()
assert schema.field("y").type == pa.int64()
assert schema.field("data").type == pa.float64()

@requires_pyarrow
def test_arrow_schema_metadata(self):
import json

import pyarrow as pa

da = xr.DataArray(
[1.0, 2.0, 3.0],
dims=["x"],
coords={"x": [10, 20, 30]},
name="temperature",
attrs={"units": "K", "long_name": "temperature"},
)
schema = pa.schema(da)

assert schema.metadata[b"xarray:arrow_schema_version"] == b"v1"

xarray_meta = json.loads(schema.metadata[b"xarray"])
assert xarray_meta["name"] == "temperature"
assert xarray_meta["dims"] == ["x"]
assert xarray_meta["attrs"] == {"units": "K", "long_name": "temperature"}
assert "x" in xarray_meta["coords"]

@requires_pyarrow
def test_pyarrow_table_curvilinear_coords(self):
import pyarrow as pa

# non-dimension coordinates spanning multiple dims (e.g. a curvilinear
# grid with 2D lat/lon) should be supported
lat = np.array([[10.0, 11.0, 12.0], [13.0, 14.0, 15.0]])
lon = np.array([[20.0, 21.0, 22.0], [23.0, 24.0, 25.0]])
da = xr.DataArray(
np.arange(6, dtype=float).reshape(2, 3),
dims=["x", "y"],
coords={"lat": (["x", "y"], lat), "lon": (["x", "y"], lon)},
name="data",
)
table = pa.table(da)

assert isinstance(table, pa.Table)
assert set(table.column_names) == {"lat", "lon", "data"}
assert table.num_rows == 6
assert table.schema.field("lat").type == pa.float64()
assert table.schema.field("lon").type == pa.float64()
assert table.schema.field("data").type == pa.float64()
np.testing.assert_array_equal(table["lat"].to_pylist(), lat.ravel())
np.testing.assert_array_equal(table["lon"].to_pylist(), lon.ravel())
np.testing.assert_array_equal(
table["data"].to_pylist(), np.arange(6, dtype=float)
)

@requires_pyarrow
def test_pyarrow_table_transposed_coords(self):
import pyarrow as pa

lat = np.array([[10.0, 11.0], [12.0, 13.0], [14.0, 15.0]])

# Array with swapped dims order
da = xr.DataArray(
np.arange(6, dtype=float).reshape(2, 3),
dims=["x", "y"],
coords={"lat": (["y", "x"], lat)},
name="data",
)
table = pa.table(da)

assert isinstance(table, pa.Table)
assert set(table.column_names) == {"lat", "data"}
assert table.num_rows == 6
assert table.schema.field("lat").type == pa.float64()
np.testing.assert_array_equal(table["lat"].to_pylist(), lat.T.ravel())
np.testing.assert_array_equal(
table["data"].to_pylist(), np.arange(6, dtype=float)
)
Loading