Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rank features groups obs #622

Merged
merged 19 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions ehrapy/tools/feature_ranking/_rank_features_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import scanpy as sc
from anndata import AnnData

from ehrapy.anndata import move_to_x
from ehrapy.preprocessing import encode
from ehrapy.tools import _method_options


Expand Down Expand Up @@ -255,6 +257,7 @@ def rank_features_groups(
correction_method: _method_options._correction_method = "benjamini-hochberg",
tie_correct: bool = False,
layer: Optional[str] = None,
rank_obs_columns: Optional[Union[list[str], str]] = None,
eroell marked this conversation as resolved.
Show resolved Hide resolved
**kwds,
) -> None: # pragma: no cover
"""Rank features for characterizing groups.
Expand Down Expand Up @@ -288,6 +291,8 @@ def rank_features_groups(
Used only for statistical tests (e.g. doesn't work for "logreg" `num_cols_method`)
tie_correct: Use tie correction for `'wilcoxon'` scores. Used only for `'wilcoxon'`.
layer: Key from `adata.layers` whose value will be used to perform tests on.
rank_obs_columns: Whether to rank `adata.obs` columns instead of features in `adata.layer`. If `True`, all observation columns are ranked. If list of column names, only those are ranked.
layer needs to be None if this is used.
**kwds: Are passed to test methods. Currently this affects only parameters that
are passed to :class:`sklearn.linear_model.LogisticRegression`.
For instance, you can pass `penalty='l1'` to try to come up with a
Expand Down Expand Up @@ -320,8 +325,54 @@ def rank_features_groups(
>>> ep.tl.rank_features_groups(adata, "service_unit")
>>> ep.pl.rank_features_groups(adata)
"""
# if rank_obs_columns is indicated, layer must be None
eroell marked this conversation as resolved.
Show resolved Hide resolved
if layer is not None and rank_obs_columns is not None:
raise ValueError("Only one of layer and rank_obs_columns can be specified.")
eroell marked this conversation as resolved.
Show resolved Hide resolved

adata = adata.copy() if copy else adata

if rank_obs_columns is not None:
# keep reference to original adata, needed if copy=False
adata_orig = adata
# copy adata to work on
eroell marked this conversation as resolved.
Show resolved Hide resolved
adata = adata.copy()
eroell marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(rank_obs_columns, str):
if rank_obs_columns == "all":
rank_obs_columns = adata.obs.keys().to_list()
else:
raise ValueError(
f"rank_obs_columns should be 'all' or Iterable of column names, not {rank_obs_columns}."
eroell marked this conversation as resolved.
Show resolved Hide resolved
)

# consider adata where all columns from obs become the features, and the other features are dropped
if not all(elem in adata.obs.columns.values for elem in rank_obs_columns):
raise ValueError(
f"Columns `{[col for col in rank_obs_columns if col not in adata.obs.columns.values]}` are not in obs."
)

# if groupby in rank_obs_columns:
eroell marked this conversation as resolved.
Show resolved Hide resolved
# rank_obs_columns.remove(groupby)

# move obs columns to X
eroell marked this conversation as resolved.
Show resolved Hide resolved
adata_with_moved_columns = move_to_x(adata, rank_obs_columns)

# remove columns previously in X
eroell marked this conversation as resolved.
Show resolved Hide resolved
columns_to_select = adata_with_moved_columns.var_names.difference(adata.var_names)
adata_with_moved_columns = adata_with_moved_columns[:, columns_to_select]

# encode categoricals
eroell marked this conversation as resolved.
Show resolved Hide resolved
adata_with_moved_columns = encode(adata_with_moved_columns, autodetect=True, encodings="label")

# assign numeric and categorical columns
eroell marked this conversation as resolved.
Show resolved Hide resolved
adata_with_moved_columns.uns[
"non_numerical_columns"
] = [] # this should be empty, as have only numeric and encoded
adata_with_moved_columns.uns["numerical_columns"] = adata_with_moved_columns.var_names.difference(
adata_with_moved_columns.uns["encoded_non_numerical_columns"]
).to_list() # this is sensitive to `encode` really detecting what it should
adata = adata_with_moved_columns

if not adata.obs[groupby].dtype == "category":
adata.obs[groupby] = pd.Categorical(adata.obs[groupby])

Expand Down Expand Up @@ -409,9 +460,14 @@ def rank_features_groups(
adata.uns[key_added]["pvals"], corr_method=correction_method
)

# For some reason, pts should be a DataFrame
if "pts" in adata.uns[key_added]:
adata.uns[key_added]["pts"] = pd.DataFrame(adata.uns[key_added]["pts"])

