Skip to content

Commit 49c83bf

Browse files
danielStroblgithub-actions[bot]scottgigante-immunai
authored
Daniel strobl hvg conservation fix (#785)
* hvg conservation metric fix * pre-commit * Allow for uppercase repo owner * Fix sklearn req * bash not sh * bugfix use index * add to api * pre-commit * list instead of index * check number of genes * pre-commit * addressing comments * pre-commit * shorten line * addressing comments * pre-commit * fix checks * remove magic numbers * pre-commit * int -> numbers.Integral * Fix typo * fix dataset size assumption and duck-type hvg_unint --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Scott Gigante <[email protected]> Co-authored-by: Scott Gigante <[email protected]> Former-commit-id: 814fedc
1 parent c5e6f56 commit 49c83bf

File tree

7 files changed

+71
-6
lines changed

7 files changed

+71
-6
lines changed

openproblems/tasks/_batch_integration/_common/api.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from ....data.sample import load_sample_data
22
from ....tools.decorators import dataset
33
from .utils import filter_celltypes
4+
from .utils import precompute_hvg
45

6+
import numbers
57
import numpy as np
68

79
MIN_CELLS_PER_CELLTYPE = 50
10+
N_HVG_UNINT = 2000
811

912

1013
def check_neighbors(adata, neighbors_key, connectivities_key, distances_key):
@@ -15,7 +18,12 @@ def check_neighbors(adata, neighbors_key, connectivities_key, distances_key):
1518
assert distances_key in adata.obsp
1619

1720

18-
def check_dataset(adata, do_check_pca=False, do_check_neighbors=False):
21+
def check_dataset(
22+
adata,
23+
do_check_pca=False,
24+
do_check_neighbors=False,
25+
do_check_hvg=False,
26+
):
1927
"""Check that dataset output fits expected API."""
2028

2129
assert "batch" in adata.obs
@@ -28,12 +36,21 @@ def check_dataset(adata, do_check_pca=False, do_check_neighbors=False):
2836
assert adata.var_names.is_unique
2937
assert adata.obs_names.is_unique
3038

39+
assert "n_genes_pre" in adata.uns
40+
assert isinstance(adata.uns["n_genes_pre"], numbers.Integral)
41+
assert adata.uns["n_genes_pre"] == adata.n_vars
42+
3143
assert "organism" in adata.uns
3244
assert adata.uns["organism"] in ["mouse", "human"]
3345

3446
if do_check_pca:
3547
assert "X_uni_pca" in adata.obsm
3648

49+
if do_check_hvg:
50+
assert "hvg_unint" in adata.uns
51+
assert len(adata.uns["hvg_unint"]) == min(N_HVG_UNINT, adata.n_vars)
52+
assert np.all(np.isin(adata.uns["hvg_unint"], adata.var.index))
53+
3754
if do_check_neighbors:
3855
check_neighbors(adata, "uni", "uni_connectivities", "uni_distances")
3956

@@ -58,6 +75,10 @@ def sample_dataset(run_pca: bool = False, run_neighbors: bool = False):
5875
adata.obs["batch"] = np.random.choice(2, adata.shape[0], replace=True).astype(str)
5976
adata.obs["labels"] = np.random.choice(3, adata.shape[0], replace=True).astype(str)
6077
adata = filter_celltypes(adata)
78+
79+
adata.uns["hvg_unint"] = precompute_hvg(adata)
80+
adata.uns["n_genes_pre"] = adata.n_vars
81+
6182
if run_pca:
6283
adata.obsm["X_uni_pca"] = sc.pp.pca(adata.X)
6384
if run_neighbors:

openproblems/tasks/_batch_integration/_common/datasets/immune.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .....data.immune_cells import load_immune
22
from .....tools.decorators import dataset
33
from ..utils import filter_celltypes
4+
from ..utils import precompute_hvg
45
from typing import Optional
56

67

@@ -13,7 +14,11 @@
1314
"Smart-seq2).",
1415
image="openproblems",
1516
)
16-
def immune_batch(test: bool = False, min_celltype_count: Optional[int] = None):
17+
def immune_batch(
18+
test: bool = False,
19+
min_celltype_count: Optional[int] = None,
20+
n_hvg: Optional[int] = None,
21+
):
1722
import scanpy as sc
1823

1924
adata = load_immune(test)
@@ -38,4 +43,7 @@ def immune_batch(test: bool = False, min_celltype_count: Optional[int] = None):
3843

3944
sc.pp.neighbors(adata, use_rep="X_uni_pca", key_added="uni")
4045
adata.var_names_make_unique()
46+
47+
adata.uns["hvg_unint"] = precompute_hvg(adata, n_genes=n_hvg)
48+
adata.uns["n_genes_pre"] = adata.n_vars
4149
return adata

