|
1 | 1 | """k-NN based mapping of labels, embeddings, and expression values.""" |
2 | 2 |
|
3 | 3 | import gc |
4 | | -from typing import Any, Literal, cast |
| 4 | +from typing import Any, Literal |
5 | 5 |
|
6 | 6 | import numpy as np |
7 | 7 | import pandas as pd |
|
14 | 14 | from cellmapper.logging import logger |
15 | 15 | from cellmapper.model.embedding import EmbeddingMixin |
16 | 16 | from cellmapper.model.evaluate import EvaluationMixin |
| 17 | +from cellmapper.model.kernel import Kernel |
17 | 18 | from cellmapper.model.mapping_operator import MappingOperator |
18 | | -from cellmapper.model.neighbors import Neighbors |
19 | 19 | from cellmapper.utils import create_imputed_anndata, get_n_comps |
20 | 20 |
|
21 | 21 |
|
@@ -61,7 +61,7 @@ def __init__(self, query: AnnData, reference: AnnData | None = None) -> None: |
61 | 61 | ) |
62 | 62 |
|
63 | 63 | # Initialize result containers |
64 | | - self.knn: Neighbors | None = None |
| 64 | + self.knn: Kernel | None = None |
65 | 65 | self._mapping_operator: MappingOperator | None = None |
66 | 66 | self.label_transfer_metrics: dict[str, Any] | None = None |
67 | 67 | self.label_transfer_report: pd.DataFrame | None = None |
@@ -220,7 +220,7 @@ def compute_neighbors( |
220 | 220 | xrep = xrep[:, :n_comps] |
221 | 221 | yrep = yrep[:, :n_comps] |
222 | 222 |
|
223 | | - self.knn = Neighbors( |
| 223 | + self.knn = Kernel( |
224 | 224 | np.ascontiguousarray(xrep), |
225 | 225 | None if self._is_self_mapping else np.ascontiguousarray(yrep), |
226 | 226 | is_self_mapping=self._is_self_mapping, |
@@ -315,51 +315,25 @@ def compute_mapping_matrix( |
315 | 315 |
|
316 | 316 | logger.info("Computing mapping matrix using method '%s'.", method) |
317 | 317 |
|
318 | | - if method in ["jaccard", "hnoca"]: |
319 | | - # In cross-mapping mode, we need all four adjacency matrices |
320 | | - if self.only_yx and not self._is_self_mapping: |
321 | | - raise ValueError( |
322 | | - "Jaccard and HNOCa methods require both x and y neighbors to be computed in cross-mapping mode. Set only_yx=False." |
323 | | - ) |
| 318 | + # Compute kernel matrix using the new unified method |
| 319 | + self.knn.compute_kernel_matrix( |
| 320 | + method=method, |
| 321 | + symmetrize=symmetrize, |
| 322 | + self_edges=self_edges, |
| 323 | + ) |
324 | 324 |
|
325 | | - # symmetrize and self_edges only apply to self-terms (xx, yy) in cross-mapping mode |
326 | | - xx, yy, xy, yx = self.knn.get_adjacency_matrices( |
327 | | - symmetrize=symmetrize, |
328 | | - self_edges=True, |
329 | | - ) |
330 | | - # Type assertion for mypy - get_adjacency_matrices validates that xx is not None |
331 | | - n_neighbors = self.knn.yx.n_neighbors |
332 | | - |
333 | | - kernel_matrix = (yx @ xx.T) + (yy @ xy.T) |
334 | | - |
335 | | - if method == "jaccard": |
336 | | - kernel_matrix.data /= 4 * n_neighbors - kernel_matrix.data |
337 | | - elif method == "hnoca": |
338 | | - kernel_matrix.data /= 2 * n_neighbors - kernel_matrix.data |
339 | | - kernel_matrix.data = kernel_matrix.data**2 |
340 | | - |
341 | | - elif method in ["gauss", "scarches", "inverse_distance", "random", "equal", "umap"]: |
342 | | - # Validate self-mapping-only kernels |
343 | | - if method in PackageConstants.SELF_MAPPING_ONLY_KERNELS and not self._is_self_mapping: |
344 | | - raise ValueError(f"Method '{method}' is only supported for self-mapping mode. ") |
345 | | - |
346 | | - # Type cast to satisfy the type checker since we've filtered to only valid kernel methods |
347 | | - kernel_method = cast( |
348 | | - Literal["gauss", "scarches", "inverse_distance", "random", "equal", "umap"], |
349 | | - method, |
| 325 | + # Validate expected shape before creating mapping operator |
| 326 | + expected_shape = (self.query.n_obs, self.reference.n_obs) |
| 327 | + actual_shape = self.knn.kernel_matrix.shape |
| 328 | + if actual_shape != expected_shape: |
| 329 | + raise ValueError( |
| 330 | + f"Kernel matrix shape {actual_shape} does not match expected shape {expected_shape}. " |
| 331 | + f"Expected ({self.query.n_obs} query cells, {self.reference.n_obs} reference cells)." |
350 | 332 | ) |
351 | | - # Type assertion for mypy - neighbors validation ensures yx is not None |
352 | | - kernel_matrix = self.knn.yx.knn_graph_connectivities( |
353 | | - kernel=kernel_method, symmetrize=symmetrize, self_edges=self_edges |
354 | | - ) |
355 | | - else: |
356 | | - raise NotImplementedError(f"Method '{method}' is not implemented.") |
357 | 333 |
|
358 | | - # Create mapping operator with the computed matrix (single point of construction) |
| 334 | + # Create mapping operator with the computed matrix (simplified interface) |
359 | 335 | self._mapping_operator = MappingOperator( |
360 | | - kernel_matrix=kernel_matrix, |
361 | | - is_self_mapping=self._is_self_mapping, |
362 | | - expected_shape=(self.query.n_obs, self.reference.n_obs), |
| 336 | + kernel_matrix=self.knn, # Pass the Kernel object directly |
363 | 337 | n_eigenvectors=n_eigenvectors, |
364 | 338 | eigen_solver=eigen_solver, |
365 | 339 | ) |
@@ -663,7 +637,7 @@ def load_precomputed_distances(self, distances_key: str = "distances", remove_la |
663 | 637 | distances_matrix = csr_matrix(distances_matrix) |
664 | 638 |
|
665 | 639 | # Create a neighbors object using the factory method |
666 | | - self.knn = Neighbors.from_distances(distances_matrix, remove_last_neighbor) |
| 640 | + self.knn = Kernel.from_distances(distances_matrix, remove_last_neighbor) |
667 | 641 |
|
668 | 642 | # Type assertion for mypy - from_distances creates a valid neighbors object with xx |
669 | 643 | assert self.knn.xx is not None |
|
0 commit comments