Skip to content

Commit

Permalink
Rank features groups obs (#622)
Browse files Browse the repository at this point in the history
* tests for rank features groups with obs

* first drafted feature ranking using obs

* fixed encoding names

* remove comment

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Update ehrapy/tools/feature_ranking/_rank_features_groups.py

Co-authored-by: Lukas Heumos <[email protected]>

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Update ehrapy/tools/feature_ranking/_rank_features_groups.py

Co-authored-by: Lukas Heumos <[email protected]>

* Iterable to list and import from future

* upated to use layer, obs, or both

* this test data should be more stable

* Update ehrapy/tools/feature_ranking/_rank_features_groups.py

Co-authored-by: Lukas Heumos <[email protected]>

* correct indent of previous commit and added comment on dummy X

* bug fixes, more tests and (fixed) examples in docstring

---------

Co-authored-by: Lukas Heumos <[email protected]>
  • Loading branch information
eroell and Zethson authored Dec 7, 2023
1 parent 8c771a3 commit 5447087
Show file tree
Hide file tree
Showing 3 changed files with 340 additions and 7 deletions.
177 changes: 170 additions & 7 deletions ehrapy/tools/feature_ranking/_rank_features_groups.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Literal, Optional, Union

Expand All @@ -6,6 +8,8 @@
import scanpy as sc
from anndata import AnnData

from ehrapy.anndata import move_to_x
from ehrapy.preprocessing import encode
from ehrapy.tools import _method_options


Expand Down Expand Up @@ -240,6 +244,55 @@ def _evaluate_categorical_features(
)


def _check_no_datetime_columns(df):
datetime_cols = [
col
for col in df.columns
if pd.api.types.is_datetime64_any_dtype(df[col]) or pd.api.types.is_timedelta64_dtype(df[col])
]
if datetime_cols:
raise ValueError(f"Columns with datetime format found: {datetime_cols}")


def _get_intersection(adata_uns, key, selection):
"""Get intersection of adata_uns[key] and selection"""
if key in adata_uns:
uns_enc_to_keep = list(set(adata_uns[key]) & set(selection))
else:
uns_enc_to_keep = []
return uns_enc_to_keep


def _check_columns_to_rank_dict(columns_to_rank):
if isinstance(columns_to_rank, str):
if columns_to_rank == "all":
_var_subset = _obs_subset = False
else:
raise ValueError("If columns_to_rank is a string, it must be 'all'.")

elif isinstance(columns_to_rank, dict):
allowed_keys = {"var_names", "obs_names"}
for key in columns_to_rank.keys():
if key not in allowed_keys:
raise ValueError(
f"columns_to_rank dictionary must have only keys 'var_names' and/or 'obs_names', not {key}."
)
if not isinstance(key, str):
raise ValueError(f"columns_to_rank dictionary keys must be strings, not {type(key)}.")

for key, value in columns_to_rank.items():
if not isinstance(value, Iterable) or any(not isinstance(item, str) for item in value):
raise ValueError(f"The value associated with key '{key}' must be an iterable of strings.")

_var_subset = "var_names" in columns_to_rank.keys()
_obs_subset = "obs_names" in columns_to_rank.keys()

else:
raise ValueError("columns_to_rank must be either 'all' or a dictionary.")

return _var_subset, _obs_subset


def rank_features_groups(
adata: AnnData,
groupby: str,
Expand All @@ -255,6 +308,8 @@ def rank_features_groups(
correction_method: _method_options._correction_method = "benjamini-hochberg",
tie_correct: bool = False,
layer: Optional[str] = None,
field_to_rank: Union[Literal["layer"], Literal["obs"], Literal["layer_and_obs"]] = "layer",
columns_to_rank: Union[dict[str, Iterable[str]], Literal["all"]] = "all",
**kwds,
) -> None: # pragma: no cover
"""Rank features for characterizing groups.
Expand Down Expand Up @@ -288,40 +343,143 @@ def rank_features_groups(
Used only for statistical tests (e.g. doesn't work for "logreg" `num_cols_method`)
tie_correct: Use tie correction for `'wilcoxon'` scores. Used only for `'wilcoxon'`.
layer: Key from `adata.layers` whose value will be used to perform tests on.
field_to_rank: Set to `layer` to rank variables in `adata.X` or `adata.layers[layer]` (default), `obs` to rank `adata.obs`, or `layer_and_obs` to rank both. Layer needs to be None if this is not 'layer'.
columns_to_rank: Subset of columns to rank. If 'all', all columns are used. If a dictionary, it must have keys 'var_names' and/or 'obs_names' and values must be iterables of strings. E.g. {'var_names': ['glucose'], 'obs_names': ['age', 'height']}.
**kwds: Are passed to test methods. Currently this affects only parameters that
are passed to :class:`sklearn.linear_model.LogisticRegression`.
For instance, you can pass `penalty='l1'` to try to come up with a
minimal set of genes that are good predictors (sparse solution meaning few non-zero fitted coefficients).
Returns:
*names*: structured `np.ndarray` (`.uns['rank_features_groups']`)
*names* structured `np.ndarray` (`.uns['rank_features_groups']`)
Structured array to be indexed by group id storing the gene
names. Ordered according to scores.
*scores*: structured `np.ndarray` (`.uns['rank_features_groups']`)
*scores* structured `np.ndarray` (`.uns['rank_features_groups']`)
Structured array to be indexed by group id storing the z-score
underlying the computation of a p-value for each gene for each group.
Ordered according to scores.
*logfoldchanges*: structured `np.ndarray` (`.uns['rank_features_groups']`)
*logfoldchanges* structured `np.ndarray` (`.uns['rank_features_groups']`)
Structured array to be indexed by group id storing the log2
fold change for each gene for each group. Ordered according to scores.
Only provided if method is 't-test' like.
Note: this is an approximation calculated from mean-log values.
*pvals*: structured `np.ndarray` (`.uns['rank_features_groups']`) p-values.
*pvals_adj* : structured `np.ndarray` (`.uns['rank_features_groups']`) Corrected p-values.
*pvals* structured `np.ndarray` (`.uns['rank_features_groups']`) p-values.
*pvals_adj* structured `np.ndarray` (`.uns['rank_features_groups']`) Corrected p-values.
*pts*: `pandas.DataFrame` (`.uns['rank_features_groups']`)
Fraction of cells expressing the genes for each group.
*pts_rest*: `pandas.DataFrame` (`.uns['rank_features_groups']`)
*pts_rest* `pandas.DataFrame` (`.uns['rank_features_groups']`)
Only if `reference` is set to `'rest'`.
Fraction of observations from the union of the rest of each group containing the features.
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # want to move some metadata to the obs field
>>> ep.anndata.move_to_obs(adata, to_obs=["service_unit", "service_num", "age", "mort_day_censored"])
>>> ep.tl.rank_features_groups(adata, "service_unit")
>>> ep.pl.rank_features_groups(adata)
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # want to move some metadata to the obs field
>>> ep.anndata.move_to_obs(adata, to_obs=["service_unit", "service_num", "age", "mort_day_censored"])
>>> ep.tl.rank_features_groups(adata, "service_unit", field_to_rank="obs", columns_to_rank={"obs_names": ["age", "mort_day_censored"]})
>>> ep.pl.rank_features_groups(adata)
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # want to move some metadata to the obs field
>>> ep.anndata.move_to_obs(adata, to_obs=["service_unit", "service_num", "age", "mort_day_censored"])
>>> ep.tl.rank_features_groups(adata, "service_unit", field_to_rank="layer_and_obs", columns_to_rank={"var_names": ['copd_flg', 'renal_flg'], "obs_names": ["age", "mort_day_censored"]})
>>> ep.pl.rank_features_groups(adata)
"""
if layer is not None and field_to_rank == "obs":
raise ValueError("If 'layer' is not None, 'field_to_rank' cannot be 'obs'.")

if field_to_rank not in ["layer", "obs", "layer_and_obs"]:
raise ValueError(f"layer must be one of 'layer', 'obs', 'layer_and_obs', not {field_to_rank}")

# to give better error messages, check if columns_to_rank have valid keys and values here
_var_subset, _obs_subset = _check_columns_to_rank_dict(columns_to_rank)

adata = adata.copy() if copy else adata

# to create a minimal adata object below, grab a reference to X/layer of the original adata,
# subsetted to the specified columns
if field_to_rank in ["layer", "layer_and_obs"]:
# for some reason ruff insists on this type check. columns_to_rank is always a dict with key "var_names" if _var_subset is True
if _var_subset and isinstance(columns_to_rank, dict):
X_to_keep = (
adata[:, columns_to_rank["var_names"]].X
if layer is None
else adata[:, columns_to_rank["var_names"]].layers[layer]
)
var_to_keep = adata[:, columns_to_rank["var_names"]].var
uns_num_to_keep = _get_intersection(
adata_uns=adata.uns, key="numerical_columns", selection=columns_to_rank["var_names"]
)
uns_non_num_to_keep = _get_intersection(
adata_uns=adata.uns, key="non_numerical_columns", selection=columns_to_rank["var_names"]
)
uns_enc_to_keep = _get_intersection(
adata_uns=adata.uns, key="encoded_non_numerical_columns", selection=columns_to_rank["var_names"]
)

else:
X_to_keep = adata.X if layer is None else adata.layers[layer]
var_to_keep = adata.var
uns_num_to_keep = adata.uns["numerical_columns"] if "numerical_columns" in adata.uns else []
uns_enc_to_keep = (
adata.uns["encoded_non_numerical_columns"] if "encoded_non_numerical_columns" in adata.uns else []
)
uns_non_num_to_keep = adata.uns["non_numerical_columns"] if "non_numerical_columns" in adata.uns else []

else:
# dummy 1-dimensional X to be used by move_to_x, and removed again afterwards
X_to_keep = np.zeros((len(adata), 1))
var_to_keep = pd.DataFrame({"dummy": [0]})
uns_num_to_keep = []
uns_enc_to_keep = []
uns_non_num_to_keep = []

adata_minimal = sc.AnnData(
X=X_to_keep,
obs=adata.obs,
var=var_to_keep,
uns={
"numerical_columns": uns_num_to_keep,
"encoded_non_numerical_columns": uns_enc_to_keep,
"non_numerical_columns": uns_non_num_to_keep,
},
)

if field_to_rank in ["obs", "layer_and_obs"]:
# want columns of obs to become variables in X to be able to use rank_features_groups
# for some reason ruff insists on this type check. columns_to_rank is always a dict with key "obs_names" if _obs_subset is True
if _obs_subset and isinstance(columns_to_rank, dict):
obs_to_move = adata.obs[columns_to_rank["obs_names"]].keys()
else:
obs_to_move = adata.obs.keys()
_check_no_datetime_columns(adata.obs[obs_to_move])
adata_minimal = move_to_x(adata_minimal, list(obs_to_move))

if field_to_rank == "obs":
# the 0th column is a dummy of zeros and is meaningless in this case, and needs to be removed
adata_minimal = adata_minimal[:, 1:]

adata_minimal = encode(adata_minimal, autodetect=True, encodings="label")
# this is needed because encode() doesn't add this key if there are no categorical columns to encode
if "encoded_non_numerical_columns" not in adata_minimal.uns:
adata_minimal.uns["encoded_non_numerical_columns"] = []

if layer is not None:
adata_minimal.layers[layer] = adata_minimal.X

# save the reference to the original adata, because we will need to access it later
adata_orig = adata
adata = adata_minimal

if not adata.obs[groupby].dtype == "category":
adata.obs[groupby] = pd.Categorical(adata.obs[groupby])

Expand Down Expand Up @@ -403,12 +561,17 @@ def rank_features_groups(
groups_order=group_names,
)

# if field_to_rank was obs or layer_and_obs, the adata object we have been working with is adata_minimal
adata_orig.uns[key_added] = adata.uns[key_added]
adata = adata_orig

# Adjust p values
if "pvals" in adata.uns[key_added]:
adata.uns[key_added]["pvals_adj"] = _adjust_pvalues(
adata.uns[key_added]["pvals"], corr_method=correction_method
)

# For some reason, pts should be a DataFrame
if "pts" in adata.uns[key_added]:
adata.uns[key_added]["pts"] = pd.DataFrame(adata.uns[key_added]["pts"])

Expand Down
13 changes: 13 additions & 0 deletions tests/tools/test_data_features_ranking/dataset1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
idx,sys_bp_entry,dia_bp_entry,glucose,weight,disease,station
1,138,78,80,77,A,ICU
2,139,79,90,76,A,ICU
3,140,80,120,60,A,MICU
4,141,81,130,90,A,MICU
5,148,77,80,110,B,ICU
6,149,78,135,78,B,ICU
7,150,79,125,56,B,MICU
8,151,80,95,76,B,MICU
9,158,55,70,67,C,ICU
10,159,56,85,82,C,ICU
11,160,57,125,59,C,MICU
12,161,58,125,81,C,MICU
Loading

0 comments on commit 5447087

Please sign in to comment.