Skip to content

Commit ec7959d

Browse files
committed
Add dataset plotting accessor
1 parent 23ce8ff commit ec7959d

File tree

3 files changed

+373
-14
lines changed

3 files changed

+373
-14
lines changed

tests/test_accessor.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,26 @@
1515
class TestXpxFunction:
1616
"""Tests for the xpx() function."""
1717

18-
def test_xpx_returns_accessor(self) -> None:
19-
"""Test that xpx() returns a DataArrayPlotlyAccessor."""
18+
def test_xpx_returns_dataarray_accessor(self) -> None:
19+
"""Test that xpx() returns a DataArrayPlotlyAccessor for DataArray."""
2020
da = xr.DataArray(np.random.rand(10), dims=["time"])
2121
accessor = xpx(da)
2222
assert hasattr(accessor, "line")
2323
assert hasattr(accessor, "bar")
2424
assert hasattr(accessor, "scatter")
25+
assert hasattr(accessor, "imshow")
2526

26-
def test_xpx_equivalent_to_accessor(self) -> None:
27+
def test_xpx_returns_dataset_accessor(self) -> None:
28+
"""Test that xpx() returns a DatasetPlotlyAccessor for Dataset."""
29+
ds = xr.Dataset({"temp": (["time"], np.random.rand(10))})
30+
accessor = xpx(ds)
31+
assert hasattr(accessor, "line")
32+
assert hasattr(accessor, "bar")
33+
assert hasattr(accessor, "scatter")
34+
# Dataset accessor should not have imshow
35+
assert not hasattr(accessor, "imshow")
36+
37+
def test_xpx_dataarray_equivalent_to_accessor(self) -> None:
2738
"""Test that xpx(da).line() works the same as da.plotly.line()."""
2839
da = xr.DataArray(
2940
np.random.rand(10, 3),
@@ -36,6 +47,19 @@ def test_xpx_equivalent_to_accessor(self) -> None:
3647
assert isinstance(fig1, go.Figure)
3748
assert isinstance(fig2, go.Figure)
3849

50+
def test_xpx_dataset_equivalent_to_accessor(self) -> None:
51+
"""Test that xpx(ds).line() works the same as ds.plotly.line()."""
52+
ds = xr.Dataset(
53+
{
54+
"temperature": (["time", "city"], np.random.rand(10, 3)),
55+
"humidity": (["time", "city"], np.random.rand(10, 3)),
56+
}
57+
)
58+
fig1 = xpx(ds).line()
59+
fig2 = ds.plotly.line()
60+
assert isinstance(fig1, go.Figure)
61+
assert isinstance(fig2, go.Figure)
62+
3963

4064
class TestDataArrayPxplot:
4165
"""Tests for DataArray.plotly accessor."""
@@ -206,3 +230,65 @@ def test_value_label_from_attrs(self) -> None:
206230
"""Test that value labels are extracted from attributes."""
207231
fig = self.da.plotly.line()
208232
assert isinstance(fig, go.Figure)
233+
234+
235+
class TestDatasetPlotlyAccessor:
236+
"""Tests for Dataset.plotly accessor."""
237+
238+
@pytest.fixture(autouse=True)
239+
def setup(self) -> None:
240+
"""Set up test data."""
241+
self.ds = xr.Dataset(
242+
{
243+
"temperature": (["time", "city"], np.random.rand(10, 3)),
244+
"humidity": (["time", "city"], np.random.rand(10, 3)),
245+
},
246+
coords={
247+
"time": pd.date_range("2020", periods=10),
248+
"city": ["NYC", "LA", "Chicago"],
249+
},
250+
)
251+
252+
def test_accessor_exists(self) -> None:
253+
"""Test that plotly accessor is available on Dataset."""
254+
assert hasattr(self.ds, "plotly")
255+
assert hasattr(self.ds.plotly, "line")
256+
assert hasattr(self.ds.plotly, "bar")
257+
assert hasattr(self.ds.plotly, "area")
258+
assert hasattr(self.ds.plotly, "scatter")
259+
assert hasattr(self.ds.plotly, "box")
260+
261+
def test_line_all_variables(self) -> None:
262+
"""Test line plot with all variables."""
263+
fig = self.ds.plotly.line()
264+
assert isinstance(fig, go.Figure)
265+
266+
def test_line_single_variable(self) -> None:
267+
"""Test line plot with single variable."""
268+
fig = self.ds.plotly.line(var="temperature")
269+
assert isinstance(fig, go.Figure)
270+
271+
def test_line_variable_as_facet(self) -> None:
272+
"""Test line plot with variable as facet."""
273+
fig = self.ds.plotly.line(facet_col="variable")
274+
assert isinstance(fig, go.Figure)
275+
276+
def test_bar_all_variables(self) -> None:
277+
"""Test bar plot with all variables."""
278+
fig = self.ds.plotly.bar()
279+
assert isinstance(fig, go.Figure)
280+
281+
def test_area_all_variables(self) -> None:
282+
"""Test area plot with all variables."""
283+
fig = self.ds.plotly.area()
284+
assert isinstance(fig, go.Figure)
285+
286+
def test_scatter_all_variables(self) -> None:
287+
"""Test scatter plot with all variables."""
288+
fig = self.ds.plotly.scatter()
289+
assert isinstance(fig, go.Figure)
290+
291+
def test_box_all_variables(self) -> None:
292+
"""Test box plot with all variables."""
293+
fig = self.ds.plotly.box()
294+
assert isinstance(fig, go.Figure)

xarray_plotly/__init__.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Interactive Plotly Express plotting for xarray.
22
3-
This package provides a `plotly` accessor for xarray DataArray objects,
3+
This package provides a `plotly` accessor for xarray DataArray and Dataset objects,
44
enabling interactive visualization with Plotly Express.
55
66
Features:
77
- **Interactive plots**: Zoom, pan, hover, toggle traces
88
- **Automatic dimension assignment**: Dimensions fill slots (x, color, facet) by position
99
- **Multiple plot types**: line, bar, area, scatter, box, imshow
10+
- **Dataset support**: Plot all variables at once with "variable" dimension
1011
- **Faceting and animation**: Built-in subplot grids and animated plots
1112
- **Customizable**: Returns Plotly Figure objects for further modification
1213
@@ -15,11 +16,13 @@
1516
1617
import xarray_plotly
1718
fig = da.plotly.line()
19+
fig = ds.plotly.line() # Dataset: all variables
1820
1921
Function style (recommended for IDE completion)::
2022
2123
from xarray_plotly import xpx
2224
fig = xpx(da).line()
25+
fig = xpx(ds).line() # Dataset: all variables
2326
2427
Example:
2528
```python
@@ -34,48 +37,72 @@
3437
fig = xpx(da).line() # Auto: time->x, city->color, scenario->facet_col
3538
fig = xpx(da).line(x="time", color="scenario") # Explicit
3639
fig = xpx(da).line(color=None) # Skip slot
40+
41+
# Dataset: plot all variables (accessor or xpx)
42+
ds = xr.Dataset({"temp": da, "precip": da})
43+
fig = xpx(ds).line() # "variable" dimension for color
44+
fig = xpx(ds).line(facet_col="variable") # Facet by variable
3745
```
3846
"""
3947

4048
from importlib.metadata import version
49+
from typing import overload
4150

42-
from xarray import DataArray, register_dataarray_accessor
51+
from xarray import DataArray, Dataset, register_dataarray_accessor, register_dataset_accessor
4352

4453
from xarray_plotly import config
45-
from xarray_plotly.accessor import DataArrayPlotlyAccessor
54+
from xarray_plotly.accessor import DataArrayPlotlyAccessor, DatasetPlotlyAccessor
4655
from xarray_plotly.common import SLOT_ORDERS, auto
4756

4857
__all__ = [
4958
"SLOT_ORDERS",
5059
"DataArrayPlotlyAccessor",
60+
"DatasetPlotlyAccessor",
5161
"auto",
5262
"config",
5363
"xpx",
5464
]
5565

5666

57-
def xpx(da: DataArray) -> DataArrayPlotlyAccessor:
58-
"""Get the plotly accessor for a DataArray with full IDE code completion.
67+
@overload
68+
def xpx(data: DataArray) -> DataArrayPlotlyAccessor: ...
69+
70+
71+
@overload
72+
def xpx(data: Dataset) -> DatasetPlotlyAccessor: ...
73+
5974

60-
This is an alternative to `da.plotly` that provides proper type hints
75+
def xpx(data: DataArray | Dataset) -> DataArrayPlotlyAccessor | DatasetPlotlyAccessor:
76+
"""Get the plotly accessor for a DataArray or Dataset with full IDE code completion.
77+
78+
This is an alternative to `da.plotly` / `ds.plotly` that provides proper type hints
6179
and code completion in IDEs.
6280
6381
Args:
64-
da: The DataArray to plot.
82+
data: The DataArray or Dataset to plot.
6583
6684
Returns:
6785
The accessor with plotting methods (line, bar, area, scatter, box, imshow).
6886
6987
Example:
7088
```python
7189
from xarray_plotly import xpx
90+
91+
# DataArray
7292
fig = xpx(da).line() # Full code completion works here
93+
94+
# Dataset
95+
fig = xpx(ds).line() # Plots all variables
96+
fig = xpx(ds).line(var="temperature") # Single variable
7397
```
7498
"""
75-
return DataArrayPlotlyAccessor(da)
99+
if isinstance(data, Dataset):
100+
return DatasetPlotlyAccessor(data)
101+
return DataArrayPlotlyAccessor(data)
76102

77103

78104
__version__ = version("xarray_plotly")
79105

80-
# Register the accessor
106+
# Register the accessors
81107
register_dataarray_accessor("plotly")(DataArrayPlotlyAccessor)
108+
register_dataset_accessor("plotly")(DatasetPlotlyAccessor)

0 commit comments

Comments
 (0)