Skip to content

Commit b248ae2

Browse files
committed
1. xarray_plotly/plotting.py: Added robust parameter to imshow() with global bounds computation
2. xarray_plotly/accessor.py: Added robust parameter to accessor method 3. tests/test_accessor.py: Added 4 tests for bounds behavior New behavior: - Default: Global min/max across all data (fixes animation consistency) - robust=True: Uses 2nd/98th percentile (handles outliers) - zmin/zmax: User override still works
1 parent afcc9b9 commit b248ae2

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

tests/test_accessor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,52 @@ def test_box_all_variables(self) -> None:
292292
"""Test box plot with all variables."""
293293
fig = self.ds.plotly.box()
294294
assert isinstance(fig, go.Figure)
295+
296+
297+
class TestImshowBounds:
298+
"""Tests for imshow global bounds and robust mode."""
299+
300+
def test_imshow_global_bounds(self) -> None:
301+
"""Test that imshow uses global min/max by default."""
302+
da = xr.DataArray(
303+
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 100]]]),
304+
dims=["time", "y", "x"],
305+
)
306+
fig = da.plotly.imshow(animation_frame="time")
307+
# Check coloraxis for zmin/zmax (plotly stores them there)
308+
coloraxis = fig.layout.coloraxis
309+
assert coloraxis.cmin == 1.0
310+
assert coloraxis.cmax == 100.0
311+
312+
def test_imshow_robust_bounds(self) -> None:
313+
"""Test that robust=True uses percentile-based bounds."""
314+
# Create data with outlier
315+
data = np.random.rand(10, 20) * 100
316+
data[0, 0] = 10000 # extreme outlier
317+
da = xr.DataArray(data, dims=["y", "x"])
318+
319+
fig = da.plotly.imshow(robust=True)
320+
# With robust=True, cmax should be much less than the outlier
321+
coloraxis = fig.layout.coloraxis
322+
assert coloraxis.cmax < 10000
323+
assert coloraxis.cmax < 200 # Should be around 98th percentile (~98)
324+
325+
def test_imshow_user_zmin_zmax_override(self) -> None:
326+
"""Test that user-provided zmin/zmax overrides auto bounds."""
327+
da = xr.DataArray(np.random.rand(10, 20) * 100, dims=["y", "x"])
328+
fig = da.plotly.imshow(zmin=0, zmax=50)
329+
coloraxis = fig.layout.coloraxis
330+
assert coloraxis.cmin == 0
331+
assert coloraxis.cmax == 50
332+
333+
def test_imshow_animation_consistent_bounds(self) -> None:
334+
"""Test that animation frames have consistent color bounds."""
335+
da = xr.DataArray(
336+
np.array([[[0, 10], [20, 30]], [[40, 50], [60, 70]]]),
337+
dims=["time", "y", "x"],
338+
)
339+
fig = da.plotly.imshow(animation_frame="time")
340+
# All frames should use global min (0) and max (70)
341+
coloraxis = fig.layout.coloraxis
342+
assert coloraxis.cmin == 0.0
343+
assert coloraxis.cmax == 70.0

xarray_plotly/accessor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def imshow(
250250
y: SlotValue = auto,
251251
facet_col: SlotValue = auto,
252252
animation_frame: SlotValue = auto,
253+
robust: bool = False,
253254
**px_kwargs: Any,
254255
) -> go.Figure:
255256
"""Create an interactive heatmap image.
@@ -261,7 +262,9 @@ def imshow(
261262
y: Dimension for y-axis (rows). Default: first dimension.
262263
facet_col: Dimension for subplot columns. Default: third dimension.
263264
animation_frame: Dimension for animation. Default: fourth dimension.
265+
robust: If True, use 2nd/98th percentiles for color bounds (handles outliers).
264266
**px_kwargs: Additional arguments passed to `plotly.express.imshow()`.
267+
Use `zmin` and `zmax` to manually set color scale bounds.
265268
266269
Returns:
267270
Interactive Plotly Figure.
@@ -272,6 +275,7 @@ def imshow(
272275
y=y,
273276
facet_col=facet_col,
274277
animation_frame=animation_frame,
278+
robust=robust,
275279
**px_kwargs,
276280
)
277281

xarray_plotly/plotting.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from typing import TYPE_CHECKING, Any
88

9+
import numpy as np
910
import plotly.express as px
1011

1112
from xarray_plotly.common import (
@@ -398,6 +399,7 @@ def imshow(
398399
y: SlotValue = auto,
399400
facet_col: SlotValue = auto,
400401
animation_frame: SlotValue = auto,
402+
robust: bool = False,
401403
**px_kwargs: Any,
402404
) -> go.Figure:
403405
"""
@@ -418,8 +420,12 @@ def imshow(
418420
Dimension for subplot columns. Default: third dimension.
419421
animation_frame
420422
Dimension for animation. Default: fourth dimension.
423+
robust
424+
If True, compute color bounds using 2nd and 98th percentiles
425+
for robustness against outliers. Default: False.
421426
**px_kwargs
422427
Additional arguments passed to `plotly.express.imshow()`.
428+
Use `zmin` and `zmax` to manually set color scale bounds.
423429
424430
Returns
425431
-------
@@ -440,6 +446,20 @@ def imshow(
440446
]
441447
plot_data = darray.transpose(*transpose_order) if transpose_order else darray
442448

449+
# Compute global color bounds if not provided
450+
if "zmin" not in px_kwargs or "zmax" not in px_kwargs:
451+
values = plot_data.values
452+
if robust:
453+
# Use percentiles for outlier robustness
454+
zmin = float(np.nanpercentile(values, 2))
455+
zmax = float(np.nanpercentile(values, 98))
456+
else:
457+
# Use global min/max across all data
458+
zmin = float(np.nanmin(values))
459+
zmax = float(np.nanmax(values))
460+
px_kwargs.setdefault("zmin", zmin)
461+
px_kwargs.setdefault("zmax", zmax)
462+
443463
return px.imshow(
444464
plot_data,
445465
facet_col=slots.get("facet_col"),

0 commit comments

Comments
 (0)