Skip to content

Commit 4f21b1b

Browse files
wes-lewisgithub-actions[bot]scottgigante-immunai
authored
Add reversed norm order for ALRA in Denoising Task (#835)
* add reverse order/regular order Try to match the magic code format for decorators and implementation of reverse norm * Update __init__.py * pre-commit * Update alra.py * Update alra.py * pre-commit * Fix method names * Update alra.py * function names should be lowercase * Update alra.R * X is unused * Fix bug * Revert 5fa2c64 [formerly 77302de] * pre-commit * Actually pass sqrt in sqrt norm methods --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Scott Gigante <[email protected]> Former-commit-id: 86375dd
1 parent 1caa044 commit 4f21b1b

File tree

2 files changed

+58
-57
lines changed

2 files changed

+58
-57
lines changed

openproblems/tasks/denoising/methods/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .alra import alra_log
2+
from .alra import alra_log_reversenorm
23
from .alra import alra_sqrt
4+
from .alra import alra_sqrt_reversenorm
35
from .baseline import no_denoising
46
from .baseline import perfect_denoising
57
from .dca import dca
+56-57
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,61 @@
11
from ....tools.conversion import r_function
22
from ....tools.decorators import method
33

4+
import functools
45
import logging
56

6-
_alra = r_function("alra.R")
7+
_r_alra = r_function("alra.R")
78

89
log = logging.getLogger("openproblems")
910

1011

11-
@method(
12-
method_name="ALRA (sqrt norm, reversed normalization)",
12+
method_name = ("ALRA (sqrt norm, reversed normalization)",)
13+
_alra_method = functools.partial(
14+
method,
1315
paper_name="Zero-preserving imputation of scRNA-seq data using "
1416
"low-rank approximation",
1517
paper_reference="linderman2018zero",
1618
paper_year=2018,
1719
code_url="https://github.com/KlugerLab/ALRA",
1820
image="openproblems-r-extras",
1921
)
20-
def alra_sqrt(adata, test=False):
22+
23+
24+
def _alra(adata, normtype="log", reverse_norm_order=False, test=False):
2125
import numpy as np
2226
import rpy2.rinterface_lib.embedded
2327
import scprep
2428

25-
# libsize and sqrt norm
26-
adata.obsm["train_norm"] = scprep.utils.matrix_transform(
27-
adata.obsm["train"], np.sqrt
28-
)
29-
adata.obsm["train_norm"], libsize = scprep.normalize.library_size_normalize(
30-
adata.obsm["train_norm"], rescale=1, return_library_size=True
31-
)
32-
adata.obsm["train_norm"] = adata.obsm["train_norm"].tocsr()
29+
if normtype == "sqrt":
30+
norm_fn = np.sqrt
31+
denorm_fn = np.square
32+
elif normtype == "log":
33+
norm_fn = np.log1p
34+
denorm_fn = np.expm1
35+
else:
36+
raise NotImplementedError
37+
38+
X = adata.obsm["train"].copy()
39+
if reverse_norm_order:
40+
# inexplicably, this sometimes performs better
41+
X = scprep.utils.matrix_transform(X, norm_fn)
42+
X, libsize = scprep.normalize.library_size_normalize(
43+
X, rescale=1, return_library_size=True
44+
)
45+
else:
46+
X, libsize = scprep.normalize.library_size_normalize(
47+
X, rescale=1, return_library_size=True
48+
)
49+
X = scprep.utils.matrix_transform(X, norm_fn)
50+
51+
adata.obsm["train_norm"] = X.tocsr()
3352
# run alra
34-
# _alra takes sparse array, returns dense array
53+
# _r_alra takes sparse array, returns dense array
3554
Y = None
3655
attempts = 0
3756
while Y is None:
3857
try:
39-
Y = _alra(adata)
58+
Y = _r_alra(adata)
4059
except rpy2.rinterface_lib.embedded.RRuntimeError: # pragma: no cover
4160
if attempts < 10:
4261
attempts += 1
@@ -46,57 +65,37 @@ def alra_sqrt(adata, test=False):
4665

4766
# transform back into original space
4867
# functions are reversed!
49-
Y = scprep.utils.matrix_transform(Y, np.square)
68+
Y = scprep.utils.matrix_transform(Y, denorm_fn)
5069
Y = scprep.utils.matrix_vector_elementwise_multiply(Y, libsize, axis=0)
5170
adata.obsm["denoised"] = Y
5271

5372
adata.uns["method_code_version"] = "1.0.0"
5473
return adata
5574

5675

57-
@method(
58-
method_name="ALRA (log norm)",
59-
paper_name="Zero-preserving imputation of scRNA-seq data using "
60-
"low-rank approximation",
61-
paper_reference="linderman2018zero",
62-
paper_year=2018,
63-
code_url="https://github.com/KlugerLab/ALRA",
64-
image="openproblems-r-extras",
76+
@_alra_method(
77+
method_name="ALRA (sqrt norm, reversed normalization)",
6578
)
66-
def alra_log(adata, test=False):
67-
import numpy as np
68-
import rpy2.rinterface_lib.embedded
69-
import scprep
79+
def alra_sqrt_reversenorm(adata, test=False):
80+
return _alra(adata, normtype="sqrt", reverse_norm_order=True, test=False)
7081

71-
# libsize and log norm
72-
# lib norm
73-
adata.obsm["train_norm"], libsize = scprep.normalize.library_size_normalize(
74-
adata.obsm["train"], rescale=1, return_library_size=True
75-
)
76-
# log
77-
adata.obsm["train_norm"] = scprep.utils.matrix_transform(
78-
adata.obsm["train_norm"], np.log1p
79-
)
80-
# to csr
81-
adata.obsm["train_norm"] = adata.obsm["train_norm"].tocsr()
82-
# run alra
83-
# _alra takes sparse array, returns dense array
84-
Y = None
85-
attempts = 0
86-
while Y is None:
87-
try:
88-
Y = _alra(adata)
89-
except rpy2.rinterface_lib.embedded.RRuntimeError: # pragma: no cover
90-
if attempts < 10:
91-
attempts += 1
92-
log.warning(f"alra.R failed (attempt {attempts})")
93-
else:
94-
raise
9582

96-
# transform back into original space
97-
Y = scprep.utils.matrix_transform(Y, np.expm1)
98-
Y = scprep.utils.matrix_vector_elementwise_multiply(Y, libsize, axis=0)
99-
adata.obsm["denoised"] = Y
83+
@_alra_method(
84+
method_name="ALRA (log norm, reversed normalization)",
85+
)
86+
def alra_log_reversenorm(adata, test=False):
87+
return _alra(adata, normtype="log", reverse_norm_order=True, test=False)
10088

101-
adata.uns["method_code_version"] = "1.0.0"
102-
return adata
89+
90+
@_alra_method(
91+
method_name="ALRA (sqrt norm)",
92+
)
93+
def alra_sqrt(adata, test=False):
94+
return _alra(adata, normtype="sqrt", reverse_norm_order=False, test=False)
95+
96+
97+
@_alra_method(
98+
method_name="ALRA (log norm)",
99+
)
100+
def alra_log(adata, test=False):
101+
return _alra(adata, normtype="log", reverse_norm_order=False, test=False)

0 commit comments

Comments
 (0)