diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 03941dd8..f7cc0e25 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -755,7 +755,7 @@ def _set_color_source_vec( color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series` color_mapping = _get_categorical_color_mapping( - adata=sdata.table, + adata=sdata.tables[table_name] if table_name is not None else sdata.table, cluster_key=value_to_plot, color_source_vector=color_source_vector, groups=groups, diff --git a/tests/_images/Labels_can_annotate_labels_with_table_layer.png b/tests/_images/Labels_can_annotate_labels_with_table_layer.png index ce2c179a..4b1f4d75 100644 Binary files a/tests/_images/Labels_can_annotate_labels_with_table_layer.png and b/tests/_images/Labels_can_annotate_labels_with_table_layer.png differ diff --git a/tests/_images/Labels_label_categorical_color_and_colors_in_uns.png b/tests/_images/Labels_label_categorical_color_and_colors_in_uns.png new file mode 100644 index 00000000..eff4ca5d Binary files /dev/null and b/tests/_images/Labels_label_categorical_color_and_colors_in_uns.png differ diff --git a/tests/_images/Labels_label_categorical_color_and_colors_in_uns_query_uns_colors_removed.png b/tests/_images/Labels_label_categorical_color_and_colors_in_uns_query_uns_colors_removed.png new file mode 100644 index 00000000..c8cfba5c Binary files /dev/null and b/tests/_images/Labels_label_categorical_color_and_colors_in_uns_query_uns_colors_removed.png differ diff --git a/tests/_images/Labels_label_categorical_color_and_colors_in_uns_query_workaround.png b/tests/_images/Labels_label_categorical_color_and_colors_in_uns_query_workaround.png new file mode 100644 index 00000000..be9003ec Binary files /dev/null and b/tests/_images/Labels_label_categorical_color_and_colors_in_uns_query_workaround.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index d7697bd7..abaf06df 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -7,13 +7,12 @@ import scanpy as sc from anndata import AnnData from spatial_image import to_spatial_image -from spatialdata import SpatialData, deepcopy, get_element_instances +from spatialdata import SpatialData, bounding_box_query, deepcopy, get_element_instances from spatialdata.models import TableModel import spatialdata_plot # noqa: F401 from tests.conftest import DPI, PlotTester, PlotTesterMeta -RNG = np.random.default_rng(seed=42) sc.pl.set_rcParams_defaults() sc.set_figure_params(dpi=DPI, color_map="viridis") matplotlib.use("agg") # same as GitHub action runner @@ -214,7 +213,63 @@ def test_plot_label_categorical_color(self, sdata_blobs: SpatialData): self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels") sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show() + def test_plot_label_categorical_color_and_colors_in_uns(self, sdata_blobs: SpatialData): + self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels") + # purple, green, yellow + sdata_blobs["other_table"].uns["category_colors"] = ["#800080", "#008000", "#FFFF00"] + # placeholder, otherwise "category_colors" will be ignored + sdata_blobs["other_table"].uns["category"] = "__value__" + sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show() + + def test_plot_label_categorical_color_and_colors_in_uns_query_uns_colors_removed(self, sdata_blobs: SpatialData): + self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels") + # purple, green, yellow + sdata_blobs["other_table"].uns["category_colors"] = ["#800080", "#008000", "#FFFF00"] + # placeholder, otherwise "category_colors" will be ignored + sdata_blobs["other_table"].uns["category"] = "__value__" + sdata_blobs = bounding_box_query( + sdata_blobs, + axes=("y", "x"), + min_coordinate=[0, 0], + max_coordinate=[100, 100], + target_coordinate_system="global", + ) + # we would expect colors purple and yellow for a and c, but we see default colors blue and orange, + # Reason: "category_colors" is removed by `.filter_by_coordinate_system` in + # `spatialdata_plot.pl.render._render_labels`. + # Why? Because `.bounding_box_query` removes "category_colors" that are not in the query, + # but restores original number of catergories in `.obs["category"]`, see https://github.com/scverse/anndata/issues/997, + # leading to mismatch and removal of "category_colors" by `.filter_by_coordinate_system`. + assert all(sdata_blobs["other_table"].obs["category"].unique() == ["a", "c"]) + assert all(sdata_blobs["other_table"].uns["category_colors"] == ["#800080", "#FFFF00"]) + # but due to https://github.com/scverse/anndata/issues/997: + assert all(sdata_blobs["other_table"].obs["category"].cat.categories == ["a", "b", "c"]) + sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show() + + def test_plot_label_categorical_color_and_colors_in_uns_query_workaround(self, sdata_blobs: SpatialData): + self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels") + # purple, green, yellow + sdata_blobs["other_table"].uns["category_colors"] = ["#800080", "#008000", "#FFFF00"] + # placeholder, otherwise "category_colors" will be ignored + sdata_blobs["other_table"].uns["category"] = "__value__" + sdata_blobs = bounding_box_query( + sdata_blobs, + axes=("y", "x"), + min_coordinate=[0, 0], + max_coordinate=[100, 100], + target_coordinate_system="global", + ) + assert all(sdata_blobs["other_table"].obs["category"].unique() == ["a", "c"]) + assert all(sdata_blobs["other_table"].uns["category_colors"] == ["#800080", "#FFFF00"]) + # but due to https://github.com/scverse/anndata/issues/997: + assert all(sdata_blobs["other_table"].obs["category"].cat.categories == ["a", "b", "c"]) + sdata_blobs["other_table"].obs["category"] = ( + sdata_blobs["other_table"].obs["category"].cat.remove_unused_categories() + ) + sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show() + def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str): + RNG = np.random.default_rng(seed=42) instances = get_element_instances(sdata_blobs[labels_name]) n_obs = len(instances) adata = AnnData( @@ -235,5 +290,6 @@ def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category") def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData): + RNG = np.random.default_rng(seed=42) sdata_blobs["table"].layers["normalized"] = RNG.random(sdata_blobs["table"].X.shape) sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show()