Skip to content

Commit 222404a

Browse files
timtreisclaude
andcommitted
Fix shapes datashader colorbar exceeding data range (#559)
The default datashader reduction for shapes was "sum", causing overlapping shapes to inflate the colorbar beyond the true data maximum. Changed the default to "max" which preserves the actual data range and closely matches the matplotlib rendering. Also: extract _default_reduction to prevent log/aggregation drift, add logger_no_warns test helper, short-circuit _want_decorations for diverse color vectors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f2404e0 commit 222404a

File tree

6 files changed

+99
-40
lines changed

6 files changed

+99
-40
lines changed

src/spatialdata_plot/_logging.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,40 @@ def logger_warns(
113113
if not any(pattern.search(r.getMessage()) for r in records):
114114
msgs = [r.getMessage() for r in records]
115115
raise AssertionError(f"Did not find log matching {match!r} in records: {msgs!r}")
116+
117+
118+
@contextmanager
119+
def logger_no_warns(
120+
caplog: LogCaptureFixture,
121+
logger: logging.Logger,
122+
match: str | None = None,
123+
level: int = logging.WARNING,
124+
) -> Iterator[None]:
125+
"""Assert that no log record matching *match* is emitted.
126+
127+
Counterpart to :func:`logger_warns`.
128+
"""
129+
initial_record_count = len(caplog.records)
130+
131+
handler = caplog.handler
132+
logger.addHandler(handler)
133+
original_level = logger.level
134+
logger.setLevel(level)
135+
136+
with caplog.at_level(level, logger=logger.name):
137+
try:
138+
yield
139+
finally:
140+
logger.removeHandler(handler)
141+
logger.setLevel(original_level)
142+
143+
records = [r for r in caplog.records[initial_record_count:] if r.levelno >= level]
144+
145+
if match is not None:
146+
pattern = re.compile(match)
147+
matching = [r.getMessage() for r in records if pattern.search(r.getMessage())]
148+
if matching:
149+
raise AssertionError(f"Found unexpected log matching {match!r}: {matching!r}")
150+
elif records:
151+
msgs = [r.getMessage() for r in records]
152+
raise AssertionError(f"Expected no log records at level>={level}, but got: {msgs!r}")

src/spatialdata_plot/pl/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ def render_shapes(
272272
273273
datashader_reduction : Literal[
274274
"sum", "mean", "any", "count", "std", "var", "max", "min"
275-
], default: "sum"
276-
Reduction method for datashader when coloring by continuous values. Defaults to 'sum'.
275+
], default: "max"
276+
Reduction method for datashader when coloring by continuous values. Defaults to 'max'.
277277
278278
279279
Notes

src/spatialdata_plot/pl/render.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_ds_aggregate,
3535
_ds_shade_categorical,
3636
_ds_shade_continuous,
37+
_DsReduction,
3738
_render_ds_image,
3839
_render_ds_outlines,
3940
)
@@ -82,11 +83,9 @@ def _want_decorations(color_vector: Any, na_color: Color) -> bool:
8283
cv = np.asarray(color_vector)
8384
if cv.size == 0:
8485
return False
85-
# Fast check: if any value differs from the first, there is variety → show decorations.
8686
first = cv.flat[0]
87-
if not (cv == first).all():
87+
if any(v != first for v in cv.flat[1:]):
8888
return True
89-
# All values are the same — suppress decorations when that value is the NA color.
9089
na_hex = na_color.get_hex()
9190
if isinstance(first, str) and first.startswith("#") and na_hex.startswith("#"):
9291
return _hex_no_alpha(first) != _hex_no_alpha(na_hex)
@@ -425,10 +424,14 @@ def _render_shapes(
425424
if method is None:
426425
method = "datashader" if len(shapes) > 10000 else "matplotlib"
427426

427+
_default_reduction: _DsReduction = "max"
428+
428429
if method != "matplotlib":
429-
# we only notify the user when we switched away from matplotlib
430+
_effective_reduction = (
431+
render_params.ds_reduction if render_params.ds_reduction is not None else _default_reduction
432+
)
430433
logger.info(
431-
f"Using '{method}' backend with '{render_params.ds_reduction}' as reduction"
434+
f"Using '{method}' backend with '{_effective_reduction}' as reduction"
432435
" method to speed up plotting. Depending on the reduction method, the value"
433436
" range of the plot might change. Set method to 'matplotlib' to disable"
434437
" this behaviour."
@@ -506,7 +509,7 @@ def _render_shapes(
506509
col_for_color,
507510
color_by_categorical,
508511
render_params.ds_reduction,
509-
"mean",
512+
_default_reduction,
510513
"shapes",
511514
)
512515

@@ -784,8 +787,7 @@ def _render_points(
784787
# from the registered points (see above) avoids duplicate-origin ambiguities.
785788
color_table_name = table_name
786789

787-
# When color was already loaded from a table (line 690), pass it directly
788-
# to avoid a redundant get_values() call inside _set_color_source_vec.
790+
# Reuse color data already loaded from the table to avoid a redundant get_values() call.
789791
_preloaded = points_pd_with_color[col_for_color] if added_color_from_table and col_for_color is not None else None
790792

791793
color_source_vector, color_vector, _ = _set_color_source_vec(
@@ -840,12 +842,16 @@ def _render_points(
840842
if method is None:
841843
method = "datashader" if n_points > 10000 else "matplotlib"
842844

845+
_default_reduction: _DsReduction = "sum"
846+
843847
if method == "datashader":
844-
# we only notify the user when we switched away from matplotlib
848+
_effective_reduction = (
849+
render_params.ds_reduction if render_params.ds_reduction is not None else _default_reduction
850+
)
845851
logger.info(
846-
f"Using '{method}' backend with '{render_params.ds_reduction}' as reduction"
852+
f"Using '{method}' backend with '{_effective_reduction}' as reduction"
847853
" method to speed up plotting. Depending on the reduction method, the value"
848-
" range of the plot might change. Set method to 'matplotlib' do disable"
854+
" range of the plot might change. Set method to 'matplotlib' to disable"
849855
" this behaviour."
850856
)
851857

@@ -906,7 +912,7 @@ def _render_points(
906912
col_for_color,
907913
color_by_categorical,
908914
render_params.ds_reduction,
909-
"sum",
915+
_default_reduction,
910916
"points",
911917
)
912918

src/spatialdata_plot/pl/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2507,9 +2507,6 @@ def _ensure_table_and_layer_exist_in_sdata(
25072507
if ds_reduction and (ds_reduction not in valid_ds_reduction_methods):
25082508
raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.")
25092509

2510-
if method == "datashader" and ds_reduction is None:
2511-
param_dict["ds_reduction"] = "sum"
2512-
25132510
return param_dict
25142511

25152512

tests/pl/test_render_points.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import math
32

43
import dask.dataframe
@@ -24,7 +23,7 @@
2423
from spatialdata.transformations._utils import _set_transformations
2524

2625
import spatialdata_plot # noqa: F401
27-
from spatialdata_plot._logging import logger, logger_warns
26+
from spatialdata_plot._logging import logger, logger_no_warns, logger_warns
2827
from spatialdata_plot.pl._datashader import (
2928
_build_datashader_color_key,
3029
_ds_aggregate,
@@ -832,13 +831,8 @@ def test_ds_reduction_ignored_for_categorical(caplog):
832831
def test_ds_reduction_no_warning_when_none(caplog):
833832
"""No spurious warning when ds_reduction is None (the default)."""
834833
cvs, df = _make_ds_canvas_and_df()
835-
with caplog.at_level(logging.WARNING, logger=logger.name):
836-
logger.addHandler(caplog.handler)
837-
try:
838-
_ds_aggregate(cvs, df.copy(), "cat", True, None, "sum", "points")
839-
finally:
840-
logger.removeHandler(caplog.handler)
841-
assert not any("ignored" in r.message.lower() for r in caplog.records)
834+
with logger_no_warns(caplog, logger, match="ignored"):
835+
_ds_aggregate(cvs, df.copy(), "cat", True, None, "sum", "points")
842836

843837

844838
@pytest.mark.parametrize("reduction", ["mean", "max", "min", "count", "std", "var"])
@@ -866,13 +860,8 @@ def test_warn_groups_ignored_continuous_emits(caplog):
866860

867861
def test_warn_groups_ignored_continuous_silent_for_categorical(caplog):
868862
"""No warning when color_source_vector is present (categorical)."""
869-
with caplog.at_level(logging.WARNING, logger=logger.name):
870-
logger.addHandler(caplog.handler)
871-
try:
872-
_warn_groups_ignored_continuous(["A"], pd.Categorical(["A", "B"]), "cat_col")
873-
finally:
874-
logger.removeHandler(caplog.handler)
875-
assert not any("ignored" in r.message for r in caplog.records)
863+
with logger_no_warns(caplog, logger, match="ignored"):
864+
_warn_groups_ignored_continuous(["A"], pd.Categorical(["A", "B"]), "cat_col")
876865

877866

878867
def test_color_key_warns_on_short_color_vector(caplog):
@@ -893,13 +882,8 @@ def test_color_key_warns_on_long_color_vector(caplog):
893882
def test_color_key_no_warning_when_lengths_match(caplog):
894883
"""No warning when lengths match."""
895884
cat = pd.Categorical(["A", "B", "C"])
896-
with caplog.at_level(logging.WARNING, logger=logger.name):
897-
logger.addHandler(caplog.handler)
898-
try:
899-
_build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#0000ff"], "#cccccc")
900-
finally:
901-
logger.removeHandler(caplog.handler)
902-
assert not any("color_vector length" in r.message for r in caplog.records)
885+
with logger_no_warns(caplog, logger, match="color_vector length"):
886+
_build_datashader_color_key(cat, ["#ff0000", "#00ff00", "#0000ff"], "#cccccc")
903887

904888

905889
def test_color_key_unseen_category_gets_na_color(caplog):

tests/pl/test_render_shapes.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,3 +1142,38 @@ def test_datashader_alpha_not_applied_twice(sdata_blobs: SpatialData):
11421142
"on top of the alpha already in the RGBA channels — causing double transparency."
11431143
)
11441144
plt.close(fig)
1145+
1146+
1147+
def test_datashader_colorbar_range_matches_data(sdata_blobs: SpatialData):
1148+
"""Datashader colorbar range must not exceed the actual data range for shapes.
1149+
1150+
Regression test for https://github.com/scverse/spatialdata-plot/issues/559.
1151+
Before the fix, shapes defaulted to 'sum' aggregation, causing overlapping
1152+
shapes to inflate the colorbar beyond the true data maximum.
1153+
"""
1154+
n = len(sdata_blobs.shapes["blobs_circles"])
1155+
rng = np.random.default_rng(0)
1156+
values = rng.uniform(0, 100, size=n)
1157+
sdata_blobs.shapes["blobs_circles"]["continuous_val"] = values
1158+
data_max = float(values.max())
1159+
data_min = float(values.min())
1160+
1161+
fig, ax = plt.subplots()
1162+
sdata_blobs.pl.render_shapes("blobs_circles", color="continuous_val", method="datashader").pl.show(ax=ax)
1163+
1164+
# Find the colorbar axis — it's a child axes with a ScalarMappable
1165+
cbar_vmax = None
1166+
cbar_vmin = None
1167+
for child in fig.get_children():
1168+
if isinstance(child, matplotlib.axes.Axes) and child is not ax:
1169+
ylim = child.get_ylim()
1170+
if ylim != (0.0, 1.0): # colorbar axes have non-default limits
1171+
cbar_vmin, cbar_vmax = ylim
1172+
1173+
assert cbar_vmax is not None, "Could not find colorbar in figure"
1174+
assert cbar_vmax <= data_max * 1.01, (
1175+
f"Colorbar max ({cbar_vmax:.2f}) exceeds data max ({data_max:.2f}); "
1176+
"datashader aggregation is likely using 'sum' instead of 'max'"
1177+
)
1178+
assert cbar_vmin >= data_min * 0.99 - 0.01, f"Colorbar min ({cbar_vmin:.2f}) is below data min ({data_min:.2f})"
1179+
plt.close(fig)

0 commit comments

Comments
 (0)