Skip to content

Commit 5bb08ca

Browse files
authored
Fix render_shapes losing transformation after groups filtering (#564)
1 parent 6b2c484 commit 5bb08ca

File tree

2 files changed

+46
-4
lines changed

2 files changed

+46
-4
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,10 @@ def _render_shapes(
350350

351351
shapes = sdata_filt[element]
352352

353+
# Capture the transformation *before* any groups filtering that may strip
354+
# coordinate-system metadata from the element (see #420, #447).
355+
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)
356+
353357
# get color vector (categorical or continuous)
354358
color_source_vector, color_vector, _ = _set_color_source_vec(
355359
sdata=sdata_filt,
@@ -425,9 +429,6 @@ def _render_shapes(
425429
# necessary in case different shapes elements are annotated with one table
426430
color_source_vector = color_source_vector.remove_unused_categories()
427431

428-
# Apply the transformation to the PatchCollection's paths
429-
trans, trans_data = _prepare_transformation(sdata_filt.shapes[element], coordinate_system)
430-
431432
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
432433
# convert shapes if necessary
433434
if render_params.shape is not None:

tests/pl/test_render_shapes.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from shapely.geometry import MultiPolygon, Point, Polygon
1414
from spatialdata import SpatialData, deepcopy
1515
from spatialdata.models import ShapesModel, TableModel
16-
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Sequence, Translation
16+
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Sequence, Translation, set_transformation
1717
from spatialdata.transformations._utils import _set_transformations
1818

1919
import spatialdata_plot # noqa: F401
@@ -1067,6 +1067,47 @@ def test_plot_can_handle_non_numeric_radius_values(sdata_blobs: SpatialData):
10671067
sdata_blobs.pl.render_shapes(element="blobs_circles", color="red").pl.show()
10681068

10691069

1070+
def test_groups_filtering_preserves_transformation(sdata_blobs: SpatialData):
1071+
"""Regression test for #420: groups filtering must not strip coordinate-system metadata.
1072+
1073+
Simulates the exact sequence that ``_render_shapes`` performs —
1074+
filter_by_coordinate_system -> groups boolean-index -> reset_index ->
1075+
re-assign to sdata_filt -> GeoDataFrame re-wrap — then asserts that
1076+
``_prepare_transformation`` can still retrieve the correct transformation.
1077+
"""
1078+
from spatialdata_plot.pl.utils import _prepare_transformation
1079+
1080+
scale_factor = 2.5
1081+
cs = "not_global"
1082+
set_transformation(
1083+
sdata_blobs["blobs_polygons"],
1084+
transformation={cs: Scale([scale_factor, scale_factor], axes=("x", "y"))},
1085+
set_all=True,
1086+
)
1087+
sdata_blobs.shapes["blobs_polygons"]["cluster"] = pd.Categorical(["c1", "c2", "c1", "c2", "c1"])
1088+
1089+
sdata_filt = sdata_blobs.filter_by_coordinate_system(coordinate_system=cs, filter_tables=False)
1090+
1091+
# Replicate groups filtering: boolean-index -> reset_index -> re-assign
1092+
shapes = sdata_filt.shapes["blobs_polygons"]
1093+
keep = shapes["cluster"] == "c1"
1094+
shapes = shapes[keep].reset_index(drop=True)
1095+
sdata_filt["blobs_polygons"] = shapes
1096+
# GeoDataFrame re-wrap strips .attrs (this is what _render_shapes does next)
1097+
shapes = gpd.GeoDataFrame(shapes, geometry="geometry")
1098+
1099+
# sdata_filt's element must still carry the correct transformation
1100+
trans, _ = _prepare_transformation(sdata_filt.shapes["blobs_polygons"], cs)
1101+
matrix = trans.get_matrix()
1102+
np.testing.assert_allclose(matrix[0, 0], scale_factor, err_msg="x-scale lost after groups filtering")
1103+
np.testing.assert_allclose(matrix[1, 1], scale_factor, err_msg="y-scale lost after groups filtering")
1104+
1105+
# The GeoDataFrame re-wrap strips attrs — reading the transform from
1106+
# the re-wrapped object must fail, proving why early capture matters.
1107+
with pytest.raises(AssertionError):
1108+
_prepare_transformation(shapes, cs)
1109+
1110+
10701111
def test_plot_can_handle_mixed_numeric_and_color_data(sdata_blobs: SpatialData):
10711112
"""Test that mixed numeric and color-like data raises a clear error."""
10721113
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)

0 commit comments

Comments
 (0)