Skip to content

Commit

Permalink
Fix DataTree repr to not repeat inherited coordinates (pydata#9532)
Browse files Browse the repository at this point in the history
* Fix DataTree repr to not repeat inherited coordinates

Fixes GH9499

* skip failing test on Windows
  • Loading branch information
shoyer authored Sep 22, 2024
1 parent 2a6212e commit ab84e04
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 2 deletions.
3 changes: 2 additions & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,8 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:
summary.append(f"{dims_start}({dims_values})")

if node._node_coord_variables:
summary.append(coords_repr(node.coords, col_width=col_width, max_rows=max_rows))
node_coords = node.to_dataset(inherited=False).coords
summary.append(coords_repr(node_coords, col_width=col_width, max_rows=max_rows))

if show_inherited and inherited_coords:
summary.append(
Expand Down
107 changes: 106 additions & 1 deletion xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import sys
import typing
from copy import copy, deepcopy
from textwrap import dedent
Expand All @@ -15,6 +16,8 @@
from xarray.testing import assert_equal, assert_identical
from xarray.tests import assert_array_equal, create_test_data, source_ndarray

ON_WINDOWS = sys.platform == "win32"


class TestTreeCreation:
def test_empty(self):
Expand Down Expand Up @@ -1052,7 +1055,7 @@ def test_repr_two_children(self):
{
"/": Dataset(coords={"x": [1.0]}),
"/first_child": None,
"/second_child": Dataset({"foo": ("x", [0.0])}),
"/second_child": Dataset({"foo": ("x", [0.0])}, coords={"z": 1.0}),
}
)

Expand All @@ -1067,6 +1070,8 @@ def test_repr_two_children(self):
├── Group: /first_child
└── Group: /second_child
Dimensions: (x: 1)
Coordinates:
z float64 8B 1.0
Data variables:
foo (x) float64 8B 0.0
"""
Expand All @@ -1091,6 +1096,8 @@ def test_repr_two_children(self):
<xarray.DataTree 'second_child'>
Group: /second_child
Dimensions: (x: 1)
Coordinates:
z float64 8B 1.0
Inherited coordinates:
* x (x) float64 8B 1.0
Data variables:
Expand Down Expand Up @@ -1138,6 +1145,104 @@ def test_repr_inherited_dims(self):
).strip()
assert result == expected

@pytest.mark.skipif(
ON_WINDOWS, reason="windows (pre NumPy2) uses int32 instead of int64"
)
def test_doc_example(self):
# regression test for https://github.com/pydata/xarray/issues/9499
time = xr.DataArray(data=["2022-01", "2023-01"], dims="time")
stations = xr.DataArray(data=list("abcdef"), dims="station")
lon = [-100, -80, -60]
lat = [10, 20, 30]
# Set up fake data
wind_speed = xr.DataArray(np.ones((2, 6)) * 2, dims=("time", "station"))
pressure = xr.DataArray(np.ones((2, 6)) * 3, dims=("time", "station"))
air_temperature = xr.DataArray(np.ones((2, 6)) * 4, dims=("time", "station"))
dewpoint = xr.DataArray(np.ones((2, 6)) * 5, dims=("time", "station"))
infrared = xr.DataArray(np.ones((2, 3, 3)) * 6, dims=("time", "lon", "lat"))
true_color = xr.DataArray(np.ones((2, 3, 3)) * 7, dims=("time", "lon", "lat"))
tree = xr.DataTree.from_dict(
{
"/": xr.Dataset(
coords={"time": time},
),
"/weather": xr.Dataset(
coords={"station": stations},
data_vars={
"wind_speed": wind_speed,
"pressure": pressure,
},
),
"/weather/temperature": xr.Dataset(
data_vars={
"air_temperature": air_temperature,
"dewpoint": dewpoint,
},
),
"/satellite": xr.Dataset(
coords={"lat": lat, "lon": lon},
data_vars={
"infrared": infrared,
"true_color": true_color,
},
),
},
)

result = repr(tree)
expected = dedent(
"""
<xarray.DataTree>
Group: /
│ Dimensions: (time: 2)
│ Coordinates:
│ * time (time) <U7 56B '2022-01' '2023-01'
├── Group: /weather
│ │ Dimensions: (station: 6, time: 2)
│ │ Coordinates:
│ │ * station (station) <U1 24B 'a' 'b' 'c' 'd' 'e' 'f'
│ │ Data variables:
│ │ wind_speed (time, station) float64 96B 2.0 2.0 2.0 2.0 ... 2.0 2.0 2.0 2.0
│ │ pressure (time, station) float64 96B 3.0 3.0 3.0 3.0 ... 3.0 3.0 3.0 3.0
│ └── Group: /weather/temperature
│ Dimensions: (time: 2, station: 6)
│ Data variables:
│ air_temperature (time, station) float64 96B 4.0 4.0 4.0 4.0 ... 4.0 4.0 4.0
│ dewpoint (time, station) float64 96B 5.0 5.0 5.0 5.0 ... 5.0 5.0 5.0
└── Group: /satellite
Dimensions: (lat: 3, lon: 3, time: 2)
Coordinates:
* lat (lat) int64 24B 10 20 30
* lon (lon) int64 24B -100 -80 -60
Data variables:
infrared (time, lon, lat) float64 144B 6.0 6.0 6.0 6.0 ... 6.0 6.0 6.0
true_color (time, lon, lat) float64 144B 7.0 7.0 7.0 7.0 ... 7.0 7.0 7.0
"""
).strip()
assert result == expected

result = repr(tree["weather"])
expected = dedent(
"""
<xarray.DataTree 'weather'>
Group: /weather
│ Dimensions: (time: 2, station: 6)
│ Coordinates:
│ * station (station) <U1 24B 'a' 'b' 'c' 'd' 'e' 'f'
│ Inherited coordinates:
│ * time (time) <U7 56B '2022-01' '2023-01'
│ Data variables:
│ wind_speed (time, station) float64 96B 2.0 2.0 2.0 2.0 ... 2.0 2.0 2.0 2.0
│ pressure (time, station) float64 96B 3.0 3.0 3.0 3.0 ... 3.0 3.0 3.0 3.0
└── Group: /weather/temperature
Dimensions: (time: 2, station: 6)
Data variables:
air_temperature (time, station) float64 96B 4.0 4.0 4.0 4.0 ... 4.0 4.0 4.0
dewpoint (time, station) float64 96B 5.0 5.0 5.0 5.0 ... 5.0 5.0 5.0
"""
).strip()
assert result == expected


def _exact_match(message: str) -> str:
return re.escape(dedent(message).strip())
Expand Down

0 comments on commit ab84e04

Please sign in to comment.