Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move information on numerical/non_numerical/encoded_non_numerical from .uns to .var #630

Merged
merged 29 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0c52f13
tests for rank features groups with obs
eroell Dec 6, 2023
dd022ec
first drafted feature ranking using obs
eroell Dec 6, 2023
a818728
fixed encoding names
eroell Dec 6, 2023
0790a80
Merge branch 'main' into rank-features-groups-obs
eroell Dec 6, 2023
adaa53b
remove comment
eroell Dec 6, 2023
37c6a9a
Remove comment
eroell Dec 6, 2023
06c2a00
Remove comment
eroell Dec 6, 2023
6fa3de1
Remove comment
eroell Dec 6, 2023
458520e
Update ehrapy/tools/feature_ranking/_rank_features_groups.py
eroell Dec 6, 2023
210bac6
Remove comment
eroell Dec 6, 2023
ffadf71
Remove comment
eroell Dec 6, 2023
f0f4867
Update ehrapy/tools/feature_ranking/_rank_features_groups.py
eroell Dec 6, 2023
02148f5
Iterable to list and import from future
eroell Dec 6, 2023
a6f5606
no expensive copy
eroell Dec 6, 2023
bf02fff
upated to use layer, obs, or both
eroell Dec 7, 2023
213d8d5
this test data should be more stable
eroell Dec 7, 2023
034f820
Update ehrapy/tools/feature_ranking/_rank_features_groups.py
eroell Dec 7, 2023
6458265
correct indent of previous commit and added comment on dummy X
eroell Dec 7, 2023
f444a59
bug fixes, more tests and (fixed) examples in docstring
eroell Dec 7, 2023
8b3fcce
corrected for tests and modified encode
eroell Dec 18, 2023
e729baf
remove need for fields in .uns, updated use in var
eroell Dec 18, 2023
32313e1
merge
eroell Dec 18, 2023
9c24a11
forgot to commit ep.anndata._constants.py
eroell Dec 18, 2023
f9e598f
Update ehrapy/anndata/anndata_ext.py
eroell Dec 19, 2023
9a7c9b2
Update ehrapy/anndata/anndata_ext.py
eroell Dec 19, 2023
3bd7cba
remove usage of .uns, rewrite tests
eroell Dec 19, 2023
1e1d16f
remove type in docstring
eroell Dec 19, 2023
5fb7c21
Merge remote-tracking branch 'origin/main' into rank-features-groups-obs
eroell Dec 19, 2023
436a6cc
remove commented code
eroell Dec 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ehrapy/anndata/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Typing Column
# -----------------------
# The column name and used values in adata.var for column types.

EHRAPY_TYPE_KEY = "ehrapy_column_type"
NUMERIC_TAG = "numeric"
NON_NUMERIC_TAG = "non_numeric"
NON_NUMERIC_ENCODED_TAG = "non_numeric_encoded"
88 changes: 62 additions & 26 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from scipy.sparse import issparse

from ehrapy import logging as logg
from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG
Zethson marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Sequence
Expand Down Expand Up @@ -93,19 +94,23 @@ def df_to_anndata(

# initializing an OrderedDict with a non-empty dict might not be intended,
# see: https://stackoverflow.com/questions/25480089/right-way-to-initialize-an-ordereddict-using-its-constructor-such-that-it-retain/25480206
uns = OrderedDict()
uns = OrderedDict() # type: ignore
# store all numerical/non-numerical columns that are not obs only
binary_columns = _detect_binary_columns(df, numerical_columns)
uns["numerical_columns"] = list(set(numerical_columns) | set(binary_columns))
uns["non_numerical_columns"] = list(set(dataframes.df.columns) ^ set(uns["numerical_columns"]))

var = pd.DataFrame(index=list(dataframes.df.columns))
var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG
var.loc[var.index.isin(list(set(numerical_columns) | set(binary_columns))), EHRAPY_TYPE_KEY] = NUMERIC_TAG
# in case of encoded columns by ehrapy, want to be able to read it back in
var.loc[var.index.str.contains("ehrapycat"), EHRAPY_TYPE_KEY] = NON_NUMERIC_ENCODED_TAG

all_num = True if len(numerical_columns) == len(list(dataframes.df.columns)) else False
X = X.astype(np.number) if all_num else X.astype(object)
# cast non numerical obs only columns to category or bool dtype, which is needed for writing to .h5ad files
adata = AnnData(
X=X,
obs=_cast_obs_columns(dataframes.obs),
var=pd.DataFrame(index=list(dataframes.df.columns)),
var=var,
layers={"original": X.copy()},
uns=uns,
)
Expand Down Expand Up @@ -202,37 +207,41 @@ def move_to_obs(adata: AnnData, to_obs: list[str] | str, copy_obs: bool = False)
f"Columns `{[col for col in to_obs if col not in adata.var_names.values]}` are not in var_names."
)

