Skip to content

Commit cf34bfe

Browse files
timtreisclaude
andcommitted
Fix axes mismatch when elements have multi-CS transformations (#176)
When elements have transformations to multiple coordinate systems, filter_by_coordinate_system cannot strip the extra transformations (upstream limitation). This caused show() to auto-detect too many CS and raise a mismatch error when the user passed their own axes. Apply a stricter CS filter when coordinate_systems is auto-detected and axes are provided: keep only CS that have element types for all render commands. Also improve the error message with a hint to pass coordinate_systems= explicitly when the mismatch cannot be resolved. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 303140c commit cf34bfe

File tree

3 files changed

+100
-2
lines changed

3 files changed

+100
-2
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -993,9 +993,11 @@ def show(
993993
ax_x_min, ax_x_max = ax.get_xlim()
994994
ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left
995995

996-
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
996+
cs_was_auto = coordinate_systems is None
997+
coordinate_systems = list(sdata.coordinate_systems) if cs_was_auto else coordinate_systems
997998
if isinstance(coordinate_systems, str):
998999
coordinate_systems = [coordinate_systems]
1000+
assert coordinate_systems is not None # guaranteed by the branches above
9991001

10001002
for cs in coordinate_systems:
10011003
if cs not in sdata.coordinate_systems:
@@ -1019,14 +1021,43 @@ def show(
10191021
elements=elements_to_be_rendered,
10201022
)
10211023

1024+
# When coordinate_systems was auto-detected and the user provided axes,
1025+
# apply a stricter filter: keep only CS that have element types for ALL
1026+
# render commands. This handles the case where filter_by_coordinate_system
1027+
# leaves extra CS due to elements having multi-CS transformations (#176).
1028+
if ax is not None and cs_was_auto:
1029+
n_ax = 1 if isinstance(ax, Axes) else len(ax)
1030+
if len(coordinate_systems) > n_ax:
1031+
render_type_flags = {
1032+
"render_images": "has_images",
1033+
"render_labels": "has_labels",
1034+
"render_points": "has_points",
1035+
"render_shapes": "has_shapes",
1036+
}
1037+
required_flags = [render_type_flags[cmd] for cmd in cmds if cmd in render_type_flags]
1038+
strict_cs = []
1039+
for cs_name in coordinate_systems:
1040+
row = cs_contents.query(f"cs == '{cs_name}'").iloc[0]
1041+
if all(row[flag] for flag in required_flags):
1042+
strict_cs.append(cs_name)
1043+
if strict_cs:
1044+
coordinate_systems = strict_cs
1045+
10221046
# catch error in ruff-friendly way
10231047
if ax is not None: # we'll generate matching number then
10241048
n_ax = 1 if isinstance(ax, Axes) else len(ax)
10251049
if len(coordinate_systems) != n_ax:
1026-
raise ValueError(
1050+
msg = (
10271051
f"Mismatch between number of matplotlib axes objects ({n_ax}) "
10281052
f"and number of coordinate systems ({len(coordinate_systems)})."
10291053
)
1054+
if cs_was_auto:
1055+
msg += (
1056+
" This can happen when elements have transformations to multiple "
1057+
"coordinate systems (e.g. after filter_by_coordinate_system). "
1058+
"Pass `coordinate_systems=` explicitly to select which ones to plot."
1059+
)
1060+
raise ValueError(msg)
10301061

10311062
# set up canvas
10321063
fig_params, scalebar_params = _prepare_params_plot(

src/spatialdata_plot/pl/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def _prepare_params_plot(
264264
if ax is not None and len(ax) != num_panels:
265265
raise ValueError(f"Len of `ax`: {len(ax)} is not equal to number of panels: {num_panels}.")
266266
if fig is None:
267+
# TODO(#579): infer fig from ax[0].get_figure() instead of requiring it
267268
raise ValueError(
268269
f"Invalid value of `fig`: {fig}. If a list of `Axes` is passed, a `Figure` must also be specified."
269270
)

tests/pl/test_render.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import matplotlib.pyplot as plt
2+
import numpy as np
23
import pytest
4+
import spatialdata as sd
5+
from spatialdata.transformations import Identity
36

47

58
def test_render_images_can_plot_one_cyx_image(request):
@@ -62,3 +65,66 @@ def test_keyerror_when_shape_element_does_not_exist(request):
6265

6366
with pytest.raises(KeyError):
6467
sdata.pl.render_shapes(element="not_found").pl.show()
68+
69+
70+
# Regression tests for #176: plotting with user-supplied ax when elements
71+
# have transformations to multiple coordinate systems.
72+
73+
74+
def _make_multi_cs_sdata():
75+
"""Create sdata where shapes has transformations to two CS but image has only one."""
76+
import geopandas as gpd
77+
from shapely import Point
78+
79+
image = sd.models.Image2DModel.parse(
80+
np.zeros((1, 10, 10)), dims=("c", "y", "x"), transformations={"aligned": Identity()}
81+
)
82+
shapes = sd.models.ShapesModel.parse(
83+
gpd.GeoDataFrame(geometry=[Point(5, 5)], data={"radius": [2]}),
84+
transformations={"aligned": Identity(), "global": Identity()},
85+
)
86+
return sd.SpatialData(images={"img": image}, shapes={"shp": shapes})
87+
88+
89+
def test_single_ax_after_filter_by_coordinate_system():
90+
"""After filter_by_coordinate_system, single ax should work without specifying CS."""
91+
sdata = _make_multi_cs_sdata()
92+
sdata_filt = sdata.filter_by_coordinate_system("aligned")
93+
94+
_, ax = plt.subplots(1, 1)
95+
sdata_filt.pl.render_images("img").pl.render_shapes("shp").pl.show(ax=ax)
96+
assert ax.get_title() == "aligned"
97+
98+
99+
def test_single_ax_with_explicit_cs():
100+
"""Explicit coordinate_systems with single ax should work."""
101+
sdata = _make_multi_cs_sdata()
102+
103+
_, ax = plt.subplots(1, 1)
104+
sdata.pl.render_images("img").pl.render_shapes("shp").pl.show(ax=ax, coordinate_systems="aligned")
105+
assert ax.get_title() == "aligned"
106+
107+
108+
def test_single_ax_explicit_multi_cs_raises():
109+
"""Explicitly requesting more CS than axes should still raise."""
110+
sdata = _make_multi_cs_sdata()
111+
112+
_, ax = plt.subplots(1, 1)
113+
with pytest.raises(ValueError, match="Mismatch"):
114+
sdata.pl.render_shapes("shp").pl.show(ax=ax, coordinate_systems=["aligned", "global"])
115+
116+
117+
def test_single_ax_auto_cs_unresolvable_raises():
118+
"""When strict filtering can't resolve the mismatch, error includes hint."""
119+
import geopandas as gpd
120+
from shapely import Point
121+
122+
shapes = sd.models.ShapesModel.parse(
123+
gpd.GeoDataFrame(geometry=[Point(5, 5)], data={"radius": [2]}),
124+
transformations={"aligned": Identity(), "global": Identity()},
125+
)
126+
sdata = sd.SpatialData(shapes={"shp": shapes})
127+
128+
_, ax = plt.subplots(1, 1)
129+
with pytest.raises(ValueError, match="coordinate_systems="):
130+
sdata.pl.render_shapes("shp").pl.show(ax=ax)

0 commit comments

Comments
 (0)