Skip to content

Commit 0df4c97

Browse files
committed
now respecting uns color
1 parent fd11c33 commit 0df4c97

1 file changed

Lines changed: 178 additions & 13 deletions

File tree

src/spatialdata_plot/pl/utils.py

Lines changed: 178 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,6 @@ def _set_color_source_vec(
760760
)[value_to_plot]
761761

762762
# numerical case, return early
763-
# TODO temporary split until refactor is complete
764763
if color_source_vector is not None and not isinstance(color_source_vector.dtype, pd.CategoricalDtype):
765764
if (
766765
not isinstance(element, GeoDataFrame)
@@ -777,18 +776,50 @@ def _set_color_source_vec(
777776

778777
color_source_vector = pd.Categorical(color_source_vector) # convert, e.g., `pd.Series`
779778

780-
# TODO check why table_name is not passed here.
781-
color_mapping = _get_categorical_color_mapping(
782-
adata=sdata["table"],
783-
cluster_key=value_to_plot,
784-
color_source_vector=color_source_vector,
785-
cmap_params=cmap_params,
786-
alpha=alpha,
787-
groups=groups,
788-
palette=palette,
789-
na_color=na_color,
790-
render_type=render_type,
791-
)
779+
# Use the provided table_name parameter, fall back to only one present
780+
if table_name is not None:
781+
table_to_use = table_name
782+
else:
783+
table_to_use = list(sdata.tables.keys())[0]
784+
logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.")
785+
786+
# Check if custom colors exist in the table's .uns slot
787+
if _has_colors_in_uns(sdata, table_name, value_to_plot):
788+
# Extract colors directly from the table's .uns slot
789+
color_mapping = _extract_colors_from_table_uns(
790+
sdata=sdata,
791+
table_name=table_name,
792+
col_to_colorby=value_to_plot,
793+
color_source_vector=color_source_vector,
794+
na_color=na_color,
795+
)
796+
if color_mapping is None:
797+
logger.warning(f"Failed to extract colors for '{value_to_plot}', falling back to default mapping.")
798+
# Fall back to the existing method if extraction fails
799+
color_mapping = _get_categorical_color_mapping(
800+
adata=sdata[table_to_use],
801+
cluster_key=value_to_plot,
802+
color_source_vector=color_source_vector,
803+
cmap_params=cmap_params,
804+
alpha=alpha,
805+
groups=groups,
806+
palette=palette,
807+
na_color=na_color,
808+
render_type=render_type,
809+
)
810+
else:
811+
# Use the existing color mapping method
812+
color_mapping = _get_categorical_color_mapping(
813+
adata=sdata[table_to_use],
814+
cluster_key=value_to_plot,
815+
color_source_vector=color_source_vector,
816+
cmap_params=cmap_params,
817+
alpha=alpha,
818+
groups=groups,
819+
palette=palette,
820+
na_color=na_color,
821+
render_type=render_type,
822+
)
792823

793824
color_source_vector = color_source_vector.set_categories(color_mapping.keys())
794825
if color_mapping is None:
@@ -897,6 +928,140 @@ def _generate_base_categorial_color_mapping(
897928
return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params)
898929

899930

931+
def _has_colors_in_uns(
932+
sdata: sd.SpatialData,
933+
table_name: str | None,
934+
col_to_colorby: str,
935+
) -> bool:
936+
"""
937+
Check if <column_name>_colors exists in the specified table's .uns slot.
938+
939+
Parameters
940+
----------
941+
sdata
942+
SpatialData object containing tables
943+
table_name
944+
Name of the table to check. If None, uses the first available table.
945+
col_to_colorby
946+
Name of the categorical column (e.g., "celltype")
947+
948+
Returns
949+
-------
950+
True if <col_to_colorby>_colors exists in the table's .uns, False otherwise
951+
"""
952+
# Determine which table to use
953+
if table_name is not None:
954+
if table_name not in sdata.tables:
955+
return False
956+
table_to_use = table_name
957+
else:
958+
if len(sdata.tables) == 0:
959+
return False
960+
table_to_use = list(sdata.tables.keys())[0]
961+
962+
adata = sdata.tables[table_to_use]
963+
color_key = f"{col_to_colorby}_colors"
964+
965+
return color_key in adata.uns
966+
967+
968+
def _extract_colors_from_table_uns(
969+
sdata: sd.SpatialData,
970+
table_name: str | None,
971+
col_to_colorby: str,
972+
color_source_vector: ArrayLike | pd.Series[CategoricalDtype],
973+
na_color: ColorLike,
974+
) -> Mapping[str, str] | None:
975+
"""
976+
Extract categorical colors from the <column_name>_colors pattern in adata.uns.
977+
978+
This function looks for colors stored in the format <col_to_colorby>_colors in the
979+
specified table's .uns slot and creates a mapping from categories to colors.
980+
981+
Parameters
982+
----------
983+
sdata
984+
SpatialData object containing tables
985+
table_name
986+
Name of the table to look in. If None, uses the first available table.
987+
col_to_colorby
988+
Name of the categorical column (e.g., "celltype")
989+
color_source_vector
990+
Categorical vector containing the categories to map
991+
na_color
992+
Color to use for NaN/missing values
993+
994+
Returns
995+
-------
996+
Mapping from category names to hex colors, or None if colors not found
997+
"""
998+
# Determine which table to use
999+
if table_name is not None:
1000+
if table_name not in sdata.tables:
1001+
logger.warning(f"Table '{table_name}' not found in sdata. Available tables: {list(sdata.tables.keys())}")
1002+
return None
1003+
table_to_use = table_name
1004+
else:
1005+
if len(sdata.tables) == 0:
1006+
logger.warning("No tables found in sdata.")
1007+
return None
1008+
table_to_use = list(sdata.tables.keys())[0]
1009+
logger.info(f"No table name provided, using '{table_to_use}' for color extraction.")
1010+
1011+
adata = sdata.tables[table_to_use]
1012+
color_key = f"{col_to_colorby}_colors"
1013+
1014+
# Check if the color pattern exists
1015+
if color_key not in adata.uns:
1016+
logger.debug(f"Color key '{color_key}' not found in table '{table_to_use}' uns.")
1017+
return None
1018+
1019+
# Extract colors and categories
1020+
stored_colors = adata.uns[color_key]
1021+
categories = color_source_vector.categories.tolist()
1022+
1023+
# Validate na_color format
1024+
if "#" not in str(na_color):
1025+
logger.warning("Expected `na_color` to be a hex color, converting...")
1026+
na_color = to_hex(to_rgba(na_color)[:3])
1027+
1028+
# Strip alpha channel from na_color if present
1029+
if len(str(na_color)) == 9: # #rrggbbaa format
1030+
na_color = str(na_color)[:7] # Keep only #rrggbb
1031+
1032+
# Convert stored colors to hex format (without alpha channel)
1033+
try:
1034+
hex_colors = []
1035+
for color in stored_colors:
1036+
rgba = to_rgba(color)[:3] # Take only RGB, drop alpha
1037+
hex_color = to_hex(rgba)
1038+
# Ensure we strip alpha channel if present
1039+
if len(hex_color) == 9: # #rrggbbaa format
1040+
hex_color = hex_color[:7] # Keep only #rrggbb
1041+
hex_colors.append(hex_color)
1042+
except Exception as e:
1043+
logger.warning(f"Error converting colors to hex format: {e}")
1044+
return None
1045+
1046+
# Create the mapping
1047+
color_mapping = {}
1048+
1049+
# Map categories to colors
1050+
for i, category in enumerate(categories):
1051+
if i < len(hex_colors):
1052+
color_mapping[category] = hex_colors[i]
1053+
else:
1054+
# Not enough colors provided, use na_color for extra categories
1055+
logger.warning(f"Not enough colors provided for category '{category}', using na_color.")
1056+
color_mapping[category] = na_color
1057+
1058+
# Add NaN category
1059+
color_mapping["NaN"] = na_color
1060+
1061+
logger.info(f"Successfully extracted {len(hex_colors)} colors from '{color_key}' in table '{table_to_use}'.")
1062+
return color_mapping
1063+
1064+
9001065
def _modify_categorical_color_mapping(
9011066
mapping: Mapping[str, str],
9021067
groups: list[str] | str | None = None,

0 commit comments

Comments
 (0)