55- :func:`make_palette` — produce *n* colours, optionally reordered for
66 maximum perceptual contrast or colourblind accessibility.
77- :func:`make_palette_from_data` — like :func:`make_palette` but derives
8- the number of colours and (for ``spaco`` methods) the assignment order
9- from a :class:`~spatialdata.SpatialData` element.
8+ the number of colours from a :class:`~spatialdata.SpatialData` element.
109
1110Both share the same *palette* / *method* vocabulary. The *palette*
1211parameter controls **which** colours are used (the source), while
2221from matplotlib .colors import ListedColormap , to_hex , to_rgb
2322from matplotlib .pyplot import colormaps as mpl_colormaps
2423from scanpy .plotting .palettes import default_20 , default_28 , default_102
25- from scipy .spatial import cKDTree
26-
27- from spatialdata_plot ._logging import logger
2824
2925if TYPE_CHECKING :
30- from collections .abc import Sequence
31-
3226 import spatialdata as sd
3327
3428# ---------------------------------------------------------------------------
@@ -163,9 +157,6 @@ def _optimize_assignment(
163157) -> np .ndarray :
164158 """Find a permutation that maximizes ``sum(weights * color_dist[perm, perm])``.
165159
166- Works for both spatial interlacement weights (spaco) and uniform
167- weights (pure contrast maximization).
168-
169160 Returns an index array: ``perm[category_idx] = color_idx``.
170161 """
171162 if rng is None :
@@ -233,56 +224,6 @@ def _optimized_order(
233224 return [to_hex (rgb [perm [i ]]) for i in range (n )]
234225
235226
236- # ---------------------------------------------------------------------------
237- # Spatial interlacement (spaco-specific)
238- # ---------------------------------------------------------------------------
239-
240-
241- def _spatial_interlacement (
242- coords : np .ndarray ,
243- labels : np .ndarray ,
244- categories : Sequence [str ],
245- n_neighbors : int = 15 ,
246- ) -> np .ndarray :
247- """Build a symmetric interlacement matrix (n_categories × n_categories).
248-
249- Entry (i, j) reflects how much categories i and j are spatially
250- interleaved, measured by inverse-distance-weighted neighbor counts.
251- """
252- n_cat = len (categories )
253- cat_to_idx = {c : i for i , c in enumerate (categories )}
254- label_idx = np .array ([cat_to_idx [l ] for l in labels ])
255-
256- tree = cKDTree (coords )
257- dists , indices = tree .query (coords , k = min (n_neighbors + 1 , len (coords )))
258-
259- # Vectorized accumulation (avoids Python double-loop over cells × neighbors)
260- neighbor_dists = dists [:, 1 :]
261- neighbor_indices = indices [:, 1 :]
262- cell_cats = label_idx
263- neighbor_cats = label_idx [neighbor_indices ]
264-
265- # Mask: different category and positive distance
266- cross_cat = neighbor_cats != cell_cats [:, np .newaxis ]
267- valid_dist = neighbor_dists > 0
268- mask = cross_cat & valid_dist
269-
270- weights = np .where (mask , 1.0 / np .where (neighbor_dists > 0 , neighbor_dists , 1.0 ), 0.0 )
271-
272- rows = np .broadcast_to (cell_cats [:, np .newaxis ], neighbor_cats .shape )[mask ]
273- cols = neighbor_cats [mask ]
274- vals = weights [mask ]
275-
276- mat = np .zeros ((n_cat , n_cat ), dtype = np .float64 )
277- np .add .at (mat , (rows , cols ), vals )
278-
279- mat = np .maximum (mat , mat .T )
280- max_val = mat .max ()
281- if max_val > 0 :
282- mat /= max_val
283- return mat # type: ignore[no-any-return]
284-
285-
286227# ---------------------------------------------------------------------------
287228# Palette resolution
288229# ---------------------------------------------------------------------------
@@ -339,35 +280,24 @@ def _resolve_element(
339280 element : str ,
340281 color : str ,
341282 table_name : str | None = None ,
342- ) -> tuple [ np . ndarray , pd .Categorical ] :
343- """Extract coordinates and categorical labels from a SpatialData element.
283+ ) -> pd .Categorical :
284+ """Extract categorical labels from a SpatialData element.
344285
345- Coordinates come from the element geometry (shapes) or x/y columns
346- (points). Labels come from a column on the element itself, or from
347- a linked table (joined on the instance key to guarantee alignment).
286+ Labels come from a column on the element itself, or from a linked
287+ table (joined on the instance key to guarantee alignment).
348288 """
349289 if element in sdata .shapes :
350290 gdf = sdata .shapes [element ]
351- coords = np .column_stack ([gdf .geometry .centroid .x , gdf .geometry .centroid .y ])
352291 if color in gdf .columns :
353292 labels_series = gdf [color ]
354293 else :
355- labels_series , matched_indices = _get_labels_from_table (sdata , element , color , table_name )
356- # Align coords to table rows via matched instance indices
357- coords = coords [matched_indices ]
294+ labels_series = _get_labels_from_table (sdata , element , color , table_name )
358295 elif element in sdata .points :
359296 ddf = sdata .points [element ]
360- if "x" not in ddf .columns or "y" not in ddf .columns :
361- raise ValueError (f"Points element '{ element } ' does not have 'x' and 'y' columns." )
362297 if color in ddf .columns :
363- df = ddf [["x" , "y" , color ]].compute ()
364- coords = df [["x" , "y" ]].values .astype (np .float64 )
365- labels_series = df [color ]
298+ labels_series = ddf [[color ]].compute ()[color ]
366299 else :
367- df = ddf [["x" , "y" ]].compute ()
368- coords = df [["x" , "y" ]].values .astype (np .float64 )
369- labels_series , matched_indices = _get_labels_from_table (sdata , element , color , table_name )
370- coords = coords [matched_indices ]
300+ labels_series = _get_labels_from_table (sdata , element , color , table_name )
371301 else :
372302 available = list (sdata .shapes .keys ()) + list (sdata .points .keys ())
373303 raise KeyError (
@@ -376,24 +306,16 @@ def _resolve_element(
376306 )
377307
378308 is_categorical = isinstance (getattr (labels_series , "dtype" , None ), pd .CategoricalDtype )
379- labels_cat = labels_series .values if is_categorical else pd .Categorical (labels_series )
380- return coords , labels_cat
309+ return labels_series .values if is_categorical else pd .Categorical (labels_series )
381310
382311
383312def _get_labels_from_table (
384313 sdata : sd .SpatialData ,
385314 element : str ,
386315 color : str ,
387316 table_name : str | None = None ,
388- ) -> tuple [pd .Series , np .ndarray ]:
389- """Extract a column from the table linked to an element.
390-
391- Returns (labels_series, element_indices) where element_indices maps
392- each table row to its position in the element, ensuring coord-label
393- alignment.
394- """
395- from spatialdata .models import get_table_keys
396-
317+ ) -> pd .Series :
318+ """Extract a column from the table linked to an element."""
397319 matches : list [str ] = []
398320 for name in sdata .tables :
399321 table = sdata .tables [name ]
@@ -423,29 +345,7 @@ def _get_labels_from_table(
423345 )
424346
425347 table = sdata .tables [resolved_name ]
426- _ , _ , instance_key = get_table_keys (table )
427-
428- # Join on instance key to align table rows with element positions
429- instance_ids = table .obs [instance_key ].values
430- element_index = sdata .shapes [element ].index if element in sdata .shapes else sdata .points [element ].compute ().index
431-
432- # Map each table instance_id to its position in the element index
433- element_idx_map = {val : i for i , val in enumerate (element_index )}
434- matched_indices = []
435- valid_mask = []
436- for iid in instance_ids :
437- if iid in element_idx_map :
438- matched_indices .append (element_idx_map [iid ])
439- valid_mask .append (True )
440- else :
441- valid_mask .append (False )
442-
443- valid_mask_arr = np .array (valid_mask )
444- if not any (valid_mask ):
445- raise ValueError (f"No matching instance keys between table '{ resolved_name } ' and element '{ element } '." )
446-
447- labels = table .obs .loc [valid_mask_arr , color ]
448- return labels .reset_index (drop = True ), np .array (matched_indices )
348+ return table .obs [color ].reset_index (drop = True )
449349
450350
451351# ---------------------------------------------------------------------------
@@ -461,16 +361,7 @@ def _get_labels_from_table(
461361 "tritanopia" : "tritanopia" ,
462362}
463363
464- # Maps spaco methods → CVD type (None = normal vision).
465- _SPACO_CVD_TYPES : dict [str , str | None ] = {
466- "spaco" : None ,
467- "spaco_colorblind" : "general" ,
468- "spaco_protanopia" : "protanopia" ,
469- "spaco_deuteranopia" : "deuteranopia" ,
470- "spaco_tritanopia" : "tritanopia" ,
471- }
472-
473- _ALL_METHODS = sorted ({"default" , * _CONTRAST_CVD_TYPES , * _SPACO_CVD_TYPES })
364+ _ALL_METHODS = sorted ({"default" , * _CONTRAST_CVD_TYPES })
474365
475366
476367# ---------------------------------------------------------------------------
@@ -484,11 +375,6 @@ def _get_labels_from_table(
484375 "protanopia" ,
485376 "deuteranopia" ,
486377 "tritanopia" ,
487- "spaco" ,
488- "spaco_colorblind" ,
489- "spaco_protanopia" ,
490- "spaco_deuteranopia" ,
491- "spaco_tritanopia" ,
492378]
493379
494380
@@ -528,9 +414,6 @@ def make_palette(
528414 under worst-case colour-vision deficiency.
529415 - ``"protanopia"`` / ``"deuteranopia"`` / ``"tritanopia"`` —
530416 reorder for a specific colour-vision deficiency.
531-
532- The ``spaco*`` methods require spatial data and are only
533- available via :func:`make_palette_from_data`.
534417 n_random
535418 Random permutations to try (optimisation methods only).
536419 n_swaps
@@ -553,9 +436,6 @@ def make_palette(
553436 if n < 1 :
554437 raise ValueError (f"n must be at least 1, got { n } ." )
555438
556- if method in _SPACO_CVD_TYPES :
557- raise ValueError (f"Method '{ method } ' requires spatial data. Use make_palette_from_data() instead." )
558-
559439 colors = _resolve_palette (palette , n )
560440
561441 if method == "default" :
@@ -577,7 +457,6 @@ def make_palette_from_data(
577457 palette : list [str ] | str | None = None ,
578458 method : Method = "default" ,
579459 table_name : str | None = None ,
580- n_neighbors : int = 15 ,
581460 n_random : int = 5000 ,
582461 n_swaps : int = 10000 ,
583462 seed : int = 0 ,
@@ -605,25 +484,13 @@ def make_palette_from_data(
605484 Name of the table to use when *color* is looked up from a linked
606485 table. Required when multiple tables annotate the same element.
607486 method
608- Strategy for assigning colours to categories. Accepts all
609- methods from :func:`make_palette` plus spatially-aware ones:
487+ Strategy for assigning colours to categories:
610488
611489 - ``"default"`` — assign in sorted category order (reproduces
612490 the current render-pipeline behaviour).
613491 - ``"contrast"`` / ``"colorblind"`` / ``"protanopia"`` /
614492 ``"deuteranopia"`` / ``"tritanopia"`` — reorder to maximise
615- perceptual spread (ignores spatial layout).
616- - ``"spaco"`` — spatially-aware assignment (Jing et al.,
617- *Patterns* 2023). Maximises perceptual contrast between
618- categories that are spatially interleaved.
619- - ``"spaco_colorblind"`` — like ``"spaco"`` but optimises under
620- worst-case colour-vision deficiency (all three types).
621- - ``"spaco_protanopia"`` / ``"spaco_deuteranopia"`` /
622- ``"spaco_tritanopia"`` — like ``"spaco"`` but optimises for
623- a specific colour-vision deficiency.
624- n_neighbors
625- Only used with ``spaco`` methods. Number of spatial neighbours
626- for the interlacement computation.
493+ perceptual spread.
627494 n_random
628495 Random permutations to try (optimisation methods only).
629496 n_swaps
@@ -641,11 +508,11 @@ def make_palette_from_data(
641508 --------
642509 >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type")
643510 >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", palette="tab10")
644- >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco ")
645- >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="spaco_colorblind ")
511+ >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="contrast ")
512+ >>> palette = sdp.pl.make_palette_from_data(sdata, "cells", "cell_type", method="colorblind ")
646513 >>> sdata.pl.render_shapes("cells", color="cell_type", palette=palette).pl.show()
647514 """
648- coords , labels_cat = _resolve_element (sdata , element , color , table_name = table_name )
515+ labels_cat = _resolve_element (sdata , element , color , table_name = table_name )
649516
650517 categories = list (labels_cat .categories )
651518 n_cat = len (categories )
@@ -657,42 +524,12 @@ def make_palette_from_data(
657524 if method == "default" :
658525 return {cat : to_hex (to_rgb (c )) for cat , c in zip (categories , colors_list , strict = True )}
659526
660- # Non-spatial contrast methods (same as make_palette but returns dict)
661527 if method in _CONTRAST_CVD_TYPES :
662528 cvd_type = _CONTRAST_CVD_TYPES [method ]
663529 reordered = _optimized_order (
664530 colors_list , colorblind_type = cvd_type , n_random = n_random , n_swaps = n_swaps , seed = seed
665531 )
666532 return dict (zip (categories , reordered , strict = True ))
667533
668- # Spaco methods (spatially-aware)
669- if method in _SPACO_CVD_TYPES :
670- cvd_type = _SPACO_CVD_TYPES [method ]
671-
672- # Filter NaN labels
673- mask = labels_cat .codes != - 1
674- coords_clean = coords [mask ]
675- labels_clean = np .array (categories )[labels_cat .codes [mask ]]
676-
677- if len (coords_clean ) == 0 :
678- raise ValueError (f"All values in column '{ color } ' are NaN." )
679-
680- rgb = np .array ([to_rgb (c ) for c in colors_list ])
681-
682- if n_cat == 1 :
683- return {categories [0 ]: to_hex (rgb [0 ])}
684-
685- logger .info (f"Computing spatial interlacement for { n_cat } categories ({ len (coords_clean )} cells)..." )
686- inter = _spatial_interlacement (coords_clean , labels_clean , categories , n_neighbors = n_neighbors )
687-
688- logger .info ("Computing perceptual distance matrix..." )
689- cdist = _perceptual_distance_matrix (rgb , colorblind_type = cvd_type )
690-
691- logger .info ("Optimizing color assignment..." )
692- rng = np .random .default_rng (seed )
693- perm = _optimize_assignment (inter , cdist , n_random = n_random , n_swaps = n_swaps , rng = rng )
694-
695- return {cat : to_hex (rgb [perm [i ]]) for i , cat in enumerate (categories )}
696-
697534 valid = ", " .join (f"'{ m } '" for m in _ALL_METHODS )
698535 raise ValueError (f"Unknown method '{ method } '. Choose from { valid } ." )
0 commit comments