@@ -760,7 +760,6 @@ def _set_color_source_vec(
760760 )[value_to_plot ]
761761
762762 # numerical case, return early
763- # TODO temporary split until refactor is complete
764763 if color_source_vector is not None and not isinstance (color_source_vector .dtype , pd .CategoricalDtype ):
765764 if (
766765 not isinstance (element , GeoDataFrame )
@@ -777,18 +776,50 @@ def _set_color_source_vec(
777776
778777 color_source_vector = pd .Categorical (color_source_vector ) # convert, e.g., `pd.Series`
779778
780- # TODO check why table_name is not passed here.
781- color_mapping = _get_categorical_color_mapping (
782- adata = sdata ["table" ],
783- cluster_key = value_to_plot ,
784- color_source_vector = color_source_vector ,
785- cmap_params = cmap_params ,
786- alpha = alpha ,
787- groups = groups ,
788- palette = palette ,
789- na_color = na_color ,
790- render_type = render_type ,
791- )
779+ # Use the provided table_name parameter, fall back to only one present
780+ if table_name is not None :
781+ table_to_use = table_name
782+ else :
783+ table_to_use = list (sdata .tables .keys ())[0 ]
784+ logger .warning (f"No table name provided, using '{ table_to_use } ' as fallback for color mapping." )
785+
786+ # Check if custom colors exist in the table's .uns slot
787+ if _has_colors_in_uns (sdata , table_name , value_to_plot ):
788+ # Extract colors directly from the table's .uns slot
789+ color_mapping = _extract_colors_from_table_uns (
790+ sdata = sdata ,
791+ table_name = table_name ,
792+ col_to_colorby = value_to_plot ,
793+ color_source_vector = color_source_vector ,
794+ na_color = na_color ,
795+ )
796+ if color_mapping is None :
797+ logger .warning (f"Failed to extract colors for '{ value_to_plot } ', falling back to default mapping." )
798+ # Fall back to the existing method if extraction fails
799+ color_mapping = _get_categorical_color_mapping (
800+ adata = sdata [table_to_use ],
801+ cluster_key = value_to_plot ,
802+ color_source_vector = color_source_vector ,
803+ cmap_params = cmap_params ,
804+ alpha = alpha ,
805+ groups = groups ,
806+ palette = palette ,
807+ na_color = na_color ,
808+ render_type = render_type ,
809+ )
810+ else :
811+ # Use the existing color mapping method
812+ color_mapping = _get_categorical_color_mapping (
813+ adata = sdata [table_to_use ],
814+ cluster_key = value_to_plot ,
815+ color_source_vector = color_source_vector ,
816+ cmap_params = cmap_params ,
817+ alpha = alpha ,
818+ groups = groups ,
819+ palette = palette ,
820+ na_color = na_color ,
821+ render_type = render_type ,
822+ )
792823
793824 color_source_vector = color_source_vector .set_categories (color_mapping .keys ())
794825 if color_mapping is None :
@@ -897,6 +928,140 @@ def _generate_base_categorial_color_mapping(
897928 return _get_default_categorial_color_mapping (color_source_vector = color_source_vector , cmap_params = cmap_params )
898929
899930
931+ def _has_colors_in_uns (
932+ sdata : sd .SpatialData ,
933+ table_name : str | None ,
934+ col_to_colorby : str ,
935+ ) -> bool :
936+ """
937+ Check if <column_name>_colors exists in the specified table's .uns slot.
938+
939+ Parameters
940+ ----------
941+ sdata
942+ SpatialData object containing tables
943+ table_name
944+ Name of the table to check. If None, uses the first available table.
945+ col_to_colorby
946+ Name of the categorical column (e.g., "celltype")
947+
948+ Returns
949+ -------
950+ True if <col_to_colorby>_colors exists in the table's .uns, False otherwise
951+ """
952+ # Determine which table to use
953+ if table_name is not None :
954+ if table_name not in sdata .tables :
955+ return False
956+ table_to_use = table_name
957+ else :
958+ if len (sdata .tables ) == 0 :
959+ return False
960+ table_to_use = list (sdata .tables .keys ())[0 ]
961+
962+ adata = sdata .tables [table_to_use ]
963+ color_key = f"{ col_to_colorby } _colors"
964+
965+ return color_key in adata .uns
966+
967+
968+ def _extract_colors_from_table_uns (
969+ sdata : sd .SpatialData ,
970+ table_name : str | None ,
971+ col_to_colorby : str ,
972+ color_source_vector : ArrayLike | pd .Series [CategoricalDtype ],
973+ na_color : ColorLike ,
974+ ) -> Mapping [str , str ] | None :
975+ """
976+ Extract categorical colors from the <column_name>_colors pattern in adata.uns.
977+
978+ This function looks for colors stored in the format <col_to_colorby>_colors in the
979+ specified table's .uns slot and creates a mapping from categories to colors.
980+
981+ Parameters
982+ ----------
983+ sdata
984+ SpatialData object containing tables
985+ table_name
986+ Name of the table to look in. If None, uses the first available table.
987+ col_to_colorby
988+ Name of the categorical column (e.g., "celltype")
989+ color_source_vector
990+ Categorical vector containing the categories to map
991+ na_color
992+ Color to use for NaN/missing values
993+
994+ Returns
995+ -------
996+ Mapping from category names to hex colors, or None if colors not found
997+ """
998+ # Determine which table to use
999+ if table_name is not None :
1000+ if table_name not in sdata .tables :
1001+ logger .warning (f"Table '{ table_name } ' not found in sdata. Available tables: { list (sdata .tables .keys ())} " )
1002+ return None
1003+ table_to_use = table_name
1004+ else :
1005+ if len (sdata .tables ) == 0 :
1006+ logger .warning ("No tables found in sdata." )
1007+ return None
1008+ table_to_use = list (sdata .tables .keys ())[0 ]
1009+ logger .info (f"No table name provided, using '{ table_to_use } ' for color extraction." )
1010+
1011+ adata = sdata .tables [table_to_use ]
1012+ color_key = f"{ col_to_colorby } _colors"
1013+
1014+ # Check if the color pattern exists
1015+ if color_key not in adata .uns :
1016+ logger .debug (f"Color key '{ color_key } ' not found in table '{ table_to_use } ' uns." )
1017+ return None
1018+
1019+ # Extract colors and categories
1020+ stored_colors = adata .uns [color_key ]
1021+ categories = color_source_vector .categories .tolist ()
1022+
1023+ # Validate na_color format
1024+ if "#" not in str (na_color ):
1025+ logger .warning ("Expected `na_color` to be a hex color, converting..." )
1026+ na_color = to_hex (to_rgba (na_color )[:3 ])
1027+
1028+ # Strip alpha channel from na_color if present
1029+ if len (str (na_color )) == 9 : # #rrggbbaa format
1030+ na_color = str (na_color )[:7 ] # Keep only #rrggbb
1031+
1032+ # Convert stored colors to hex format (without alpha channel)
1033+ try :
1034+ hex_colors = []
1035+ for color in stored_colors :
1036+ rgba = to_rgba (color )[:3 ] # Take only RGB, drop alpha
1037+ hex_color = to_hex (rgba )
1038+ # Ensure we strip alpha channel if present
1039+ if len (hex_color ) == 9 : # #rrggbbaa format
1040+ hex_color = hex_color [:7 ] # Keep only #rrggbb
1041+ hex_colors .append (hex_color )
1042+ except Exception as e :
1043+ logger .warning (f"Error converting colors to hex format: { e } " )
1044+ return None
1045+
1046+ # Create the mapping
1047+ color_mapping = {}
1048+
1049+ # Map categories to colors
1050+ for i , category in enumerate (categories ):
1051+ if i < len (hex_colors ):
1052+ color_mapping [category ] = hex_colors [i ]
1053+ else :
1054+ # Not enough colors provided, use na_color for extra categories
1055+ logger .warning (f"Not enough colors provided for category '{ category } ', using na_color." )
1056+ color_mapping [category ] = na_color
1057+
1058+ # Add NaN category
1059+ color_mapping ["NaN" ] = na_color
1060+
1061+ logger .info (f"Successfully extracted { len (hex_colors )} colors from '{ color_key } ' in table '{ table_to_use } '." )
1062+ return color_mapping
1063+
1064+
9001065def _modify_categorical_color_mapping (
9011066 mapping : Mapping [str , str ],
9021067 groups : list [str ] | str | None = None ,
0 commit comments