Skip to content

Commit f0f9e96

Browse files
authored
Remove spatial interlacement prototype from palette generation (#586)
1 parent c6e4644 commit f0f9e96

File tree

2 files changed

+24
-252
lines changed

2 files changed

+24
-252
lines changed

src/spatialdata_plot/pl/_palette.py

Lines changed: 18 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
- :func:`make_palette` — produce *n* colours, optionally reordered for
66
maximum perceptual contrast or colourblind accessibility.
77
- :func:`make_palette_from_data` — like :func:`make_palette` but derives
8-
the number of colours and (for ``spaco`` methods) the assignment order
9-
from a :class:`~spatialdata.SpatialData` element.
8+
the number of colours from a :class:`~spatialdata.SpatialData` element.
109
1110
Both share the same *palette* / *method* vocabulary. The *palette*
1211
parameter controls **which** colours are used (the source), while
@@ -22,13 +21,8 @@
2221
from matplotlib.colors import ListedColormap, to_hex, to_rgb
2322
from matplotlib.pyplot import colormaps as mpl_colormaps
2423
from scanpy.plotting.palettes import default_20, default_28, default_102
25-
from scipy.spatial import cKDTree
26-
27-
from spatialdata_plot._logging import logger
2824

2925
if TYPE_CHECKING:
30-
from collections.abc import Sequence
31-
3226
import spatialdata as sd
3327

3428
# ---------------------------------------------------------------------------
@@ -163,9 +157,6 @@ def _optimize_assignment(
163157
) -> np.ndarray:
164158
"""Find a permutation that maximizes ``sum(weights * color_dist[perm, perm])``.
165159
166-
Works for both spatial interlacement weights (spaco) and uniform
167-
weights (pure contrast maximization).
168-
169160
Returns an index array: ``perm[category_idx] = color_idx``.
170161
"""
171162
if rng is None:
@@ -233,56 +224,6 @@ def _optimized_order(
233224
return [to_hex(rgb[perm[i]]) for i in range(n)]
234225

235226

236-
# ---------------------------------------------------------------------------
237-
# Spatial interlacement (spaco-specific)
238-
# ---------------------------------------------------------------------------
239-
240-
241-
def _spatial_interlacement(
242-
coords: np.ndarray,
243-
labels: np.ndarray,
244-
categories: Sequence[str],
245-
n_neighbors: int = 15,
246-
) -> np.ndarray:
247-
"""Build a symmetric interlacement matrix (n_categories × n_categories).
248-
249-
Entry (i, j) reflects how much categories i and j are spatially
250-
interleaved, measured by inverse-distance-weighted neighbor counts.
251-
"""
252-
n_cat = len(categories)
253-
cat_to_idx = {c: i for i, c in enumerate(categories)}
254-
label_idx = np.array([cat_to_idx[l] for l in labels])
255-
256-
tree = cKDTree(coords)
257-
dists, indices = tree.query(coords, k=min(n_neighbors + 1, len(coords)))
258-
259-
# Vectorized accumulation (avoids Python double-loop over cells × neighbors)
260-
neighbor_dists = dists[:, 1:]
261-
neighbor_indices = indices[:, 1:]
262-
cell_cats = label_idx
263-
neighbor_cats = label_idx[neighbor_indices]
264-
265-
# Mask: different category and positive distance
266-
cross_cat = neighbor_cats != cell_cats[:, np.newaxis]
267-
valid_dist = neighbor_dists > 0
268-
mask = cross_cat & valid_dist
269-
270-
weights = np.where(mask, 1.0 / np.where(neighbor_dists > 0, neighbor_dists, 1.0), 0.0)
271-
272-
rows = np.broadcast_to(cell_cats[:, np.newaxis], neighbor_cats.shape)[mask]
273-
cols = neighbor_cats[mask]
274-
vals = weights[mask]
275-
276-
mat = np.zeros((n_cat, n_cat), dtype=np.float64)
277-
np.add.at(mat, (rows, cols), vals)
278-
279-
mat = np.maximum(mat, mat.T)
280-
max_val = mat.max()
281-
if max_val > 0:
282-
mat /= max_val
283-
return mat # type: ignore[no-any-return]
284-
285-
286227
# ---------------------------------------------------------------------------
287228
# Palette resolution
288229
# ---------------------------------------------------------------------------
@@ -339,35 +280,24 @@ def _resolve_element(
339280
element: str,
340281
color: str,
341282
table_name: str | None = None,
342-
) -> tuple[np.ndarray, pd.Categorical]:
343-
"""Extract coordinates and categorical labels from a SpatialData element.
283+
) -> pd.Categorical:
284+
"""Extract categorical labels from a SpatialData element.
344285
345-
Coordinates come from the element geometry (shapes) or x/y columns
346-
(points). Labels come from a column on the element itself, or from
347-
a linked table (joined on the instance key to guarantee alignment).
286+
Labels come from a column on the element itself, or from a linked
287+
table (joined on the instance key to guarantee alignment).
348288
"""
349289
if element in sdata.shapes:
350290
gdf = sdata.shapes[element]
351-
coords = np.column_stack([gdf.geometry.centroid.x, gdf.geometry.centroid.y])
352291
if color in gdf.columns:
353292
labels_series = gdf[color]
354293
else:
355-
labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name)
356-
# Align coords to table rows via matched instance indices
357-
coords = coords[matched_indices]
294+
labels_series = _get_labels_from_table(sdata, element, color, table_name)
358295
elif element in sdata.points:
359296
ddf = sdata.points[element]
360-
if "x" not in ddf.columns or "y" not in ddf.columns:
361-
raise ValueError(f"Points element '{element}' does not have 'x' and 'y' columns.")
362297
if color in ddf.columns:
363-
df = ddf[["x", "y", color]].compute()
364-
coords = df[["x", "y"]].values.astype(np.float64)
365-
labels_series = df[color]
298+
labels_series = ddf[[color]].compute()[color]
366299
else:
367-
df = ddf[["x", "y"]].compute()
368-
coords = df[["x", "y"]].values.astype(np.float64)
369-
labels_series, matched_indices = _get_labels_from_table(sdata, element, color, table_name)
370-
coords = coords[matched_indices]
300+
labels_series = _get_labels_from_table(sdata, element, color, table_name)
371301
else:
372302
available = list(sdata.shapes.keys()) + list(sdata.points.keys())
373303
raise KeyError(
@@ -376,24 +306,16 @@ def _resolve_element(
376306
)
377307

378308
is_categorical = isinstance(getattr(labels_series, "dtype", None), pd.CategoricalDtype)
379-
labels_cat = labels_series.values if is_categorical else pd.Categorical(labels_series)
380-
return coords, labels_cat
309+
return labels_series.values if is_categorical else pd.Categorical(labels_series)
381310

382311

383312
def _get_labels_from_table(
384313
sdata: sd.SpatialData,
385314
element: str,
386315
color: str,
387316
table_name: str | None = None,
388-
) -> tuple[pd.Series, np.ndarray]:
389-
"""Extract a column from the table linked to an element.
390-
391-
Returns (labels_series, element_indices) where element_indices maps
392-
each table row to its position in the element, ensuring coord-label
393-
alignment.
394-
"""
395-
from spatialdata.models import get_table_keys
396-
317+
) -> pd.Series:
318+
"""Extract a column from the table linked to an element."""
397319
matches: list[str] = []
398320
for name in sdata.tables:
399321
table = sdata.tables[name]
@@ -423,29 +345,7 @@ def _get_labels_from_table(
423345
)
424346

425347
table = sdata.tables[resolved_name]
426-
_, _, instance_key = get_table_keys(table)
427-
428-
# Join on instance key to align table rows with element positions
429-
instance_ids = table.obs[instance_key].values
430-
element_index = sdata.shapes[element].index if element in sdata.shapes else sdata.points[element].compute().index
431-
432-
# Map each table instance_id to its position in the element index
433-
element_idx_map = {val: i for i, val in enumerate(element_index)}
434-
matched_indices = []
435-
valid_mask = []
436-
for iid in instance_ids:
437-
if iid in element_idx_map:
438-
matched_indices.append(element_idx_map[iid])
439-
valid_mask.append(True)
440-
else:
441-
valid_mask.append(False)
442-
443-
valid_mask_arr = np.array(valid_mask)
444-
if not any(valid_mask):
445-
raise ValueError(f"No matching instance keys between table '{resolved_name}' and element '{element}'.")
446-
447-
labels = table.obs.loc[valid_mask_arr, color]
448-
return labels.reset_index(drop=True), np.array(matched_indices)
348+
return table.obs[color].reset_index(drop=True)
449349

450350

451351
# ---------------------------------------------------------------------------
@@ -461,16 +361,7 @@ def _get_labels_from_table(
461361
"tritanopia": "tritanopia",
462362
}
463363

464-
# Maps spaco methods → CVD type (None = normal vision).
465-
_SPACO_CVD_TYPES: dict[str, str | None] = {
466-
"spaco": None,
467-
"spaco_colorblind": "general",
468-
"spaco_protanopia": "protanopia",
469-
"spaco_deuteranopia": "deuteranopia",
470-
"spaco_tritanopia": "tritanopia",
471-
}
472-
473-
_ALL_METHODS = sorted({"default", *_CONTRAST_CVD_TYPES, *_SPACO_CVD_TYPES})
364+
_ALL_METHODS = sorted({"default", *_CONTRAST_CVD_TYPES})
474365

475366

476367
# ---------------------------------------------------------------------------
@@ -484,11 +375,6 @@ def _get_labels_from_table(
484375
"protanopia",
485376
"deuteranopia",
486377
"tritanopia",
487-
"spaco",
488-
"spaco_colorblind",
489-
"spaco_protanopia",
490-
"spaco_deuteranopia",
491-
"spaco_tritanopia",
492378
]
493379

494380

@@ -528,9 +414,6 @@ def make_palette(
528414
under worst-case colour-vision deficiency.
529415
- ``"protanopia"`` / ``"deuteranopia"`` / ``"tritanopia"`` —
530416
reorder for a specific colour-vision deficiency.
531-
532-
The ``spaco*`` methods require spatial data and are only
533-
available via :func:`make_palette_from_data`.
534417
n_random
535418
Random permutations to try (optimisation methods only).
536419
n_swaps
@@ -553,9 +436,6 @@ def make_palette(
553436
if n < 1:
554437
raise ValueError(f"n must be at least 1, got {n}.")
555438

556-
if method in _SPACO_CVD_TYPES:
557-
raise ValueError(f"Method '{method}' requires spatial data. Use make_palette_from_data() instead.")
558-
559439
colors = _resolve_palette(palette, n)
560440

561441
if method == "default":
@@ -577,7 +457,6 @@ def make_palette_from_data(
577457
palette: list[str] | str | None = None,
578458
method: Method = "default",
579459
table_name: str | None = None,
580-
n_neighbors: int = 15,
581460
n_random: int = 5000,
582461
n_swaps: int = 10000,
583462
seed: int = 0,
@@ -605,25 +484,13 @@ def make_palette_from_data(
605484
Name of the table to use when *color* is looked up from a linked
606485
table. Required when multiple tables annotate the same element.
607486
method
608-
Strategy for assigning colours to categories. Accepts all
609-
methods from :func:`make_palette` plus spatially-aware ones:
487+
Strategy for assigning colours to categories:
610488
611489
- ``"default"`` — assign in sorted category order (reproduces
612490
the current render-pipeline behaviour).
613491
- ``"contrast"`` / ``"colorblind"`` / ``"protanopia"`` /
614492
``"deuteranopia"`` / ``"tritanopia"`` — reorder to maximise
615-
perceptual spread (ignores spatial layout).
616-
- ``"spaco"`` — spatially-aware assignment (Jing et al.,
617-
*Patterns* 2023). Maximises perceptual contrast between
618-
categories that are spatially interleaved.
619-
- ``"spaco_colorblind"`` — like ``"spaco"`` but optimises under
620-
worst-case colour-vision deficiency (all three types).
621-
- ``"spaco_protanopia"`` / ``"spaco_deuteranopia"`` /
622-
``"spaco_tritanopia"`` — like ``"spaco"`` but optimises for
623-
a specific colour-vision deficiency.
624-
n_neighbors
625-
Only used with ``spaco`` methods. Number of spatial neighbours
626-
for the interlacement computation.
493+
perceptual spread.
627494
n_random
628495
Random permutations to try (optimisation methods only).
629496
n_swaps
@@ -641,11 +508,11 @@ def make_palette_from_data(
641508
--------
642509
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type")
643510
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", palette="tab10")
644-
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco")
645-
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco_colorblind")
511+
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="contrast")
512+
>>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="colorblind")
646513
>>> sdata.pl.render_shapes("cells", color="cell_type", palette=palette).pl.show()
647514
"""
648-
coords, labels_cat = _resolve_element(sdata, element, color, table_name=table_name)
515+
labels_cat = _resolve_element(sdata, element, color, table_name=table_name)
649516

650517
categories = list(labels_cat.categories)
651518
n_cat = len(categories)
@@ -657,42 +524,12 @@ def make_palette_from_data(
657524
if method == "default":
658525
return {cat: to_hex(to_rgb(c)) for cat, c in zip(categories, colors_list, strict=True)}
659526

660-
# Non-spatial contrast methods (same as make_palette but returns dict)
661527
if method in _CONTRAST_CVD_TYPES:
662528
cvd_type = _CONTRAST_CVD_TYPES[method]
663529
reordered = _optimized_order(
664530
colors_list, colorblind_type=cvd_type, n_random=n_random, n_swaps=n_swaps, seed=seed
665531
)
666532
return dict(zip(categories, reordered, strict=True))
667533

668-
# Spaco methods (spatially-aware)
669-
if method in _SPACO_CVD_TYPES:
670-
cvd_type = _SPACO_CVD_TYPES[method]
671-
672-
# Filter NaN labels
673-
mask = labels_cat.codes != -1
674-
coords_clean = coords[mask]
675-
labels_clean = np.array(categories)[labels_cat.codes[mask]]
676-
677-
if len(coords_clean) == 0:
678-
raise ValueError(f"All values in column '{color}' are NaN.")
679-
680-
rgb = np.array([to_rgb(c) for c in colors_list])
681-
682-
if n_cat == 1:
683-
return {categories[0]: to_hex(rgb[0])}
684-
685-
logger.info(f"Computing spatial interlacement for {n_cat} categories ({len(coords_clean)} cells)...")
686-
inter = _spatial_interlacement(coords_clean, labels_clean, categories, n_neighbors=n_neighbors)
687-
688-
logger.info("Computing perceptual distance matrix...")
689-
cdist = _perceptual_distance_matrix(rgb, colorblind_type=cvd_type)
690-
691-
logger.info("Optimizing color assignment...")
692-
rng = np.random.default_rng(seed)
693-
perm = _optimize_assignment(inter, cdist, n_random=n_random, n_swaps=n_swaps, rng=rng)
694-
695-
return {cat: to_hex(rgb[perm[i]]) for i, cat in enumerate(categories)}
696-
697534
valid = ", ".join(f"'{m}'" for m in _ALL_METHODS)
698535
raise ValueError(f"Unknown method '{method}'. Choose from {valid}.")

0 commit comments

Comments
 (0)