Skip to content

Commit b49e8a1

Browse files
authored
Merge pull request #36 from quadbio/feat/symmetrization
Refactor the neighbor classes and fix symmetrization
2 parents 88710a7 + 33e0132 commit b49e8a1

File tree

10 files changed

+1107
-971
lines changed

10 files changed

+1107
-971
lines changed

src/cellmapper/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from .logging import logger
44
from .model.cellmapper import CellMapper
5+
from .model.kernel import Kernel
56
from .model.neighbors import Neighbors
67

7-
__all__ = ["logger", "CellMapper", "Neighbors"]
8+
__all__ = ["logger", "CellMapper", "Kernel", "Neighbors"]
89

910
__version__ = version("cellmapper")

src/cellmapper/constants.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ class PackageConstants:
99
DEFAULT_SELF_MAPPING_METHOD: str = "umap"
1010
DEFAULT_CROSS_MAPPING_METHOD: str = "gauss"
1111

12+
# Kernel method categories
13+
JACCARD_BASED_KERNELS = {"jaccard", "hnoca"}
14+
CONNECTIVITY_BASED_KERNELS = {"gauss", "scarches", "inverse_distance", "random", "equal", "umap"}
15+
1216
# Kernel methods that only work in self-mapping mode
13-
SELF_MAPPING_ONLY_KERNELS = {"umap", "adaptive_gauss"}
17+
SELF_MAPPING_ONLY_KERNELS = {"umap"}
1418

1519
# Threshold for recommending spectral method over iterative for matrix powers
1620
SPECTRAL_METHOD_THRESHOLD: int = 10

src/cellmapper/model/cellmapper.py

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""k-NN based mapping of labels, embeddings, and expression values."""
22

33
import gc
4-
from typing import Any, Literal, cast
4+
from typing import Any, Literal
55

66
import numpy as np
77
import pandas as pd
@@ -14,8 +14,8 @@
1414
from cellmapper.logging import logger
1515
from cellmapper.model.embedding import EmbeddingMixin
1616
from cellmapper.model.evaluate import EvaluationMixin
17+
from cellmapper.model.kernel import Kernel
1718
from cellmapper.model.mapping_operator import MappingOperator
18-
from cellmapper.model.neighbors import Neighbors
1919
from cellmapper.utils import create_imputed_anndata, get_n_comps
2020

2121

@@ -61,7 +61,7 @@ def __init__(self, query: AnnData, reference: AnnData | None = None) -> None:
6161
)
6262

6363
# Initialize result containers
64-
self.knn: Neighbors | None = None
64+
self.knn: Kernel | None = None
6565
self._mapping_operator: MappingOperator | None = None
6666
self.label_transfer_metrics: dict[str, Any] | None = None
6767
self.label_transfer_report: pd.DataFrame | None = None
@@ -220,7 +220,7 @@ def compute_neighbors(
220220
xrep = xrep[:, :n_comps]
221221
yrep = yrep[:, :n_comps]
222222

223-
self.knn = Neighbors(
223+
self.knn = Kernel(
224224
np.ascontiguousarray(xrep),
225225
None if self._is_self_mapping else np.ascontiguousarray(yrep),
226226
is_self_mapping=self._is_self_mapping,
@@ -315,51 +315,25 @@ def compute_mapping_matrix(
315315

316316
logger.info("Computing mapping matrix using method '%s'.", method)
317317

318-
if method in ["jaccard", "hnoca"]:
319-
# In cross-mapping mode, we need all four adjacency matrices
320-
if self.only_yx and not self._is_self_mapping:
321-
raise ValueError(
322-
"Jaccard and HNOCa methods require both x and y neighbors to be computed in cross-mapping mode. Set only_yx=False."
323-
)
318+
# Compute kernel matrix using the new unified method
319+
self.knn.compute_kernel_matrix(
320+
method=method,
321+
symmetrize=symmetrize,
322+
self_edges=self_edges,
323+
)
324324

325-
# symmetrize and self_edges only apply to self-terms (xx, yy) in cross-mapping mode
326-
xx, yy, xy, yx = self.knn.get_adjacency_matrices(
327-
symmetrize=symmetrize,
328-
self_edges=True,
329-
)
330-
# Type assertion for mypy - get_adjacency_matrices validates that xx is not None
331-
n_neighbors = self.knn.yx.n_neighbors
332-
333-
kernel_matrix = (yx @ xx.T) + (yy @ xy.T)
334-
335-
if method == "jaccard":
336-
kernel_matrix.data /= 4 * n_neighbors - kernel_matrix.data
337-
elif method == "hnoca":
338-
kernel_matrix.data /= 2 * n_neighbors - kernel_matrix.data
339-
kernel_matrix.data = kernel_matrix.data**2
340-
341-
elif method in ["gauss", "scarches", "inverse_distance", "random", "equal", "umap"]:
342-
# Validate self-mapping-only kernels
343-
if method in PackageConstants.SELF_MAPPING_ONLY_KERNELS and not self._is_self_mapping:
344-
raise ValueError(f"Method '{method}' is only supported for self-mapping mode. ")
345-
346-
# Type cast to satisfy the type checker since we've filtered to only valid kernel methods
347-
kernel_method = cast(
348-
Literal["gauss", "scarches", "inverse_distance", "random", "equal", "umap"],
349-
method,
325+
# Validate expected shape before creating mapping operator
326+
expected_shape = (self.query.n_obs, self.reference.n_obs)
327+
actual_shape = self.knn.kernel_matrix.shape
328+
if actual_shape != expected_shape:
329+
raise ValueError(
330+
f"Kernel matrix shape {actual_shape} does not match expected shape {expected_shape}. "
331+
f"Expected ({self.query.n_obs} query cells, {self.reference.n_obs} reference cells)."
350332
)
351-
# Type assertion for mypy - neighbors validation ensures yx is not None
352-
kernel_matrix = self.knn.yx.knn_graph_connectivities(
353-
kernel=kernel_method, symmetrize=symmetrize, self_edges=self_edges
354-
)
355-
else:
356-
raise NotImplementedError(f"Method '{method}' is not implemented.")
357333

358-
# Create mapping operator with the computed matrix (single point of construction)
334+
# Create mapping operator with the computed matrix (simplified interface)
359335
self._mapping_operator = MappingOperator(
360-
kernel_matrix=kernel_matrix,
361-
is_self_mapping=self._is_self_mapping,
362-
expected_shape=(self.query.n_obs, self.reference.n_obs),
336+
kernel_matrix=self.knn, # Pass the Kernel object directly
363337
n_eigenvectors=n_eigenvectors,
364338
eigen_solver=eigen_solver,
365339
)
@@ -663,7 +637,7 @@ def load_precomputed_distances(self, distances_key: str = "distances", remove_la
663637
distances_matrix = csr_matrix(distances_matrix)
664638

665639
# Create a neighbors object using the factory method
666-
self.knn = Neighbors.from_distances(distances_matrix, remove_last_neighbor)
640+
self.knn = Kernel.from_distances(distances_matrix, remove_last_neighbor)
667641

668642
# Type assertion for mypy - from_distances creates a valid neighbors object with xx
669643
assert self.knn.xx is not None

0 commit comments

Comments
 (0)