Skip to content

Commit

Permalink
Move information on numerical/non_numerical/encoded_non_numerical fro…
Browse files Browse the repository at this point in the history
…m .uns to .var (#630)

* tests for rank features groups with obs

* first drafted feature ranking using obs

* fixed encoding names

* remove comment

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Update ehrapy/tools/feature_ranking/_rank_features_groups.py

Co-authored-by: Lukas Heumos <[email protected]>

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Remove comment

Co-authored-by: Lukas Heumos <[email protected]>

* Update ehrapy/tools/feature_ranking/_rank_features_groups.py

Co-authored-by: Lukas Heumos <[email protected]>

* Iterable to list and import from future

* upated to use layer, obs, or both

* this test data should be more stable

* Update ehrapy/tools/feature_ranking/_rank_features_groups.py

Co-authored-by: Lukas Heumos <[email protected]>

* correct indent of previous commit and added comment on dummy X

* bug fixes, more tests and (fixed) examples in docstring

* corrected for tests and modified encode

* remove need for fields in .uns, updated use in var

* forgot to commit ep.anndata._constants.py

* Update ehrapy/anndata/anndata_ext.py

Co-authored-by: Lukas Heumos <[email protected]>

* Update ehrapy/anndata/anndata_ext.py

No type annotation in docstrings

Co-authored-by: Lukas Heumos <[email protected]>

* remove usage of .uns, rewrite tests

* remove type in docstring

* remove commented code

---------

Co-authored-by: Lukas Heumos <[email protected]>
  • Loading branch information
eroell and Zethson authored Dec 19, 2023
1 parent 255f9a7 commit 3cbeafd
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 141 deletions.
8 changes: 8 additions & 0 deletions ehrapy/anndata/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Typing Column
# -----------------------
# The column name and used values in adata.var for column types.

EHRAPY_TYPE_KEY = "ehrapy_column_type"
NUMERIC_TAG = "numeric"
NON_NUMERIC_TAG = "non_numeric"
NON_NUMERIC_ENCODED_TAG = "non_numeric_encoded"
81 changes: 55 additions & 26 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from scipy.sparse import issparse

from ehrapy import logging as logg
from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG

if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Sequence
Expand Down Expand Up @@ -93,19 +94,23 @@ def df_to_anndata(

# initializing an OrderedDict with a non-empty dict might not be intended,
# see: https://stackoverflow.com/questions/25480089/right-way-to-initialize-an-ordereddict-using-its-constructor-such-that-it-retain/25480206
uns = OrderedDict()
uns = OrderedDict() # type: ignore
# store all numerical/non-numerical columns that are not obs only
binary_columns = _detect_binary_columns(df, numerical_columns)
uns["numerical_columns"] = list(set(numerical_columns) | set(binary_columns))
uns["non_numerical_columns"] = list(set(dataframes.df.columns) ^ set(uns["numerical_columns"]))

var = pd.DataFrame(index=list(dataframes.df.columns))
var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG
var.loc[var.index.isin(list(set(numerical_columns) | set(binary_columns))), EHRAPY_TYPE_KEY] = NUMERIC_TAG
# in case of encoded columns by ehrapy, want to be able to read it back in
var.loc[var.index.str.contains("ehrapycat"), EHRAPY_TYPE_KEY] = NON_NUMERIC_ENCODED_TAG

all_num = True if len(numerical_columns) == len(list(dataframes.df.columns)) else False
X = X.astype(np.number) if all_num else X.astype(object)
# cast non numerical obs only columns to category or bool dtype, which is needed for writing to .h5ad files
adata = AnnData(
X=X,
obs=_cast_obs_columns(dataframes.obs),
var=pd.DataFrame(index=list(dataframes.df.columns)),
var=var,
layers={"original": X.copy()},
uns=uns,
)
Expand Down Expand Up @@ -202,37 +207,41 @@ def move_to_obs(adata: AnnData, to_obs: list[str] | str, copy_obs: bool = False)
f"Columns `{[col for col in to_obs if col not in adata.var_names.values]}` are not in var_names."
)

cols_to_obs_indices = adata.var_names.isin(to_obs)

num_set = _get_var_indices_for_type(adata, NUMERIC_TAG)
var_num = list(set(to_obs) & set(num_set))

if copy_obs:
cols_to_obs_indices = adata.var_names.isin(to_obs)
cols_to_obs = adata[:, cols_to_obs_indices].to_df()
adata.obs = adata.obs.join(cols_to_obs)
num_set = set(adata.uns["numerical_columns"].copy())
non_num_set = set(adata.uns["non_numerical_columns"].copy())
var_num = []
var_non_num = []
for var in to_obs:
if var in num_set:
var_num.append(var)
elif var in non_num_set:
var_non_num.append(var)
adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, errors="ignore", downcast="float")
adata.obs = _cast_obs_columns(adata.obs)
else:
cols_to_obs_indices = adata.var_names.isin(to_obs)
df = adata[:, cols_to_obs_indices].to_df()
adata._inplace_subset_var(~cols_to_obs_indices)
adata.obs = adata.obs.join(df)
updated_num_uns, updated_non_num_uns, num_var = _update_uns(adata, to_obs)
adata.obs[num_var] = adata.obs[num_var].apply(pd.to_numeric, errors="ignore", downcast="float")
adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, errors="ignore", downcast="float")
adata.obs = _cast_obs_columns(adata.obs)
adata.uns["numerical_columns"] = updated_num_uns
adata.uns["non_numerical_columns"] = updated_non_num_uns