cols_to_obs_indices = adata.var_names.isin(to_obs)

num_set = _get_var_indices_for_type(adata, NUMERIC_TAG)
var_num = list(set(to_obs) & set(num_set))
Zethson marked this conversation as resolved.
Show resolved Hide resolved

if copy_obs:
cols_to_obs_indices = adata.var_names.isin(to_obs)
cols_to_obs = adata[:, cols_to_obs_indices].to_df()
adata.obs = adata.obs.join(cols_to_obs)
num_set = set(adata.uns["numerical_columns"].copy())
non_num_set = set(adata.uns["non_numerical_columns"].copy())
var_num = []
var_non_num = []
for var in to_obs:
if var in num_set:
var_num.append(var)
elif var in non_num_set:
var_non_num.append(var)
adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, errors="ignore", downcast="float")
adata.obs = _cast_obs_columns(adata.obs)
else:
cols_to_obs_indices = adata.var_names.isin(to_obs)
df = adata[:, cols_to_obs_indices].to_df()
adata._inplace_subset_var(~cols_to_obs_indices)
adata.obs = adata.obs.join(df)
updated_num_uns, updated_non_num_uns, num_var = _update_uns(adata, to_obs)
adata.obs[num_var] = adata.obs[num_var].apply(pd.to_numeric, errors="ignore", downcast="float")
adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, errors="ignore", downcast="float")
adata.obs = _cast_obs_columns(adata.obs)
adata.uns["numerical_columns"] = updated_num_uns
adata.uns["non_numerical_columns"] = updated_non_num_uns

logg.info(f"Added `{to_obs}` to `obs`.")

return adata


def _get_var_indices_for_type(adata: AnnData, tag: str) -> list[str]:
"""Get indices of columns in var for a given tag.

Args:
adata: The AnnData object
tag: The tag to search for, should be one of `NUMERIC_TAG`, `NON_NUMERIC_TAG` or `NON_NUMERIC_ENCODED_TAG`

Returns:
List of numeric columns
"""
return adata.var_names[adata.var[EHRAPY_TYPE_KEY] == tag].tolist()


def delete_from_obs(adata: AnnData, to_delete: list[str]) -> AnnData:
"""Delete features from obs.

Expand Down Expand Up @@ -305,11 +314,11 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData:
if cols_not_in_x:
new_adata = concat([adata, AnnData(adata.obs[cols_not_in_x])], axis=1)
new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]]
# update uns (copy maybe: could be a costly operation but reduces reference cycles)
# users might save those as separate AnnData object and this could be unexpected behaviour if we dont copy
num_columns_moved, non_num_columns_moved, _ = _update_uns(adata, cols_not_in_x, True)
new_adata.uns["numerical_columns"] = adata.uns["numerical_columns"] + num_columns_moved
new_adata.uns["non_numerical_columns"] = adata.uns["non_numerical_columns"] + non_num_columns_moved

# AnnData's concat discards var if they dont match in their keys, so we need to create a new var
created_var = _create_new_var(adata, cols_not_in_x)
new_adata.var = pd.concat([adata.var, created_var], axis=0)

logg.info(f"Added `{cols_not_in_x}` features to `X`.")
else:
new_adata = adata
Expand Down Expand Up @@ -486,10 +495,18 @@ def get_numeric_vars(adata: AnnData) -> list[str]:
"""
_assert_encoded(adata)

if "numerical_columns" not in adata.uns_keys():
# if "numerical_columns" not in adata.uns_keys():
Zethson marked this conversation as resolved.
Show resolved Hide resolved
# return list(adata.var_names.values)
# else:
# return adata.uns["numerical_columns"]
# if "numerical_columns" in adata.uns_keys():
# raise UserWarning("numerical_columns is deprecated. Use EHRAPY_TYPE_KEY instead.")

