Skip to content

Commit

Permalink
Fix one-hot encoding tests (#644)
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson authored Jan 23, 2024
1 parent 68d6498 commit f55ef2b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
1 change: 0 additions & 1 deletion docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ Other than tools, preprocessing steps usually don’t return an easily interpret
:nosignatures:
preprocessing.encode
preprocessing.undo_encoding
```

### Normalization
Expand Down
16 changes: 8 additions & 8 deletions ehrapy/preprocessing/_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ def _encode(
Returns:
An :class:`~anndata.AnnData` object with the encoded values in X
"""
# check whether original layer exists; if not add it
if "original" not in adata.layers.keys():
adata.layers["original"] = adata.X.copy()

# autodetect categorical values, which could lead to more categoricals
if autodetect:
if "var_to_encoding" in adata.uns.keys():
Expand All @@ -146,7 +146,7 @@ def _encode(
return adata
categoricals_names = _get_var_indices_for_type(adata, NON_NUMERIC_TAG)

# no columns were detected, that would require an encoding (e.g. non numerical columns)
# no columns were detected, that would require an encoding (e.g. non-numerical columns)
if not categoricals_names:
print("[bold yellow]Detected no columns that need to be encoded. Leaving passed AnnData object unchanged.")
return adata
Expand Down Expand Up @@ -214,7 +214,7 @@ def _encode(

# user passed categorical values with encoding mode for each of them
else:
# reencode data
# re-encode data
if "var_to_encoding" in adata.uns.keys():
encodings = _reorder_encodings(adata, encodings) # type: ignore
adata = _undo_encoding(adata, "all")
Expand All @@ -229,7 +229,7 @@ def _encode(
adata.uns["encoding_to_var"] = encodings

categoricals_not_flat = list(chain(*encodings.values())) # type: ignore
# this is needed since multi column encoding will get passed a list of list instead of a flat list
# this is needed since multi-column encoding will get passed a list of list instead of a flat list
categoricals = list(
chain(
*(
Expand Down Expand Up @@ -313,7 +313,7 @@ def _encode(
# update current encodings in uns
encoded_ann_data.uns["var_to_encoding"] = var_to_encoding

# if the user did not pass every non numerical column for encoding, an Anndata object cannot be created
# if the user did not pass every non-numerical column for encoding, an Anndata object cannot be created
except ValueError:
raise AnnDataCreationError(
"Creation of AnnData object failed. Ensure that you passed all non numerical, "
Expand Down Expand Up @@ -351,8 +351,8 @@ def _one_hot_encoding(
Encoded new X and the corresponding new var names
"""
original_values = _initial_encoding(uns, categories)
progress.update(task, description="[bold blue]Running one hot encoding on passed columns ...")
encoder = OneHotEncoder(handle_unknown="ignore", sparse=False).fit(original_values)
progress.update(task, description="[bold blue]Running one-hot encoding on passed columns ...")
encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False).fit(original_values)
categorical_prefixes = [
f"ehrapycat_{category}_{str(suffix).strip()}"
for idx, category in enumerate(categories)
Expand All @@ -366,7 +366,7 @@ def _one_hot_encoding(
progress.update(task, description="[blue]Updating X and var ...")

temp_x, temp_var_names = _update_encoded_data(X, transformed, var_names, categorical_prefixes, categories)
progress.update(task, description="[blue]Finished one hot encoding.")
progress.update(task, description="[blue]Finished one-hot encoding.")

return temp_x, temp_var_names

Expand Down

0 comments on commit f55ef2b

Please sign in to comment.