logg.info(f"Added `{to_obs}` to `obs`.")

return adata


def _get_var_indices_for_type(adata: AnnData, tag: str) -> list[str]:
"""Get indices of columns in var for a given tag.
Args:
adata: The AnnData object
tag: The tag to search for, should be one of `NUMERIC_TAG`, `NON_NUMERIC_TAG` or `NON_NUMERIC_ENCODED_TAG`
Returns:
List of numeric columns
"""
return adata.var_names[adata.var[EHRAPY_TYPE_KEY] == tag].tolist()


def delete_from_obs(adata: AnnData, to_delete: list[str]) -> AnnData:
"""Delete features from obs.
Expand Down Expand Up @@ -305,11 +314,11 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData:
if cols_not_in_x:
new_adata = concat([adata, AnnData(adata.obs[cols_not_in_x])], axis=1)
new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]]
# update uns (copy maybe: could be a costly operation but reduces reference cycles)
# users might save those as separate AnnData object and this could be unexpected behaviour if we dont copy
num_columns_moved, non_num_columns_moved, _ = _update_uns(adata, cols_not_in_x, True)
new_adata.uns["numerical_columns"] = adata.uns["numerical_columns"] + num_columns_moved
new_adata.uns["non_numerical_columns"] = adata.uns["non_numerical_columns"] + non_num_columns_moved

# AnnData's concat discards var if they dont match in their keys, so we need to create a new var
created_var = _create_new_var(adata, cols_not_in_x)
new_adata.var = pd.concat([adata.var, created_var], axis=0)

logg.info(f"Added `{cols_not_in_x}` features to `X`.")
else:
new_adata = adata
Expand Down Expand Up @@ -486,10 +495,11 @@ def get_numeric_vars(adata: AnnData) -> list[str]:
"""
_assert_encoded(adata)

if "numerical_columns" not in adata.uns_keys():
# This behaviour is consistent with the previous behaviour, allowing for a simple fully numeric X
if EHRAPY_TYPE_KEY not in adata.var.columns:
return list(adata.var_names.values)
else:
return adata.uns["numerical_columns"]
return _get_var_indices_for_type(adata, NUMERIC_TAG)