# This behaviour is consistent with the previous behaviour, allowing for a simple fully numeric X
if EHRAPY_TYPE_KEY not in adata.var.columns:
return list(adata.var_names.values)
else:
return adata.uns["numerical_columns"]
return _get_var_indices_for_type(adata, NUMERIC_TAG)


def assert_numeric_vars(adata: AnnData, vars: Sequence[str]):
Expand Down Expand Up @@ -579,6 +596,25 @@ def _update_uns(
return all_moved_num_columns, list(all_moved_non_num_columns), None


def _create_new_var(adata: AnnData, cols_not_in_x: list[str]) -> pd.DataFrame:
"""Create a new var DataFrame with the EHRAPY_TYPE_KEY column set for entries from .obs.

Args:
adata: From where to get the .obs
cols_not_in_x: .obs columns to move to X

Returns:
New var DataFrame with EHRAPY_TYPE_KEY column set for entries from .obs
"""
all_moved_num_columns = set(cols_not_in_x) & set(adata.obs.select_dtypes(include="number").columns)

new_var = pd.DataFrame(index=cols_not_in_x)
new_var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG
new_var.loc[list(all_moved_num_columns), EHRAPY_TYPE_KEY] = NUMERIC_TAG

return new_var


def _detect_binary_columns(df: pd.DataFrame, numerical_columns: list[str]) -> list[str]:
"""Detect all columns that contain only 0 and 1 (besides NaNs).

Expand Down
50 changes: 27 additions & 23 deletions ehrapy/preprocessing/_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

from ehrapy import logging as logg
from ehrapy.anndata.anndata_ext import _update_uns
from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG
from ehrapy.anndata.anndata_ext import _get_var_indices_for_type

multi_encoding_modes = {"hash"}
available_encodings = {"one-hot", "label", "count", *multi_encoding_modes}
Expand Down Expand Up @@ -145,8 +146,8 @@ def _encode(
"[bold yellow]The current AnnData object has been already encoded. Returning original AnnData object!"
)
return adata
categoricals_names = _get_var_indices_for_type(adata, NON_NUMERIC_TAG)

categoricals_names = adata.uns["non_numerical_columns"]
# no columns were detected, that would require an encoding (e.g. non numerical columns)
if not categoricals_names:
print("[bold yellow]Detected no columns that need to be encoded. Leaving passed AnnData object unchanged.")
Expand Down Expand Up @@ -196,33 +197,30 @@ def _encode(
)
progress.update(task, description=f"[bold blue]Finished {encodings} of autodetected columns.")

# copy non-encoded columns, and add new tag for encoded columns. This is needed to track encodings
new_var = pd.DataFrame(index=encoded_var_names)
new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG

encoded_ann_data = AnnData(
encoded_x,
obs=adata.obs.copy(),
var={"var_names": encoded_var_names},
var=new_var,
uns=orig_uns_copy,
layers={"original": updated_layer},
)
encoded_ann_data.uns["var_to_encoding"] = {categorical: encodings for categorical in categoricals_names}
encoded_ann_data.uns["encoding_to_var"] = {encodings: categoricals_names}

encoded_ann_data.uns["numerical_columns"] = adata.uns["numerical_columns"].copy()
encoded_ann_data.uns["non_numerical_columns"] = []
encoded_ann_data.uns["encoded_non_numerical_columns"] = [
column for column in encoded_ann_data.var_names if column.startswith("ehrapycat_")
]

_add_categoricals_to_obs(adata, encoded_ann_data, categoricals_names)

# user passed categorical values with encoding mode for each of them
else:
# Required since this would be deleted through side references
non_numericals = adata.uns["non_numerical_columns"].copy()
# reencode data
if "var_to_encoding" in adata.uns.keys():
encodings = _reorder_encodings(adata, encodings) # type: ignore
adata = _undo_encoding(adata, "all")
adata.uns["non_numerical_columns"] = non_numericals

# are all specified encodings valid?
for encoding in encodings.keys(): # type: ignore
if encoding not in available_encodings:
Expand All @@ -248,7 +246,7 @@ def _encode(
"The categorical column names given contain at least one duplicate column. "
"Check the column names to ensure that no column is encoded twice!"
)
elif any(cat in adata.uns["numerical_columns"] for cat in categoricals):
elif any(cat in adata.var_names[adata.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG] for cat in categoricals):
print(
"[bold yellow]At least one of passed column names seems to have numerical dtype. In general it is not recommended "
"to encode numerical columns!"
Expand Down Expand Up @@ -300,11 +298,17 @@ def _encode(
adata.var_names.to_list(),
categoricals,
)

# copy non-encoded columns, and add new tag for encoded columns. This is needed to track encodings
new_var = pd.DataFrame(index=encoded_var_names)
new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG

try:
encoded_ann_data = AnnData(
X=encoded_x,
obs=adata.obs.copy(),
var={"var_names": encoded_var_names},
var=new_var,
uns=orig_uns_copy,
layers={"original": updated_layer},
)
Expand All @@ -317,12 +321,6 @@ def _encode(
"Creation of AnnData object failed. Ensure that you passed all non numerical, "
"categorical values for encoding!"
) from None
updated_num_uns, updated_non_num_uns, _ = _update_uns(adata, categoricals)
encoded_ann_data.uns["numerical_columns"] = updated_num_uns
encoded_ann_data.uns["non_numerical_columns"] = updated_non_num_uns
encoded_ann_data.uns["encoded_non_numerical_columns"] = [
column for column in encoded_ann_data.var_names if column.startswith("ehrapycat_")
]

_add_categoricals_to_obs(adata, encoded_ann_data, categoricals)

Expand Down Expand Up @@ -688,7 +686,8 @@ def _undo_encoding(
new_obs = adata.obs[columns_obs_only]
uns = OrderedDict()
# reset uns and keep numerical/non-numerical columns
num_vars, non_num_vars = adata.uns["numerical_columns"], adata.uns["non_numerical_columns"]
num_vars = _get_var_indices_for_type(adata, NUMERIC_TAG)
non_num_vars = _get_var_indices_for_type(adata, NON_NUMERIC_TAG)
for cat in categoricals:
original_values = adata.uns["original_values_categoricals"][cat]
type_first_nan = original_values[np.where(original_values != np.nan)][0]
Expand All @@ -697,6 +696,11 @@ def _undo_encoding(
else:
non_num_vars.append(cat)

var = pd.DataFrame(index=new_var_names)
var[EHRAPY_TYPE_KEY] = NON_NUMERIC_TAG
# Notice previously encoded columns are now newly added, and will stay tagged as non numeric
var.loc[num_vars, EHRAPY_TYPE_KEY] = NUMERIC_TAG

uns["numerical_columns"] = num_vars
uns["non_numerical_columns"] = non_num_vars

Expand All @@ -705,7 +709,7 @@ def _undo_encoding(
return AnnData(
new_x,
obs=new_obs,
var=pd.DataFrame(index=new_var_names),
var=var,
uns=uns,
layers={"original": new_x.copy()},
)
Expand Down Expand Up @@ -880,7 +884,7 @@ def _add_categoricals_to_uns(original: AnnData, new: AnnData, categorical_names:
continue
elif var_name in categorical_names:
# keep numerical dtype when writing original values to uns
if var_name in original.uns["numerical_columns"]:
if var_name in original.var_names[original.var[EHRAPY_TYPE_KEY] == NUMERIC_TAG]:
new["original_values_categoricals"][var_name] = original.X[::, idx : idx + 1].astype("float")
else:
new["original_values_categoricals"][var_name] = original.X[::, idx : idx + 1].astype("str")
Expand Down
13 changes: 7 additions & 6 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ehrapy import logging as logg
from ehrapy import settings
from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_TAG
from ehrapy.anndata.anndata_ext import _get_column_indices
from ehrapy.core._tool_available import _check_module_importable

Expand Down Expand Up @@ -252,7 +253,7 @@ def knn_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using KNN imputation
_knn_impute(adata, var_names, n_neighbours)
Expand Down Expand Up @@ -502,7 +503,7 @@ def soft_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using SoftImpute
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using SoftImpute
_soft_impute(
Expand Down Expand Up @@ -658,7 +659,7 @@ def iterative_svd_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using IterativeSVD
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using IterativeSVD
_iterative_svd_impute(
Expand Down Expand Up @@ -798,7 +799,7 @@ def matrix_factorization_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using MatrixFactorization
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using MatrixFactorization
_matrix_factorization_impute(
Expand Down Expand Up @@ -920,7 +921,7 @@ def nuclear_norm_minimization_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using NuclearNormMinimization
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using NuclearNormMinimization
_nuclear_norm_minimization_impute(
Expand Down Expand Up @@ -1039,7 +1040,7 @@ def mice_forest_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using miceforest
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using miceforest
_miceforest_impute(
Expand Down
Loading
Loading