Skip to content

Commit

Permalink
bug fixes, more tests and (fixed) examples in docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Dec 7, 2023
1 parent 6458265 commit f444a59
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 8 deletions.
36 changes: 28 additions & 8 deletions ehrapy/tools/feature_ranking/_rank_features_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _check_no_datetime_columns(df):
def _get_intersection(adata_uns, key, selection):
"""Get intersection of adata_uns[key] and selection"""
if key in adata_uns:
uns_enc_to_keep = list(set(adata_uns["encoded_non_numerical_columns"]) & set(selection))
uns_enc_to_keep = list(set(adata_uns[key]) & set(selection))
else:
uns_enc_to_keep = []
return uns_enc_to_keep
Expand Down Expand Up @@ -351,31 +351,48 @@ def rank_features_groups(
minimal set of genes that are good predictors (sparse solution meaning few non-zero fitted coefficients).
Returns:
*names*: structured `np.ndarray` (`.uns['rank_features_groups']`)
*names* structured `np.ndarray` (`.uns['rank_features_groups']`)
Structured array to be indexed by group id storing the gene
names. Ordered according to scores.
*scores*: structured `np.ndarray` (`.uns['rank_features_groups']`)
*scores* structured `np.ndarray` (`.uns['rank_features_groups']`)
Structured array to be indexed by group id storing the z-score
underlying the computation of a p-value for each gene for each group.
Ordered according to scores.
*logfoldchanges*: structured `np.ndarray` (`.uns['rank_features_groups']`)
*logfoldchanges* structured `np.ndarray` (`.uns['rank_features_groups']`)
Structured array to be indexed by group id storing the log2
fold change for each gene for each group. Ordered according to scores.
Only provided if method is 't-test' like.
Note: this is an approximation calculated from mean-log values.
*pvals*: structured `np.ndarray` (`.uns['rank_features_groups']`) p-values.
*pvals_adj* : structured `np.ndarray` (`.uns['rank_features_groups']`) Corrected p-values.
*pvals* structured `np.ndarray` (`.uns['rank_features_groups']`) p-values.
*pvals_adj* structured `np.ndarray` (`.uns['rank_features_groups']`) Corrected p-values.
*pts*: `pandas.DataFrame` (`.uns['rank_features_groups']`)
Fraction of cells expressing the genes for each group.
*pts_rest*: `pandas.DataFrame` (`.uns['rank_features_groups']`)
*pts_rest* `pandas.DataFrame` (`.uns['rank_features_groups']`)
Only if `reference` is set to `'rest'`.
Fraction of observations from the union of the rest of each group containing the features.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # want to move some metadata to the obs field
>>> ep.anndata.move_to_obs(adata, to_obs=["service_unit", "service_num", "age", "mort_day_censored"])
>>> ep.tl.rank_features_groups(adata, "service_unit")
>>> ep.pl.rank_features_groups(adata)
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # want to move some metadata to the obs field
>>> ep.anndata.move_to_obs(adata, to_obs=["service_unit", "service_num", "age", "mort_day_censored"])
>>> ep.tl.rank_features_groups(adata, "service_unit", field_to_rank="obs", columns_to_rank={"obs_names": ["age", "mort_day_censored"]})
>>> ep.pl.rank_features_groups(adata)
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # want to move some metadata to the obs field
>>> ep.anndata.move_to_obs(adata, to_obs=["service_unit", "service_num", "age", "mort_day_censored"])
>>> ep.tl.rank_features_groups(adata, "service_unit", field_to_rank="layer_and_obs", columns_to_rank={"var_names": ['copd_flg', 'renal_flg'], "obs_names": ["age", "mort_day_censored"]})
>>> ep.pl.rank_features_groups(adata)
"""
if layer is not None and field_to_rank == "obs":
raise ValueError("If 'layer' is not None, 'field_to_rank' cannot be 'obs'.")
Expand Down Expand Up @@ -452,6 +469,9 @@ def rank_features_groups(
adata_minimal = adata_minimal[:, 1:]

adata_minimal = encode(adata_minimal, autodetect=True, encodings="label")
# this is needed because encode() doesn't add this key if there are no categorical columns to encode
if "encoded_non_numerical_columns" not in adata_minimal.uns:
adata_minimal.uns["encoded_non_numerical_columns"] = []

if layer is not None:
adata_minimal.layers[layer] = adata_minimal.X
Expand Down
43 changes: 43 additions & 0 deletions tests/tools/test_features_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,46 @@ def test_rank_features_groups_consistent_results(self):
np.array(adata_features_in_x.uns["rank_features_groups"]["names"][record]),
np.array(adata_features_in_x_and_obs.uns["rank_features_groups"]["names"][record]),
)

def test_rank_features_group_column_to_rank(self):
adata = read_csv(
dataset_path=f"{_TEST_PATH}/dataset1.csv",
columns_obs_only=["disease", "station", "sys_bp_entry", "dia_bp_entry"],
index_column="idx",
)

# get a fresh adata for every test to not have any side effects
adata_copy = adata.copy()

ep.tl.rank_features_groups(adata, groupby="disease", columns_to_rank="all")
assert len(adata.uns["rank_features_groups"]["names"]) == 2

# want to check a "complete selection" works
adata = adata_copy.copy()
ep.tl.rank_features_groups(adata, groupby="disease", columns_to_rank={"var_names": ["glucose", "weight"]})
assert len(adata.uns["rank_features_groups"]["names"]) == 2

# want to check a "sub-selection" works
adata = adata_copy.copy()
ep.tl.rank_features_groups(adata, groupby="disease", columns_to_rank={"var_names": ["glucose"]})
assert len(adata.uns["rank_features_groups"]["names"]) == 1

# want to check a "complete" selection works
adata = adata_copy.copy()
ep.tl.rank_features_groups(
adata,
groupby="disease",
field_to_rank="obs",
columns_to_rank={"obs_names": ["station", "sys_bp_entry", "dia_bp_entry"]},
)
assert len(adata.uns["rank_features_groups"]["names"]) == 3

# want to check a "sub-selection" selection works
adata = adata_copy.copy()
ep.tl.rank_features_groups(
adata,
groupby="disease",
field_to_rank="obs",
columns_to_rank={"obs_names": ["sys_bp_entry", "dia_bp_entry"]},
)
assert len(adata.uns["rank_features_groups"]["names"]) == 2

0 comments on commit f444a59

Please sign in to comment.