Skip to content

Commit c8fa71d

Browse files
flying-sheepsshen8
andauthored
Backport PR #2546: Fix getting log1p base (#2549)
Co-authored-by: Simon P Shen <[email protected]>
1 parent c40d2ee commit c8fa71d

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

scanpy/preprocessing/_highly_variable_genes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def _highly_variable_genes_single_batch(
194194
"""
195195
X = adata.layers[layer] if layer is not None else adata.X
196196
if flavor == 'seurat':
197-
if 'log1p' in adata.uns_keys() and adata.uns['log1p']['base'] is not None:
197+
if 'log1p' in adata.uns_keys() and adata.uns['log1p'].get('base') is not None:
198198
X *= np.log(adata.uns['log1p']['base'])
199199
X = np.expm1(X)
200200

scanpy/tests/test_rank_genes_groups.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,21 @@ def test_emptycat():
248248
rank_genes_groups(pbmc, groupby='louvain')
249249

250250

251+
def test_log1p_save_restore(tmp_path):
252+
"""tests the sequence log1p→save→load→rank_genes_groups"""
253+
from anndata import read
254+
255+
pbmc = pbmc68k_reduced()
256+
sc.pp.log1p(pbmc)
257+
258+
path = tmp_path / 'test.h5ad'
259+
pbmc.write(path)
260+
261+
pbmc = read(path)
262+
263+
sc.tl.rank_genes_groups(pbmc, groupby='bulk_labels', use_raw=True)
264+
265+
251266
def test_wilcoxon_symmetry():
252267
pbmc = pbmc68k_reduced()
253268

scanpy/tools/_rank_genes_groups.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ def __init__(
8888
layer=None,
8989
comp_pts=False,
9090
):
91-
92-
if 'log1p' in adata.uns_keys() and adata.uns['log1p']['base'] is not None:
91+
if 'log1p' in adata.uns_keys() and adata.uns['log1p'].get('base') is not None:
9392
self.expm1_func = lambda x: np.expm1(x * np.log(adata.uns['log1p']['base']))
9493
else:
9594
self.expm1_func = np.expm1
@@ -362,7 +361,6 @@ def compute_statistics(
362361
tie_correct=False,
363362
**kwds,
364363
):
365-
366364
if method in {'t-test', 't-test_overestim_var'}:
367365
generate_test_results = self.t_test(method)
368366
elif method == 'wilcoxon':
@@ -753,7 +751,7 @@ def filter_rank_genes_groups(
753751
index=gene_names.index,
754752
)
755753

756-
if 'log1p' in adata.uns_keys() and adata.uns['log1p']['base'] is not None:
754+
if 'log1p' in adata.uns_keys() and adata.uns['log1p'].get('base') is not None:
757755
expm1_func = lambda x: np.expm1(x * np.log(adata.uns['log1p']['base']))
758756
else:
759757
expm1_func = np.expm1

0 commit comments

Comments
 (0)