From f55ef2b294495a1198a639c1b11d8ca39a9d90c8 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Tue, 23 Jan 2024 02:36:21 -0800 Subject: [PATCH] Fix one-hot encoding tests (#644) Signed-off-by: zethson --- docs/usage/usage.md | 1 - ehrapy/preprocessing/_encode.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/usage/usage.md b/docs/usage/usage.md index a50e8d78..9c749245 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -119,7 +119,6 @@ Other than tools, preprocessing steps usually don’t return an easily interpret :nosignatures: preprocessing.encode - preprocessing.undo_encoding ``` ### Normalization diff --git a/ehrapy/preprocessing/_encode.py b/ehrapy/preprocessing/_encode.py index 5fd247b3..c3b4c5d1 100644 --- a/ehrapy/preprocessing/_encode.py +++ b/ehrapy/preprocessing/_encode.py @@ -134,9 +134,9 @@ def _encode( Returns: An :class:`~anndata.AnnData` object with the encoded values in X """ - # check whether original layer exists; if not add it if "original" not in adata.layers.keys(): adata.layers["original"] = adata.X.copy() + # autodetect categorical values, which could lead to more categoricals if autodetect: if "var_to_encoding" in adata.uns.keys(): @@ -146,7 +146,7 @@ def _encode( return adata categoricals_names = _get_var_indices_for_type(adata, NON_NUMERIC_TAG) - # no columns were detected, that would require an encoding (e.g. non numerical columns) + # no columns were detected, that would require an encoding (e.g. non-numerical columns) if not categoricals_names: print("[bold yellow]Detected no columns that need to be encoded. Leaving passed AnnData object unchanged.") return adata @@ -214,7 +214,7 @@ def _encode( # user passed categorical values with encoding mode for each of them else: - # reencode data + # re-encode data if "var_to_encoding" in adata.uns.keys(): encodings = _reorder_encodings(adata, encodings) # type: ignore adata = _undo_encoding(adata, "all") @@ -229,7 +229,7 @@ def _encode( adata.uns["encoding_to_var"] = encodings categoricals_not_flat = list(chain(*encodings.values())) # type: ignore - # this is needed since multi column encoding will get passed a list of list instead of a flat list + # this is needed since multi-column encoding will get passed a list of list instead of a flat list categoricals = list( chain( *( @@ -313,7 +313,7 @@ def _encode( # update current encodings in uns encoded_ann_data.uns["var_to_encoding"] = var_to_encoding - # if the user did not pass every non numerical column for encoding, an Anndata object cannot be created + # if the user did not pass every non-numerical column for encoding, an Anndata object cannot be created except ValueError: raise AnnDataCreationError( "Creation of AnnData object failed. Ensure that you passed all non numerical, " @@ -351,8 +351,8 @@ def _one_hot_encoding( Encoded new X and the corresponding new var names """ original_values = _initial_encoding(uns, categories) - progress.update(task, description="[bold blue]Running one hot encoding on passed columns ...") - encoder = OneHotEncoder(handle_unknown="ignore", sparse=False).fit(original_values) + progress.update(task, description="[bold blue]Running one-hot encoding on passed columns ...") + encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False).fit(original_values) categorical_prefixes = [ f"ehrapycat_{category}_{str(suffix).strip()}" for idx, category in enumerate(categories) @@ -366,7 +366,7 @@ def _one_hot_encoding( progress.update(task, description="[blue]Updating X and var ...") temp_x, temp_var_names = _update_encoded_data(X, transformed, var_names, categorical_prefixes, categories) - progress.update(task, description="[blue]Finished one hot encoding.") + progress.update(task, description="[blue]Finished one-hot encoding.") return temp_x, temp_var_names