Skip to content

Commit 041ce94

Browse files
authored
Merge pull request #46 from quadbio/feat/subset_categories
Enable subsetting categories
2 parents b79584b + 41924c3 commit 041ce94

File tree

3 files changed

+227
-21
lines changed

3 files changed

+227
-21
lines changed

src/cellmapper/_docs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@
9191
knn_dist_metric
9292
Distance metric to use for nearest neighbors. See the knn algorithms documentation for details. """
9393

94+
_subset_categories = """\
95+
subset_categories
96+
For categorical data, optionally specify a subset of categories to include in the mapping.
97+
If None (default), all categories are included. If specified, only the listed categories
98+
will be mapped, and others will be ignored. For numerical data, this parameter is ignored
99+
with a warning. Can be a single category string or a list of category strings."""
100+
94101

95102
d = DocstringProcessor(
96103
t=_t,
@@ -106,4 +113,5 @@
106113
n_neighbors=_n_neighbors,
107114
use_rep=_use_rep,
108115
knn_dist_metric=_knn_dist_metric,
116+
subset_categories=_subset_categories,
109117
)

src/cellmapper/model/cellmapper.py

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ def map(
517517
symmetrize: bool | None = None,
518518
self_edges: bool | None = None,
519519
prediction_postfix: str = "pred",
520+
subset_categories: None | list[str] | str = None,
520521
) -> "CellMapper":
521522
"""
522523
Map data from reference to query datasets.
@@ -540,6 +541,7 @@ def map(
540541
%(symmetrize)s
541542
%(self_edges)s
542543
%(prediction_postfix)s
544+
%(subset_categories)s
543545
"""
544546
if self.knn is None:
545547
self.compute_neighbors(
@@ -553,27 +555,23 @@ def map(
553555
self.compute_mapping_matrix(kernel_method=kernel_method, symmetrize=symmetrize, self_edges=self_edges)
554556

555557
if obs_keys is not None:
556-
# Handle both single key and list of keys for backward compatibility
557-
if isinstance(obs_keys, str):
558+
# Normalize to list for consistent handling
559+
obs_keys_list = [obs_keys] if isinstance(obs_keys, str) else obs_keys
560+
for obs_key in obs_keys_list:
558561
self.map_obs(
559-
key=obs_keys, t=t, diffusion_method=diffusion_method, prediction_postfix=prediction_postfix
562+
key=obs_key,
563+
t=t,
564+
diffusion_method=diffusion_method,
565+
prediction_postfix=prediction_postfix,
566+
subset_categories=subset_categories,
560567
)
561-
else:
562-
for obs_key in obs_keys:
563-
self.map_obs(
564-
key=obs_key, t=t, diffusion_method=diffusion_method, prediction_postfix=prediction_postfix
565-
)
566568
if obsm_keys is not None:
567-
# Handle both single key and list of keys for backward compatibility
568-
if isinstance(obsm_keys, str):
569+
# Normalize to list for consistent handling
570+
obsm_keys_list = [obsm_keys] if isinstance(obsm_keys, str) else obsm_keys
571+
for obsm_key in obsm_keys_list:
569572
self.map_obsm(
570-
key=obsm_keys, t=t, diffusion_method=diffusion_method, prediction_postfix=prediction_postfix
573+
key=obsm_key, t=t, diffusion_method=diffusion_method, prediction_postfix=prediction_postfix
571574
)
572-
else:
573-
for obsm_key in obsm_keys:
574-
self.map_obsm(
575-
key=obsm_key, t=t, diffusion_method=diffusion_method, prediction_postfix=prediction_postfix
576-
)
577575
if layer_key is not None:
578576
self.map_layers(key=layer_key, t=t, diffusion_method=diffusion_method)
579577
if obs_keys is None and obsm_keys is None and layer_key is None:
@@ -650,6 +648,7 @@ def map_obs(
650648
prediction_postfix: str = "pred",
651649
confidence_postfix: str = "conf",
652650
return_probabilities: bool = False,
651+
subset_categories: None | list[str] | str = None,
653652
) -> np.ndarray | csr_matrix | None:
654653
"""
655654
Map observation data from reference dataset to query dataset.
@@ -672,6 +671,7 @@ def map_obs(
672671
return_probabilities
673672
If True, return the probability matrix for categorical data.
674673
Only applicable for categorical data. The matrix is never densified.
674+
%(subset_categories)s
675675
676676
Returns
677677
-------
@@ -705,6 +705,43 @@ def map_obs(
705705
or pd.api.types.is_string_dtype(reference_data)
706706
)
707707

708+
# Handle subset_categories parameter and warnings
709+
if subset_categories is not None:
710+
if not is_categorical:
711+
logger.warning(
712+
"subset_categories parameter specified for numerical data in key '%s'. This parameter will be ignored for numerical data.",
713+
key,
714+
)
715+
subset_categories = None
716+
else:
717+
# Convert single string to list for consistent handling
718+
if isinstance(subset_categories, str):
719+
subset_categories = [subset_categories]
720+
721+
# Check if specified categories exist in the data
722+
available_categories = set(
723+
reference_data.cat.categories if hasattr(reference_data, "cat") else reference_data.unique()
724+
)
725+
invalid_categories = set(subset_categories) - available_categories
726+
727+
if invalid_categories:
728+
logger.warning(
729+
"Some specified categories for key '%s' do not exist in the data and will be ignored: %s. Available categories: %s",
730+
key,
731+
list(invalid_categories),
732+
list(available_categories),
733+
)
734+
# Filter out invalid categories
735+
subset_categories = [cat for cat in subset_categories if cat in available_categories]
736+
737+
# If no valid categories remain, set to None to use all
738+
if not subset_categories:
739+
logger.warning(
740+
"No valid categories remaining for key '%s' after filtering. Using all available categories.",
741+
key,
742+
)
743+
subset_categories = None
744+
708745
# Log the operation being performed
709746
data_type = "categorical" if is_categorical else "numerical"
710747
if t is None:
@@ -720,7 +757,13 @@ def map_obs(
720757

721758
if is_categorical:
722759
return self._map_obs_categorical(
723-
key, prediction_postfix, confidence_postfix, t, diffusion_method, return_probabilities
760+
key,
761+
prediction_postfix,
762+
confidence_postfix,
763+
t,
764+
diffusion_method,
765+
return_probabilities,
766+
subset_categories,
724767
)
725768
else:
726769
if return_probabilities:
@@ -736,12 +779,35 @@ def _map_obs_categorical(
736779
t: int | None,
737780
diffusion_method: Literal["iterative", "spectral"],
738781
return_probabilities: bool = False,
782+
subset_categories: None | list[str] = None,
739783
) -> np.ndarray | csr_matrix | None:
740784
"""Map categorical observation data using one-hot encoding."""
741-
onehot = OneHotEncoder(dtype=np.float32)
742-
xtab = onehot.fit_transform(
743-
self.reference.obs[[key]],
744-
) # shape = (n_reference_cells x n_categories), sparse csr matrix, float32
785+
# Get the reference data
786+
reference_data = self.reference.obs[key]
787+
788+
if subset_categories is not None:
789+
# Create a filtered version of reference data for one-hot encoding
790+
# Only include rows that have the desired categories
791+
mask = reference_data.isin(subset_categories)
792+
793+
# Create a filtered DataFrame with only the subset categories
794+
filtered_reference_data = reference_data.copy()
795+
filtered_reference_data[~mask] = pd.NA # Set non-subset categories to missing
796+
797+
# Create one-hot encoding only for the subset categories
798+
onehot = OneHotEncoder(dtype=np.float32, handle_unknown="ignore")
799+
# Create a DataFrame with only subset categories for fitting
800+
subset_df = pd.DataFrame({key: pd.Categorical(subset_categories, categories=subset_categories)})
801+
onehot.fit(subset_df)
802+
803+
# Transform the full reference data (missing values will be ignored)
804+
xtab = onehot.transform(filtered_reference_data.to_frame())
805+
else:
806+
# Use the original approach for all categories
807+
onehot = OneHotEncoder(dtype=np.float32)
808+
xtab = onehot.fit_transform(self.reference.obs[[key]])
809+
810+
# Apply the mapping
745811
ytab = self.mapping_operator.apply(
746812
xtab, t=t, diffusion_method=diffusion_method
747813
) # shape = (n_query_cells x n_categories), sparse csr matrix, float32

tests/model/test_query_to_reference_mapping.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,135 @@ def test_map_obs_pseudotime_cross_mapping(self, query_reference_adata):
307307

308308
# Verify no confidence scores for numerical data
309309
assert "dpt_pseudotime_conf" not in cmap.query.obs
310+
311+
def test_map_obs_subset_categories(self, query_reference_adata):
312+
"""Test mapping with subset_categories parameter for categorical data."""
313+
query, reference = query_reference_adata
314+
315+
# Create CellMapper and compute mapping matrix
316+
cmap = CellMapper(query=query, reference=reference)
317+
cmap.compute_neighbors(n_neighbors=30, use_rep="X_pca", knn_method="sklearn")
318+
cmap.compute_mapping_matrix(kernel_method="gauss")
319+
320+
# Get available leiden categories in reference
321+
available_categories = list(reference.obs["leiden"].cat.categories)
322+
323+
# Test with subset of categories
324+
subset_cats = available_categories[:2] # Take first 2 categories
325+
cmap.map_obs(key="leiden", subset_categories=subset_cats)
326+
327+
# Check that mapping was performed
328+
assert "leiden_pred" in cmap.query.obs
329+
assert "leiden_conf" in cmap.query.obs
330+
331+
# Check that predictions only contain subset categories (or might be missing if no assignment)
332+
predicted_categories = set(cmap.query.obs["leiden_pred"].dropna().unique())
333+
assert predicted_categories.issubset(set(subset_cats)), (
334+
f"Predicted categories {predicted_categories} not subset of {subset_cats}"
335+
)
336+
337+
def test_map_obs_subset_categories_single_string(self, query_reference_adata):
338+
"""Test mapping with subset_categories as single string."""
339+
query, reference = query_reference_adata
340+
341+
# Create CellMapper and compute mapping matrix
342+
cmap = CellMapper(query=query, reference=reference)
343+
cmap.compute_neighbors(n_neighbors=30, use_rep="X_pca", knn_method="sklearn")
344+
cmap.compute_mapping_matrix(kernel_method="gauss")
345+
346+
# Get first available category
347+
first_category = reference.obs["leiden"].cat.categories[0]
348+
349+
# Test with single category as string
350+
cmap.map_obs(key="leiden", subset_categories=first_category)
351+
352+
# Check that mapping was performed and only contains the specified category
353+
assert "leiden_pred" in cmap.query.obs
354+
predicted_categories = set(cmap.query.obs["leiden_pred"].dropna().unique())
355+
assert predicted_categories.issubset({first_category}), (
356+
f"Predicted categories {predicted_categories} not subset of {first_category}"
357+
)
358+
359+
def test_map_obs_subset_categories_invalid_categories(self, query_reference_adata, caplog):
360+
"""Test mapping with some invalid categories in subset_categories."""
361+
query, reference = query_reference_adata
362+
363+
# Create CellMapper and compute mapping matrix
364+
cmap = CellMapper(query=query, reference=reference)
365+
cmap.compute_neighbors(n_neighbors=30, use_rep="X_pca", knn_method="sklearn")
366+
cmap.compute_mapping_matrix(kernel_method="gauss")
367+
368+
# Mix valid and invalid categories
369+
valid_category = reference.obs["leiden"].cat.categories[0]
370+
invalid_categories = ["nonexistent1", "nonexistent2"]
371+
mixed_categories = [valid_category] + invalid_categories
372+
373+
# Test with mixed valid/invalid categories - should work without errors
374+
cmap.map_obs(key="leiden", subset_categories=mixed_categories)
375+
376+
# Check that mapping still worked with valid categories
377+
assert "leiden_pred" in cmap.query.obs
378+
predicted_categories = set(cmap.query.obs["leiden_pred"].dropna().unique())
379+
assert predicted_categories.issubset({valid_category})
380+
381+
def test_map_obs_subset_categories_all_invalid(self, query_reference_adata, caplog):
382+
"""Test mapping with all invalid categories in subset_categories."""
383+
query, reference = query_reference_adata
384+
385+
# Create CellMapper and compute mapping matrix
386+
cmap = CellMapper(query=query, reference=reference)
387+
cmap.compute_neighbors(n_neighbors=30, use_rep="X_pca", knn_method="sklearn")
388+
cmap.compute_mapping_matrix(kernel_method="gauss")
389+
390+
# Use only invalid categories
391+
invalid_categories = ["nonexistent1", "nonexistent2"]
392+
393+
# Test with all invalid categories - should fallback to using all categories
394+
cmap.map_obs(key="leiden", subset_categories=invalid_categories)
395+
396+
# Check that mapping still worked with all categories (fallback)
397+
assert "leiden_pred" in cmap.query.obs
398+
# Should have predictions from all available categories since it fell back
399+
predicted_categories = set(cmap.query.obs["leiden_pred"].dropna().unique())
400+
available_categories = set(reference.obs["leiden"].cat.categories)
401+
# At least one category should be predicted (could be subset due to k-NN mapping)
402+
assert len(predicted_categories) > 0
403+
assert predicted_categories.issubset(available_categories)
404+
405+
def test_map_obs_subset_categories_numerical_warning(self, query_reference_adata, caplog):
406+
"""Test that subset_categories generates warning for numerical data."""
407+
query, reference = query_reference_adata
408+
409+
# Create CellMapper and compute mapping matrix
410+
cmap = CellMapper(query=query, reference=reference)
411+
cmap.compute_neighbors(n_neighbors=30, use_rep="X_pca", knn_method="sklearn")
412+
cmap.compute_mapping_matrix(kernel_method="gauss")
413+
414+
# Test with numerical data and subset_categories - should work and ignore the parameter
415+
cmap.map_obs(key="dpt_pseudotime", subset_categories=["some_category"])
416+
417+
# Check that mapping still worked normally (parameter was ignored)
418+
assert "dpt_pseudotime_pred" in cmap.query.obs
419+
# Confidence scores should not be created for numerical data
420+
assert "dpt_pseudotime_conf" not in cmap.query.obs
421+
422+
def test_map_method_with_subset_categories(self, query_reference_adata):
423+
"""Test that subset_categories parameter works through the high-level map method."""
424+
query, reference = query_reference_adata
425+
426+
# Create CellMapper
427+
cmap = CellMapper(query=query, reference=reference)
428+
429+
# Get available categories
430+
available_categories = list(reference.obs["leiden"].cat.categories)
431+
subset_cats = available_categories[:2]
432+
433+
# Test high-level map method with subset_categories
434+
cmap.map(
435+
obs_keys="leiden", n_neighbors=30, use_rep="X_pca", kernel_method="gauss", subset_categories=subset_cats
436+
)
437+
438+
# Check that mapping was performed with subset
439+
assert "leiden_pred" in cmap.query.obs
440+
predicted_categories = set(cmap.query.obs["leiden_pred"].dropna().unique())
441+
assert predicted_categories.issubset(set(subset_cats))

0 commit comments

Comments
 (0)