def assert_numeric_vars(adata: AnnData, vars: Sequence[str]):
Expand Down Expand Up @@ -579,6 +589,25 @@ def _update_uns(
return all_moved_num_columns, list(all_moved_non_num_columns), None


def _create_new_var(adata: AnnData, cols_not_in_x: list[str]) -> pd.DataFrame:
"""Create a new var DataFrame with the EHRAPY_TYPE_KEY column set for entries from .obs.
Args:
adata: From where to get the .obs
cols_not_in_x: .obs columns to move to X
Returns:
New var DataFrame with EHRAPY_TYPE_KEY column set for entries from .obs
"""
all_moved_num_columns = set(cols_not_in_x) & set(adata.obs.select_dtypes(include="number").columns)

new_var = pd.DataFrame(index=cols_not_in_x)
new_var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG
new_var.loc[list(all_moved_num_columns), EHRAPY_TYPE_KEY] = NUMERIC_TAG

return new_var


def _detect_binary_columns(df: pd.DataFrame, numerical_columns: list[str]) -> list[str]:
"""Detect all columns that contain only 0 and 1 (besides NaNs).
Expand Down
50 changes: 27 additions & 23 deletions ehrapy/preprocessing/_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

from ehrapy import logging as logg
from ehrapy.anndata.anndata_ext import _update_uns
from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG
from ehrapy.anndata.anndata_ext import _get_var_indices_for_type

multi_encoding_modes = {"hash"}
available_encodings = {"one-hot", "label", "count", *multi_encoding_modes}
Expand Down Expand Up @@ -143,8 +144,8 @@ def _encode(
"[bold yellow]The current AnnData object has been already encoded. Returning original AnnData object!"
)
return adata
categoricals_names = _get_var_indices_for_type(adata, NON_NUMERIC_TAG)

categoricals_names = adata.uns["non_numerical_columns"]
# no columns were detected, that would require an encoding (e.g. non numerical columns)
if not categoricals_names:
print("[bold yellow]Detected no columns that need to be encoded. Leaving passed AnnData object unchanged.")
Expand Down Expand Up @@ -194,33 +195,30 @@ def _encode(
)
progress.update(task, description=f"[bold blue]Finished {encodings} of autodetected columns.")

# copy non-encoded columns, and add new tag for encoded columns. This is needed to track encodings
new_var = pd.DataFrame(index=encoded_var_names)
new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG

encoded_ann_data = AnnData(
encoded_x,
obs=adata.obs.copy(),
var={"var_names": encoded_var_names},
var=new_var,
uns=orig_uns_copy,
layers={"original": updated_layer},
)
encoded_ann_data.uns["var_to_encoding"] = {categorical: encodings for categorical in categoricals_names}
encoded_ann_data.uns["encoding_to_var"] = {encodings: categoricals_names}

encoded_ann_data.uns["numerical_columns"] = adata.uns["numerical_columns"].copy()
encoded_ann_data.uns["non_numerical_columns"] = []
encoded_ann_data.uns["encoded_non_numerical_columns"] = [
column for column in encoded_ann_data.var_names if column.startswith("ehrapycat_")
]

_add_categoricals_to_obs(adata, encoded_ann_data, categoricals_names)

# user passed categorical values with encoding mode for each of them
else:
# Required since this would be deleted through side references
non_numericals = adata.uns["non_numerical_columns"].copy()
# reencode data
if "var_to_encoding" in adata.uns.keys():
encodings = _reorder_encodings(adata, encodings) # type: ignore
adata = _undo_encoding(adata, "all")
adata.uns["non_numerical_columns"] = non_numericals

# are all specified encodings valid?
for encoding in encodings.keys(): # type: ignore
if encoding not in available_encodings:
Expand All @@ -246,7 +244,7 @@ def _encode(
"The categorical column names given contain at least one duplicate column. "
"Check the column names to ensure that no column is encoded twice!"
)
elif any(cat in adata.uns["numerical_columns"] for cat in categoricals):
elif any(cat in adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG] for cat in categoricals):
print(
"[bold yellow]At least one of passed column names seems to have numerical dtype. In general it is not recommended "
"to encode numerical columns!"
Expand Down Expand Up @@ -298,11 +296,17 @@ def _encode(
adata.var_names.to_list(),
categoricals,
)

# copy non-encoded columns, and add new tag for encoded columns. This is needed to track encodings
new_var = pd.DataFrame(index=encoded_var_names)
new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG

try:
encoded_ann_data = AnnData(
X=encoded_x,
obs=adata.obs.copy(),
var={"var_names": encoded_var_names},
var=new_var,
uns=orig_uns_copy,
layers={"original": updated_layer},
)
Expand All @@ -315,12 +319,6 @@ def _encode(
"Creation of AnnData object failed. Ensure that you passed all non numerical, "
"categorical values for encoding!"
) from None
updated_num_uns, updated_non_num_uns, _ = _update_uns(adata, categoricals)
encoded_ann_data.uns["numerical_columns"] = updated_num_uns
encoded_ann_data.uns["non_numerical_columns"] = updated_non_num_uns
encoded_ann_data.uns["encoded_non_numerical_columns"] = [
column for column in encoded_ann_data.var_names if column.startswith("ehrapycat_")
]

_add_categoricals_to_obs(adata, encoded_ann_data, categoricals)

Expand Down Expand Up @@ -686,7 +684,8 @@ def _undo_encoding(
new_obs = adata.obs[columns_obs_only]
uns = OrderedDict()
# reset uns and keep numerical/non-numerical columns
num_vars, non_num_vars = adata.uns["numerical_columns"], adata.uns["non_numerical_columns"]
num_vars = _get_var_indices_for_type(adata, NUMERIC_TAG)
non_num_vars = _get_var_indices_for_type(adata, NON_NUMERIC_TAG)
for cat in categoricals:
original_values = adata.uns["original_values_categoricals"][cat]
type_first_nan = original_values[np.where(original_values != np.nan)][0]
Expand All @@ -695,6 +694,11 @@ def _undo_encoding(
else:
non_num_vars.append(cat)

var = pd.DataFrame(index=new_var_names)
var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG
# Notice previously encoded columns are now newly added, and will stay tagged as non numeric
var.loc[num_vars, EHRAPY_TYPE_KEY] = NUMERIC_TAG

uns["numerical_columns"] = num_vars
uns["non_numerical_columns"] = non_num_vars

Expand All @@ -703,7 +707,7 @@ def _undo_encoding(
return AnnData(
new_x,
obs=new_obs,
var=pd.DataFrame(index=new_var_names),
var=var,
uns=uns,
layers={"original": new_x.copy()},
)
Expand Down Expand Up @@ -877,7 +881,7 @@ def _add_categoricals_to_uns(original: AnnData, new: AnnData, categorical_names:
continue
elif var_name in categorical_names:
# keep numerical dtype when writing original values to uns
if var_name in original.uns["numerical_columns"]:
if var_name in original.var_names[original.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG]:
new["original_values_categoricals"][var_name] = original.X[::, idx : idx + 1].astype("float")
else:
new["original_values_categoricals"][var_name] = original.X[::, idx : idx + 1].astype("str")
Expand Down
13 changes: 7 additions & 6 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ehrapy import logging as logg
from ehrapy import settings
from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_TAG
from ehrapy.anndata.anndata_ext import _get_column_indices
from ehrapy.core._tool_available import _check_module_importable

Expand Down Expand Up @@ -252,7 +253,7 @@ def knn_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using KNN imputation
_knn_impute(adata, var_names, n_neighbours)
Expand Down Expand Up @@ -513,7 +514,7 @@ def soft_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using SoftImpute
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using SoftImpute
_soft_impute(
Expand Down Expand Up @@ -669,7 +670,7 @@ def iterative_svd_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using IterativeSVD
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using IterativeSVD
_iterative_svd_impute(
Expand Down Expand Up @@ -809,7 +810,7 @@ def matrix_factorization_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using MatrixFactorization
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using MatrixFactorization
_matrix_factorization_impute(
Expand Down Expand Up @@ -931,7 +932,7 @@ def nuclear_norm_minimization_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using NuclearNormMinimization
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using NuclearNormMinimization
_nuclear_norm_minimization_impute(
Expand Down Expand Up @@ -1057,7 +1058,7 @@ def mice_forest_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using miceforest
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using miceforest
_miceforest_impute(
Expand Down
Loading

0 comments on commit 3cbeafd

Please sign in to comment.