Skip to content

Commit 113aa11

Browse files
timtreisclaude
andcommitted
Fix axes mismatch when elements have multi-CS transformations (#176)
When elements have transformations to multiple coordinate systems, filter_by_coordinate_system cannot strip the extra transformations (upstream limitation). This caused show() to auto-detect too many CS and raise a mismatch error when the user passed their own axes. Apply a stricter CS filter when coordinate_systems is auto-detected and axes are provided: keep only CS that have element types for all render commands. Also improve the error message with a hint to pass coordinate_systems= explicitly when the mismatch cannot be resolved. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 303140c commit 113aa11

File tree

4 files changed

+94
-11
lines changed

4 files changed

+94
-11
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
_FontWeight,
5252
)
5353
from spatialdata_plot.pl.utils import (
54+
_RENDER_CMD_TO_CS_FLAG,
5455
_get_cs_contents,
5556
_get_elements_to_be_rendered,
5657
_get_valid_cs,
@@ -993,9 +994,11 @@ def show(
993994
ax_x_min, ax_x_max = ax.get_xlim()
994995
ax_y_max, ax_y_min = ax.get_ylim() # (0, 0) is top-left
995996

996-
coordinate_systems = sdata.coordinate_systems if coordinate_systems is None else coordinate_systems
997+
cs_was_auto = coordinate_systems is None
998+
coordinate_systems = list(sdata.coordinate_systems) if cs_was_auto else coordinate_systems
997999
if isinstance(coordinate_systems, str):
9981000
coordinate_systems = [coordinate_systems]
1001+
assert coordinate_systems is not None
9991002

10001003
for cs in coordinate_systems:
10011004
if cs not in sdata.coordinate_systems:
@@ -1019,14 +1022,32 @@ def show(
10191022
elements=elements_to_be_rendered,
10201023
)
10211024

1022-
# catch error in ruff-friendly way
1023-
if ax is not None: # we'll generate matching number then
1025+
# When CS was auto-detected and ax is provided, keep only CS that have
1026+
# element types for ALL render commands (workaround for upstream #176).
1027+
if ax is not None:
10241028
n_ax = 1 if isinstance(ax, Axes) else len(ax)
1029+
if cs_was_auto and len(coordinate_systems) > n_ax:
1030+
required_flags = [_RENDER_CMD_TO_CS_FLAG[cmd] for cmd in cmds if cmd in _RENDER_CMD_TO_CS_FLAG]
1031+
strict_cs = [
1032+
cs_name
1033+
for cs_name in coordinate_systems
1034+
if all(cs_contents.query(f"cs == '{cs_name}'").iloc[0][flag] for flag in required_flags)
1035+
]
1036+
if strict_cs:
1037+
coordinate_systems = strict_cs
1038+
10251039
if len(coordinate_systems) != n_ax:
1026-
raise ValueError(
1040+
msg = (
10271041
f"Mismatch between number of matplotlib axes objects ({n_ax}) "
10281042
f"and number of coordinate systems ({len(coordinate_systems)})."
10291043
)
1044+
if cs_was_auto:
1045+
msg += (
1046+
" This can happen when elements have transformations to multiple "
1047+
"coordinate systems (e.g. after filter_by_coordinate_system). "
1048+
"Pass `coordinate_systems=` explicitly to select which ones to plot."
1049+
)
1050+
raise ValueError(msg)
10301051

10311052
# set up canvas
10321053
fig_params, scalebar_params = _prepare_params_plot(

src/spatialdata_plot/pl/utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@
9595

9696
_GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name."
9797

98+
_RENDER_CMD_TO_CS_FLAG: dict[str, str] = {
99+
"render_images": "has_images",
100+
"render_shapes": "has_shapes",
101+
"render_points": "has_points",
102+
"render_labels": "has_labels",
103+
}
104+
98105

99106
def _gate_palette_and_groups(
100107
element_params: dict[str, Any],
@@ -264,6 +271,7 @@ def _prepare_params_plot(
264271
if ax is not None and len(ax) != num_panels:
265272
raise ValueError(f"Len of `ax`: {len(ax)} is not equal to number of panels: {num_panels}.")
266273
if fig is None:
274+
# TODO(#579): infer fig from ax[0].get_figure() instead of requiring it
267275
raise ValueError(
268276
f"Invalid value of `fig`: {fig}. If a list of `Axes` is passed, a `Figure` must also be specified."
269277
)
@@ -2080,17 +2088,11 @@ def _get_elements_to_be_rendered(
20802088
List of names of the SpatialElements to be rendered in the plot.
20812089
"""
20822090
elements_to_be_rendered: list[str] = []
2083-
render_cmds_map = {
2084-
"render_images": "has_images",
2085-
"render_shapes": "has_shapes",
2086-
"render_points": "has_points",
2087-
"render_labels": "has_labels",
2088-
}
20892091

20902092
cs_query = cs_contents.query(f"cs == '{cs}'")
20912093

20922094
for cmd, params in render_cmds:
2093-
key = render_cmds_map.get(cmd)
2095+
key = _RENDER_CMD_TO_CS_FLAG.get(cmd)
20942096
if key and cs_query[key][0]:
20952097
elements_to_be_rendered += [params.element]
20962098

tests/conftest.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,31 @@ def _get_sdata_with_multiple_images(share_coordinate_system: str = "all"):
631631
return _get_sdata_with_multiple_images
632632

633633

634+
@pytest.fixture
635+
def sdata_multi_cs():
636+
"""SpatialData with an image in one CS and shapes in two CS.
637+
638+
Useful for testing behaviour when elements have transformations to
639+
different sets of coordinate systems (e.g. after
640+
``filter_by_coordinate_system``).
641+
"""
642+
from shapely.geometry import Point
643+
644+
image = Image2DModel.parse(
645+
np.zeros((1, 10, 10)),
646+
dims=("c", "y", "x"),
647+
transformations={"aligned": sd.transformations.Identity()},
648+
)
649+
shapes = ShapesModel.parse(
650+
GeoDataFrame(geometry=[Point(5, 5)], data={"radius": [2]}),
651+
transformations={
652+
"aligned": sd.transformations.Identity(),
653+
"global": sd.transformations.Identity(),
654+
},
655+
)
656+
return SpatialData(images={"img": image}, shapes={"shp": shapes})
657+
658+
634659
@pytest.fixture
635660
def sdata_hexagonal_grid_spots():
636661
"""Create a hexagonal grid of points for testing visium_hex functionality."""

tests/pl/test_render.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,38 @@ def test_keyerror_when_shape_element_does_not_exist(request):
6262

6363
with pytest.raises(KeyError):
6464
sdata.pl.render_shapes(element="not_found").pl.show()
65+
66+
67+
# Regression tests for #176: plotting with user-supplied ax when elements
68+
# have transformations to multiple coordinate systems.
69+
70+
71+
def test_single_ax_after_filter_by_coordinate_system(sdata_multi_cs):
72+
"""After filter_by_coordinate_system, single ax should work without specifying CS."""
73+
sdata_filt = sdata_multi_cs.filter_by_coordinate_system("aligned")
74+
75+
_, ax = plt.subplots(1, 1)
76+
sdata_filt.pl.render_images("img").pl.render_shapes("shp").pl.show(ax=ax)
77+
assert ax.get_title() == "aligned"
78+
79+
80+
def test_single_ax_with_explicit_cs(sdata_multi_cs):
81+
"""Explicit coordinate_systems with single ax should work."""
82+
_, ax = plt.subplots(1, 1)
83+
sdata_multi_cs.pl.render_images("img").pl.render_shapes("shp").pl.show(ax=ax, coordinate_systems="aligned")
84+
assert ax.get_title() == "aligned"
85+
86+
87+
def test_single_ax_explicit_multi_cs_raises(sdata_multi_cs):
88+
"""Explicitly requesting more CS than axes should still raise."""
89+
_, ax = plt.subplots(1, 1)
90+
with pytest.raises(ValueError, match="Mismatch"):
91+
sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax, coordinate_systems=["aligned", "global"])
92+
93+
94+
def test_single_ax_auto_cs_unresolvable_raises(sdata_multi_cs):
95+
"""When strict filtering can't resolve the mismatch, error includes hint."""
96+
_, ax = plt.subplots(1, 1)
97+
with pytest.raises(ValueError, match="coordinate_systems="):
98+
# Only render shapes (present in both CS), so strict filter can't narrow down
99+
sdata_multi_cs.pl.render_shapes("shp").pl.show(ax=ax)

0 commit comments

Comments
 (0)