1515class TestXpxFunction :
1616 """Tests for the xpx() function."""
1717
18- def test_xpx_returns_accessor (self ) -> None :
19- """Test that xpx() returns a DataArrayPlotlyAccessor."""
18+ def test_xpx_returns_dataarray_accessor (self ) -> None :
19+ """Test that xpx() returns a DataArrayPlotlyAccessor for DataArray ."""
2020 da = xr .DataArray (np .random .rand (10 ), dims = ["time" ])
2121 accessor = xpx (da )
2222 assert hasattr (accessor , "line" )
2323 assert hasattr (accessor , "bar" )
2424 assert hasattr (accessor , "scatter" )
25+ assert hasattr (accessor , "imshow" )
2526
26- def test_xpx_equivalent_to_accessor (self ) -> None :
27+ def test_xpx_returns_dataset_accessor (self ) -> None :
28+ """Test that xpx() returns a DatasetPlotlyAccessor for Dataset."""
29+ ds = xr .Dataset ({"temp" : (["time" ], np .random .rand (10 ))})
30+ accessor = xpx (ds )
31+ assert hasattr (accessor , "line" )
32+ assert hasattr (accessor , "bar" )
33+ assert hasattr (accessor , "scatter" )
34+ # Dataset accessor should not have imshow
35+ assert not hasattr (accessor , "imshow" )
36+
37+ def test_xpx_dataarray_equivalent_to_accessor (self ) -> None :
2738 """Test that xpx(da).line() works the same as da.plotly.line()."""
2839 da = xr .DataArray (
2940 np .random .rand (10 , 3 ),
@@ -36,6 +47,19 @@ def test_xpx_equivalent_to_accessor(self) -> None:
3647 assert isinstance (fig1 , go .Figure )
3748 assert isinstance (fig2 , go .Figure )
3849
50+ def test_xpx_dataset_equivalent_to_accessor (self ) -> None :
51+ """Test that xpx(ds).line() works the same as ds.plotly.line()."""
52+ ds = xr .Dataset (
53+ {
54+ "temperature" : (["time" , "city" ], np .random .rand (10 , 3 )),
55+ "humidity" : (["time" , "city" ], np .random .rand (10 , 3 )),
56+ }
57+ )
58+ fig1 = xpx (ds ).line ()
59+ fig2 = ds .plotly .line ()
60+ assert isinstance (fig1 , go .Figure )
61+ assert isinstance (fig2 , go .Figure )
62+
3963
4064class TestDataArrayPxplot :
4165 """Tests for DataArray.plotly accessor."""
@@ -206,3 +230,65 @@ def test_value_label_from_attrs(self) -> None:
206230 """Test that value labels are extracted from attributes."""
207231 fig = self .da .plotly .line ()
208232 assert isinstance (fig , go .Figure )
233+
234+
235+ class TestDatasetPlotlyAccessor :
236+ """Tests for Dataset.plotly accessor."""
237+
238+ @pytest .fixture (autouse = True )
239+ def setup (self ) -> None :
240+ """Set up test data."""
241+ self .ds = xr .Dataset (
242+ {
243+ "temperature" : (["time" , "city" ], np .random .rand (10 , 3 )),
244+ "humidity" : (["time" , "city" ], np .random .rand (10 , 3 )),
245+ },
246+ coords = {
247+ "time" : pd .date_range ("2020" , periods = 10 ),
248+ "city" : ["NYC" , "LA" , "Chicago" ],
249+ },
250+ )
251+
252+ def test_accessor_exists (self ) -> None :
253+ """Test that plotly accessor is available on Dataset."""
254+ assert hasattr (self .ds , "plotly" )
255+ assert hasattr (self .ds .plotly , "line" )
256+ assert hasattr (self .ds .plotly , "bar" )
257+ assert hasattr (self .ds .plotly , "area" )
258+ assert hasattr (self .ds .plotly , "scatter" )
259+ assert hasattr (self .ds .plotly , "box" )
260+
261+ def test_line_all_variables (self ) -> None :
262+ """Test line plot with all variables."""
263+ fig = self .ds .plotly .line ()
264+ assert isinstance (fig , go .Figure )
265+
266+ def test_line_single_variable (self ) -> None :
267+ """Test line plot with single variable."""
268+ fig = self .ds .plotly .line (var = "temperature" )
269+ assert isinstance (fig , go .Figure )
270+
271+ def test_line_variable_as_facet (self ) -> None :
272+ """Test line plot with variable as facet."""
273+ fig = self .ds .plotly .line (facet_col = "variable" )
274+ assert isinstance (fig , go .Figure )
275+
276+ def test_bar_all_variables (self ) -> None :
277+ """Test bar plot with all variables."""
278+ fig = self .ds .plotly .bar ()
279+ assert isinstance (fig , go .Figure )
280+
281+ def test_area_all_variables (self ) -> None :
282+ """Test area plot with all variables."""
283+ fig = self .ds .plotly .area ()
284+ assert isinstance (fig , go .Figure )
285+
286+ def test_scatter_all_variables (self ) -> None :
287+ """Test scatter plot with all variables."""
288+ fig = self .ds .plotly .scatter ()
289+ assert isinstance (fig , go .Figure )
290+
291+ def test_box_all_variables (self ) -> None :
292+ """Test box plot with all variables."""
293+ fig = self .ds .plotly .box ()
294+ assert isinstance (fig , go .Figure )
0 commit comments