_sort_features(adata, key_added)

if rank_obs_columns is not None:
adata_orig.uns[key_added] = adata.uns[key_added]
adata = adata_orig

return adata if copy else None
13 changes: 13 additions & 0 deletions tests/tools/test_data_features_ranking/dataset1.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
idx,sys_bp_entry,dia_bp_entry,glucose,weight,disease,station
1,138,78,80,77,A,ICU
2,139,79,90,76,A,ICU
3,140,80,120,60,A,MICU
4,141,81,130,90,A,MICU
5,148,77,80,110,B,ICU
6,149,78,130,78,B,ICU
7,150,79,120,56,B,MICU
8,151,80,90,76,B,MICU
9,158,55,80,67,C,ICU
10,159,56,90,82,C,ICU
11,160,57,120,59,C,MICU
12,161,58,130,81,C,MICU
63 changes: 63 additions & 0 deletions tests/tools/test_features_ranking.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from pathlib import Path

import numpy as np
import pandas as pd
import pytest

import ehrapy as ep
import ehrapy.tools.feature_ranking._rank_features_groups as _utils
from ehrapy.io._read import read_csv

CURRENT_DIR = Path(__file__).parent
_TEST_PATH = f"{CURRENT_DIR}/test_data_features_ranking"


class TestHelperFunctions:
Expand Down Expand Up @@ -270,3 +276,60 @@ def test_only_cat_features(self):
assert "scores" in adata.uns["rank_features_groups"]
assert "logfoldchanges" in adata.uns["rank_features_groups"]
assert "pvals_adj" in adata.uns["rank_features_groups"]

def test_rank_obs(
self,
):
# prepare data with some interesting features in .obs
adata_features_in_obs = read_csv(
dataset_path=f"{_TEST_PATH}/dataset1.csv",
columns_obs_only=["disease", "station", "sys_bp_entry", "dia_bp_entry"],
)

# prepare data with these features in .X
adata_features_in_x = read_csv(
dataset_path=f"{_TEST_PATH}/dataset1.csv", columns_x_only=["station", "sys_bp_entry", "dia_bp_entry"]
)
adata_features_in_x = ep.pp.encode(adata_features_in_x, encodings={"label": ["station"]})

# rank_features_groups on .obs
ep.tl.rank_features_groups(adata_features_in_obs, groupby="disease", rank_obs_columns="all")

# rank features groups on .X
ep.tl.rank_features_groups(adata_features_in_x, groupby="disease")

# check standard rank_features_groups entries
assert "names" in adata_features_in_obs.uns["rank_features_groups"]
assert "pvals" in adata_features_in_obs.uns["rank_features_groups"]
assert "scores" in adata_features_in_obs.uns["rank_features_groups"]
assert "pvals_adj" in adata_features_in_obs.uns["rank_features_groups"]
assert "log2foldchanges" not in adata_features_in_obs.uns["rank_features_groups"]
assert "pts" not in adata_features_in_obs.uns["rank_features_groups"]
assert (
len(adata_features_in_obs.uns["rank_features_groups"]["names"]) == 3
) # It only captures the length of each group
assert len(adata_features_in_obs.uns["rank_features_groups"]["pvals"]) == 3
assert len(adata_features_in_obs.uns["rank_features_groups"]["scores"]) == 3

# check the obs are used indeed
assert "sys_bp_entry" in adata_features_in_obs.uns["rank_features_groups"]["names"][0]
assert "sys_bp_entry" in adata_features_in_obs.uns["rank_features_groups"]["names"][1]
assert "ehrapycat_station" in adata_features_in_obs.uns["rank_features_groups"]["names"][2]

# check the X are not used
assert "glucose" not in adata_features_in_obs.uns["rank_features_groups"]["names"][0]

# check the results are the same
for record in adata_features_in_obs.uns["rank_features_groups"]["names"].dtype.names:
assert np.allclose(
adata_features_in_obs.uns["rank_features_groups"]["scores"][record],
adata_features_in_x.uns["rank_features_groups"]["scores"][record],
)
assert np.allclose(
np.array(adata_features_in_obs.uns["rank_features_groups"]["pvals"][record]),
np.array(adata_features_in_x.uns["rank_features_groups"]["pvals"][record]),
)
assert np.array_equal(
np.array(adata_features_in_obs.uns["rank_features_groups"]["names"][record]),
np.array(adata_features_in_x.uns["rank_features_groups"]["names"][record]),
)
Loading