@@ -517,6 +517,7 @@ def map(
517517 symmetrize : bool | None = None ,
518518 self_edges : bool | None = None ,
519519 prediction_postfix : str = "pred" ,
520+ subset_categories : None | list [str ] | str = None ,
520521 ) -> "CellMapper" :
521522 """
522523 Map data from reference to query datasets.
@@ -540,6 +541,7 @@ def map(
540541 %(symmetrize)s
541542 %(self_edges)s
542543 %(prediction_postfix)s
544+ %(subset_categories)s
543545 """
544546 if self .knn is None :
545547 self .compute_neighbors (
@@ -553,27 +555,23 @@ def map(
553555 self .compute_mapping_matrix (kernel_method = kernel_method , symmetrize = symmetrize , self_edges = self_edges )
554556
555557 if obs_keys is not None :
556- # Handle both single key and list of keys for backward compatibility
557- if isinstance (obs_keys , str ):
558+ # Normalize to list for consistent handling
559+ obs_keys_list = [obs_keys ] if isinstance (obs_keys , str ) else obs_keys
560+ for obs_key in obs_keys_list :
558561 self .map_obs (
559- key = obs_keys , t = t , diffusion_method = diffusion_method , prediction_postfix = prediction_postfix
562+ key = obs_key ,
563+ t = t ,
564+ diffusion_method = diffusion_method ,
565+ prediction_postfix = prediction_postfix ,
566+ subset_categories = subset_categories ,
560567 )
561- else :
562- for obs_key in obs_keys :
563- self .map_obs (
564- key = obs_key , t = t , diffusion_method = diffusion_method , prediction_postfix = prediction_postfix
565- )
566568 if obsm_keys is not None :
567- # Handle both single key and list of keys for backward compatibility
568- if isinstance (obsm_keys , str ):
569+ # Normalize to list for consistent handling
570+ obsm_keys_list = [obsm_keys ] if isinstance (obsm_keys , str ) else obsm_keys
571+ for obsm_key in obsm_keys_list :
569572 self .map_obsm (
570- key = obsm_keys , t = t , diffusion_method = diffusion_method , prediction_postfix = prediction_postfix
573+ key = obsm_key , t = t , diffusion_method = diffusion_method , prediction_postfix = prediction_postfix
571574 )
572- else :
573- for obsm_key in obsm_keys :
574- self .map_obsm (
575- key = obsm_key , t = t , diffusion_method = diffusion_method , prediction_postfix = prediction_postfix
576- )
577575 if layer_key is not None :
578576 self .map_layers (key = layer_key , t = t , diffusion_method = diffusion_method )
579577 if obs_keys is None and obsm_keys is None and layer_key is None :
@@ -650,6 +648,7 @@ def map_obs(
650648 prediction_postfix : str = "pred" ,
651649 confidence_postfix : str = "conf" ,
652650 return_probabilities : bool = False ,
651+ subset_categories : None | list [str ] | str = None ,
653652 ) -> np .ndarray | csr_matrix | None :
654653 """
655654 Map observation data from reference dataset to query dataset.
@@ -672,6 +671,7 @@ def map_obs(
672671 return_probabilities
673672 If True, return the probability matrix for categorical data.
674673 Only applicable for categorical data. The matrix is never densified.
674+ %(subset_categories)s
675675
676676 Returns
677677 -------
@@ -705,6 +705,43 @@ def map_obs(
705705 or pd .api .types .is_string_dtype (reference_data )
706706 )
707707
708+ # Handle subset_categories parameter and warnings
709+ if subset_categories is not None :
710+ if not is_categorical :
711+ logger .warning (
712+ "subset_categories parameter specified for numerical data in key '%s'. This parameter will be ignored for numerical data." ,
713+ key ,
714+ )
715+ subset_categories = None
716+ else :
717+ # Convert single string to list for consistent handling
718+ if isinstance (subset_categories , str ):
719+ subset_categories = [subset_categories ]
720+
721+ # Check if specified categories exist in the data
722+ available_categories = set (
723+ reference_data .cat .categories if hasattr (reference_data , "cat" ) else reference_data .unique ()
724+ )
725+ invalid_categories = set (subset_categories ) - available_categories
726+
727+ if invalid_categories :
728+ logger .warning (
729+ "Some specified categories for key '%s' do not exist in the data and will be ignored: %s. Available categories: %s" ,
730+ key ,
731+ list (invalid_categories ),
732+ list (available_categories ),
733+ )
734+ # Filter out invalid categories
735+ subset_categories = [cat for cat in subset_categories if cat in available_categories ]
736+
737+ # If no valid categories remain, set to None to use all
738+ if not subset_categories :
739+ logger .warning (
740+ "No valid categories remaining for key '%s' after filtering. Using all available categories." ,
741+ key ,
742+ )
743+ subset_categories = None
744+
708745 # Log the operation being performed
709746 data_type = "categorical" if is_categorical else "numerical"
710747 if t is None :
@@ -720,7 +757,13 @@ def map_obs(
720757
721758 if is_categorical :
722759 return self ._map_obs_categorical (
723- key , prediction_postfix , confidence_postfix , t , diffusion_method , return_probabilities
760+ key ,
761+ prediction_postfix ,
762+ confidence_postfix ,
763+ t ,
764+ diffusion_method ,
765+ return_probabilities ,
766+ subset_categories ,
724767 )
725768 else :
726769 if return_probabilities :
@@ -736,12 +779,35 @@ def _map_obs_categorical(
736779 t : int | None ,
737780 diffusion_method : Literal ["iterative" , "spectral" ],
738781 return_probabilities : bool = False ,
782+ subset_categories : None | list [str ] = None ,
739783 ) -> np .ndarray | csr_matrix | None :
740784 """Map categorical observation data using one-hot encoding."""
741- onehot = OneHotEncoder (dtype = np .float32 )
742- xtab = onehot .fit_transform (
743- self .reference .obs [[key ]],
744- ) # shape = (n_reference_cells x n_categories), sparse csr matrix, float32
785+ # Get the reference data
786+ reference_data = self .reference .obs [key ]
787+
788+ if subset_categories is not None :
789+ # Create a filtered version of reference data for one-hot encoding
790+ # Only include rows that have the desired categories
791+ mask = reference_data .isin (subset_categories )
792+
793+ # Create a filtered DataFrame with only the subset categories
794+ filtered_reference_data = reference_data .copy ()
795+ filtered_reference_data [~ mask ] = pd .NA # Set non-subset categories to missing
796+
797+ # Create one-hot encoding only for the subset categories
798+ onehot = OneHotEncoder (dtype = np .float32 , handle_unknown = "ignore" )
799+ # Create a DataFrame with only subset categories for fitting
800+ subset_df = pd .DataFrame ({key : pd .Categorical (subset_categories , categories = subset_categories )})
801+ onehot .fit (subset_df )
802+
803+ # Transform the full reference data (missing values will be ignored)
804+ xtab = onehot .transform (filtered_reference_data .to_frame ())
805+ else :
806+ # Use the original approach for all categories
807+ onehot = OneHotEncoder (dtype = np .float32 )
808+ xtab = onehot .fit_transform (self .reference .obs [[key ]])
809+
810+ # Apply the mapping
745811 ytab = self .mapping_operator .apply (
746812 xtab , t = t , diffusion_method = diffusion_method
747813 ) # shape = (n_query_cells x n_categories), sparse csr matrix, float32
0 commit comments