openproblems/tasks/_batch_integration/_common/datasets/pancreas.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .....data.pancreas import load_pancreas
22
from .....tools.decorators import dataset
33
from ..utils import filter_celltypes
4+
from ..utils import precompute_hvg
45
from typing import Optional
56

67

@@ -13,7 +14,11 @@
1314
"and SMARTER-seq).",
1415
image="openproblems",
1516
)
16-
def pancreas_batch(test: bool = False, min_celltype_count: Optional[int] = None):
17+
def pancreas_batch(
18+
test: bool = False,
19+
min_celltype_count: Optional[int] = None,
20+
n_hvg: Optional[int] = None,
21+
):
1722
import scanpy as sc
1823

1924
adata = load_pancreas(test)
@@ -38,4 +43,7 @@ def pancreas_batch(test: bool = False, min_celltype_count: Optional[int] = None)
3843
sc.pp.neighbors(adata, use_rep="X_uni_pca", key_added="uni")
3944

4045
adata.var_names_make_unique()
46+
47+
adata.uns["hvg_unint"] = precompute_hvg(adata, n_genes=n_hvg)
48+
adata.uns["n_genes_pre"] = adata.n_vars
4149
return adata
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
1+
from . import api
2+
from scanpy.pp import highly_variable_genes
13
from typing import Optional
24

35

46
def filter_celltypes(adata, min_celltype_count: Optional[int] = None):
57

6-
min_celltype_count = min_celltype_count or 50
8+
min_celltype_count = min_celltype_count or api.MIN_CELLS_PER_CELLTYPE
79

810
celltype_counts = adata.obs["labels"].value_counts()
911
keep_celltypes = celltype_counts[celltype_counts >= min_celltype_count].index
1012
keep_cells = adata.obs["labels"].isin(keep_celltypes)
1113
return adata[keep_cells].copy()
14+
15+
16+
def precompute_hvg(adata, n_genes: Optional[int] = None):
17+
18+
n_genes = n_genes or api.N_HVG_UNINT
19+
hvg_unint = highly_variable_genes(
20+
adata,
21+
n_top_genes=n_genes,
22+
layer="log_normalized",
23+
flavor="cell_ranger",
24+
batch_key="batch",
25+
inplace=False,
26+
)
27+
return list(hvg_unint[hvg_unint.highly_variable].index)

openproblems/tasks/_batch_integration/batch_integration_feature/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,12 @@ Datasets should contain the following attributes:
4646
* `adata.layers['counts']` with raw, integer UMI count data,
4747
* `adata.layers['log_normalized']` with log-normalized data and
4848
* `adata.X` with log-normalized data
49+
* `adata.uns['n_genes_pre']` with the number of genes present before integration
50+
* `adata.uns['hvg_unint']` with a list of 2000 highly variable genes
51+
prior to integration (for the hvg conservation metric)
4952

5053
Methods should store their a batch-corrected gene expression matrix in `adata.X`.
54+
The output should should contain at least 2000 features.
5155

5256
The `openproblems-python-batch-integration` docker container is used for the methods
5357
that

openproblems/tasks/_batch_integration/batch_integration_feature/api.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@
33

44
import functools
55

6-
check_dataset = functools.partial(api.check_dataset, do_check_pca=True)
6+
check_dataset = functools.partial(
7+
api.check_dataset, do_check_hvg=True, do_check_pca=True
8+
)
79

810

911
def check_method(adata, is_baseline=False):
1012
"""Check that method output fits expected API."""
1113
assert "log_normalized" in adata.layers
14+
# check hvg_unint is still there
15+
assert "hvg_unint" in adata.uns
16+
# check n_vars is not too small
17+
assert "n_genes_pre" in adata.uns
18+
assert adata.n_vars >= min(api.N_HVG_UNINT, adata.uns["n_genes_pre"])
1219
if not is_baseline:
1320
assert adata.layers["log_normalized"] is not adata.X
1421
return True

openproblems/tasks/_batch_integration/batch_integration_feature/metrics/hvg_conservation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ def hvg_conservation(adata):
2929

3030
adata_unint = adata.copy()
3131
adata_unint.X = adata_unint.layers["log_normalized"]
32+
hvg_both = list(set(adata.uns["hvg_unint"]).intersection(adata.var_names))
3233

33-
return hvg_overlap(adata_unint, adata, "batch")
34+
return hvg_overlap(adata_unint, adata[:, hvg_both], "batch")

0 commit comments

Comments
 (0)