diff --git a/ehrapy/tools/feature_ranking/_rank_features_groups.py b/ehrapy/tools/feature_ranking/_rank_features_groups.py index 4fe8d52e..5625d95f 100644 --- a/ehrapy/tools/feature_ranking/_rank_features_groups.py +++ b/ehrapy/tools/feature_ranking/_rank_features_groups.py @@ -257,7 +257,7 @@ def _check_no_datetime_columns(df): def _get_intersection(adata_uns, key, selection): """Get intersection of adata_uns[key] and selection""" if key in adata_uns: - uns_enc_to_keep = list(set(adata_uns["encoded_non_numerical_columns"]) & set(selection)) + uns_enc_to_keep = list(set(adata_uns[key]) & set(selection)) else: uns_enc_to_keep = [] return uns_enc_to_keep @@ -351,31 +351,48 @@ def rank_features_groups( minimal set of genes that are good predictors (sparse solution meaning few non-zero fitted coefficients). Returns: - *names*: structured `np.ndarray` (`.uns['rank_features_groups']`) + *names* structured `np.ndarray` (`.uns['rank_features_groups']`) Structured array to be indexed by group id storing the gene names. Ordered according to scores. - *scores*: structured `np.ndarray` (`.uns['rank_features_groups']`) + *scores* structured `np.ndarray` (`.uns['rank_features_groups']`) Structured array to be indexed by group id storing the z-score underlying the computation of a p-value for each gene for each group. Ordered according to scores. - *logfoldchanges*: structured `np.ndarray` (`.uns['rank_features_groups']`) + *logfoldchanges* structured `np.ndarray` (`.uns['rank_features_groups']`) Structured array to be indexed by group id storing the log2 fold change for each gene for each group. Ordered according to scores. Only provided if method is 't-test' like. Note: this is an approximation calculated from mean-log values. - *pvals*: structured `np.ndarray` (`.uns['rank_features_groups']`) p-values. - *pvals_adj* : structured `np.ndarray` (`.uns['rank_features_groups']`) Corrected p-values. + *pvals* structured `np.ndarray` (`.uns['rank_features_groups']`) p-values. + *pvals_adj* structured `np.ndarray` (`.uns['rank_features_groups']`) Corrected p-values. *pts*: `pandas.DataFrame` (`.uns['rank_features_groups']`) Fraction of cells expressing the genes for each group. - *pts_rest*: `pandas.DataFrame` (`.uns['rank_features_groups']`) + *pts_rest* `pandas.DataFrame` (`.uns['rank_features_groups']`) Only if `reference` is set to `'rest'`. Fraction of observations from the union of the rest of each group containing the features. Examples: >>> import ehrapy as ep - >>> adata = ep.dt.mimic_2(encoded=True) + >>> adata = ep.dt.mimic_2(encoded=False) + >>> # want to move some metadata to the obs field + >>> ep.anndata.move_to_obs(adata, to_obs=["service_unit", "service_num", "age", "mort_day_censored"]) >>> ep.tl.rank_features_groups(adata, "service_unit") >>> ep.pl.rank_features_groups(adata) + + >>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=False) + >>> # want to move some metadata to the obs field + >>> ep.anndata.move_to_obs(adata, to_obs=["service_unit", "service_num", "age", "mort_day_censored"]) + >>> ep.tl.rank_features_groups(adata, "service_unit", field_to_rank="obs", columns_to_rank={"obs_names": ["age", "mort_day_censored"]}) + >>> ep.pl.rank_features_groups(adata) + + >>> import ehrapy as ep + >>> adata = ep.dt.mimic_2(encoded=False) + >>> # want to move some metadata to the obs field + >>> ep.anndata.move_to_obs(adata, to_obs=["service_unit", "service_num", "age", "mort_day_censored"]) + >>> ep.tl.rank_features_groups(adata, "service_unit", field_to_rank="layer_and_obs", columns_to_rank={"var_names": ['copd_flg', 'renal_flg'], "obs_names": ["age", "mort_day_censored"]}) + >>> ep.pl.rank_features_groups(adata) + """ if layer is not None and field_to_rank == "obs": raise ValueError("If 'layer' is not None, 'field_to_rank' cannot be 'obs'.") @@ -452,6 +469,9 @@ def rank_features_groups( adata_minimal = adata_minimal[:, 1:] adata_minimal = encode(adata_minimal, autodetect=True, encodings="label") + # this is needed because encode() doesn't add this key if there are no categorical columns to encode + if "encoded_non_numerical_columns" not in adata_minimal.uns: + adata_minimal.uns["encoded_non_numerical_columns"] = [] if layer is not None: adata_minimal.layers[layer] = adata_minimal.X diff --git a/tests/tools/test_features_ranking.py b/tests/tools/test_features_ranking.py index 8e10ead9..cc9c3c22 100644 --- a/tests/tools/test_features_ranking.py +++ b/tests/tools/test_features_ranking.py @@ -384,3 +384,46 @@ def test_rank_features_groups_consistent_results(self): np.array(adata_features_in_x.uns["rank_features_groups"]["names"][record]), np.array(adata_features_in_x_and_obs.uns["rank_features_groups"]["names"][record]), ) + + def test_rank_features_group_column_to_rank(self): + adata = read_csv( + dataset_path=f"{_TEST_PATH}/dataset1.csv", + columns_obs_only=["disease", "station", "sys_bp_entry", "dia_bp_entry"], + index_column="idx", + ) + + # get a fresh adata for every test to not have any side effects + adata_copy = adata.copy() + + ep.tl.rank_features_groups(adata, groupby="disease", columns_to_rank="all") + assert len(adata.uns["rank_features_groups"]["names"]) == 2 + + # want to check a "complete selection" works + adata = adata_copy.copy() + ep.tl.rank_features_groups(adata, groupby="disease", columns_to_rank={"var_names": ["glucose", "weight"]}) + assert len(adata.uns["rank_features_groups"]["names"]) == 2 + + # want to check a "sub-selection" works + adata = adata_copy.copy() + ep.tl.rank_features_groups(adata, groupby="disease", columns_to_rank={"var_names": ["glucose"]}) + assert len(adata.uns["rank_features_groups"]["names"]) == 1 + + # want to check a "complete" selection works + adata = adata_copy.copy() + ep.tl.rank_features_groups( + adata, + groupby="disease", + field_to_rank="obs", + columns_to_rank={"obs_names": ["station", "sys_bp_entry", "dia_bp_entry"]}, + ) + assert len(adata.uns["rank_features_groups"]["names"]) == 3 + + # want to check a "sub-selection" selection works + adata = adata_copy.copy() + ep.tl.rank_features_groups( + adata, + groupby="disease", + field_to_rank="obs", + columns_to_rank={"obs_names": ["sys_bp_entry", "dia_bp_entry"]}, + ) + assert len(adata.uns["rank_features_groups"]["names"]) == 2