Skip to content

Commit

Permalink
Unify feature type annotations (#697)
Browse files Browse the repository at this point in the history
* Added infer and check feature types methods

* Added and tested decorator and adapted feature importances

* Added test cases and updated imputation

* Adapted encoding

* Feature specifications output

* Fix HVF test

* Added tree printing for inferred feature types

* Notebook fixes

* Fix feature importance test

* Beautify tree

* Base encoding on original feature types

* Added to usage

* Update logging message

* Improved method description

* Submodule update

Signed-off-by: zethson <[email protected]>

* PR Revisions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update submodule

* Extended method docs description

---------

Signed-off-by: zethson <[email protected]>
Co-authored-by: zethson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 23, 2024
1 parent 5b4c09a commit 169a5bb
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 65 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
1 change: 1 addition & 0 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ Methods that extract and visualize tool-specific annotation in an AnnData object
:toctree: anndata
:nosignatures:
anndata.infer_feature_types
anndata.df_to_anndata
anndata.anndata_to_df
anndata.move_to_obs
Expand Down
1 change: 1 addition & 0 deletions ehrapy/anndata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ehrapy.anndata._feature_specifications import check_feature_types, infer_feature_types
from ehrapy.anndata.anndata_ext import (
anndata_to_df,
delete_from_obs,
Expand Down
8 changes: 7 additions & 1 deletion ehrapy/anndata/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
# -----------------------
# The column name and used values in adata.var for column types.

EHRAPY_TYPE_KEY = "ehrapy_column_type"
EHRAPY_TYPE_KEY = "ehrapy_column_type" # TODO: Change to ENCODING_TYPE_KEY
NUMERIC_TAG = "numeric"
NON_NUMERIC_TAG = "non_numeric"
NON_NUMERIC_ENCODED_TAG = "non_numeric_encoded"


FEATURE_TYPE_KEY = "feature_type"
CONTINUOUS_TAG = "numeric" # TODO: Eventually rename to NUMERIC_TAG (as soon as the other NUMERIC_TAG is removed)
CATEGORICAL_TAG = "categorical"
DATE_TAG = "date"
95 changes: 95 additions & 0 deletions ehrapy/anndata/_feature_specifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Literal

import numpy as np
import pandas as pd
from anndata import AnnData
from rich import print
from rich.tree import Tree

from ehrapy import logging as logg
from ehrapy.anndata._constants import CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG, FEATURE_TYPE_KEY
from ehrapy.anndata.anndata_ext import anndata_to_df


def infer_feature_types(adata: AnnData, layer: str | None = None, output: Literal["tree", "dataframe"] | None = "tree"):
"""Infer feature types from AnnData object.
For each feature in adata.var_names, the method infers one of the following types: 'date', 'categorical', or 'numeric'.
The inferred types are stored in adata.var['feature_type']. Please check the inferred types and adjust if necessary using
adata.var['feature_type']['feature1']='corrected_type'.
Be aware that not all features stored numerically are of 'numeric' type, as categorical features might be stored in a numerically encoded format.
For example, a feature with values [0, 1, 2] might be a categorical feature with three categories. This is accounted for in the method, but it is
recommended to check the inferred types.
Args:
adata: :class:`~anndata.AnnData` object storing the EHR data.
layer: The layer to use from the AnnData object. If None, the X layer is used.
output: The output format. Choose between 'tree', 'dataframe', or None. If 'tree', the feature types will be printed to the console in a tree format.
If 'dataframe', a pandas DataFrame with the feature types will be returned. If None, nothing will be returned. Defaults to 'tree'.
"""
feature_types = {}

df = anndata_to_df(adata, layer=layer)
for feature in adata.var_names:
col = df[feature].dropna()
majority_type = col.apply(type).value_counts().idxmax()
if majority_type == pd.Timestamp:
feature_types[feature] = DATE_TAG
elif majority_type not in [int, float, complex]:
feature_types[feature] = CATEGORICAL_TAG
# Guess categorical if the feature is an integer and the values are 0/1 to n-1 with no gaps
elif np.all(i.is_integer() for i in col) and (
(col.min() == 0 and np.all(np.sort(col.unique()) == np.arange(col.nunique())))
or (col.min() == 1 and np.all(np.sort(col.unique()) == np.arange(1, col.nunique() + 1)))
):
feature_types[feature] = CATEGORICAL_TAG
else:
feature_types[feature] = CONTINUOUS_TAG

adata.var[FEATURE_TYPE_KEY] = pd.Series(feature_types)[adata.var_names]

logg.info(
f"Stored feature types in adata.var['{FEATURE_TYPE_KEY}']. Please verify and adjust if necessary using adata.var['{FEATURE_TYPE_KEY}']['feature1']='corrected_type'."
)

if output == "tree":
feature_type_overview(adata)
elif output == "dataframe":
return adata.var[FEATURE_TYPE_KEY]
elif output is not None:
raise ValueError(f"Output format {output} not recognized. Choose between 'tree', 'dataframe', or None.")


def check_feature_types(func):
def wrapper(adata, *args, **kwargs):
if FEATURE_TYPE_KEY not in adata.var.keys():
raise ValueError("Feature types are not specified in adata.var. Please run `infer_feature_types` first.")
np.all(adata.var[FEATURE_TYPE_KEY].isin([CATEGORICAL_TAG, CONTINUOUS_TAG, DATE_TAG]))
return func(adata, *args, **kwargs)

return wrapper


@check_feature_types
def feature_type_overview(adata: AnnData):
"""Print an overview of the feature types in the AnnData object."""
tree = Tree(
f"[b] Detected feature types for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars",
guide_style="underline2",
)

branch = tree.add("📅[b] Date features")
for date in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == DATE_TAG]):
branch.add(date)

branch = tree.add("📐[b] Numerical features")
for numeric in sorted(adata.var_names[adata.var[FEATURE_TYPE_KEY] == CONTINUOUS_TAG]):
branch.add(numeric)

branch = tree.add("🗂️[b] Categorical features")
cat_features = adata.var_names[adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG]
df = anndata_to_df(adata[:, cat_features])
for categorical in sorted(cat_features):
branch.add(f"{categorical} ({df.loc[:, categorical].nunique()} categories)")

print(tree)
1 change: 1 addition & 0 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def _adata_type_overview(
f"[b green]Variable names for AnnData object with {len(adata.obs_names)} obs and {len(adata.var_names)} vars",
guide_style="underline2 bright_blue",
)

if "var_to_encoding" in adata.uns.keys():
original_values = adata.uns["original_values_categoricals"]
branch = tree.add("🔐 Encoded variables", style="b green")
Expand Down
17 changes: 16 additions & 1 deletion ehrapy/preprocessing/_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

from ehrapy import logging as logg
from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_ENCODED_TAG, NON_NUMERIC_TAG, NUMERIC_TAG
from ehrapy.anndata._constants import (
CATEGORICAL_TAG,
CONTINUOUS_TAG,
DATE_TAG,
EHRAPY_TYPE_KEY,
FEATURE_TYPE_KEY,
NON_NUMERIC_ENCODED_TAG,
NON_NUMERIC_TAG,
NUMERIC_TAG,
)
from ehrapy.anndata.anndata_ext import _get_var_indices_for_type

multi_encoding_modes = {"hash"}
Expand Down Expand Up @@ -141,6 +150,9 @@ def encode(
new_var = pd.DataFrame(index=encoded_var_names)
new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG
if FEATURE_TYPE_KEY in adata.var.keys():
new_var[FEATURE_TYPE_KEY] = adata.var[FEATURE_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat"), FEATURE_TYPE_KEY] = CATEGORICAL_TAG

encoded_ann_data = AnnData(
encoded_x,
Expand Down Expand Up @@ -243,6 +255,9 @@ def encode(
new_var = pd.DataFrame(index=encoded_var_names)
new_var[EHRAPY_TYPE_KEY] = adata.var[EHRAPY_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat")] = NON_NUMERIC_ENCODED_TAG
if FEATURE_TYPE_KEY in adata.var.keys():
new_var[FEATURE_TYPE_KEY] = adata.var[FEATURE_TYPE_KEY].copy()
new_var.loc[new_var.index.str.contains("ehrapycat"), FEATURE_TYPE_KEY] = CATEGORICAL_TAG

try:
encoded_ann_data = AnnData(
Expand Down
27 changes: 20 additions & 7 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from ehrapy import logging as logg
from ehrapy import settings
from ehrapy.anndata._constants import EHRAPY_TYPE_KEY, NON_NUMERIC_TAG
from ehrapy.anndata import check_feature_types
from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY
from ehrapy.anndata.anndata_ext import _get_column_indices
from ehrapy.core._tool_available import _check_module_importable

Expand Down Expand Up @@ -188,6 +189,7 @@ def _simple_impute(adata: AnnData, var_names: Iterable[str] | None, strategy: st
adata.X = imputer.fit_transform(adata.X)


@check_feature_types
def knn_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
Expand Down Expand Up @@ -219,6 +221,7 @@ def knn_impute(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.ad.infer_feature_types(adata)
>>> ep.pp.knn_impute(adata)
"""
if copy:
Expand Down Expand Up @@ -247,7 +250,7 @@ def knn_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation
enc = OrdinalEncoder()
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using KNN imputation
_knn_impute(adata, var_names, n_neighbours)
Expand Down Expand Up @@ -416,6 +419,7 @@ def miss_forest_impute(
return adata


@check_feature_types
def soft_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
Expand Down Expand Up @@ -460,6 +464,7 @@ def soft_impute(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.ad.infer_feature_types(adata)
>>> ep.pp.soft_impute(adata)
"""
if copy:
Expand Down Expand Up @@ -491,7 +496,7 @@ def soft_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using SoftImpute
enc = OrdinalEncoder()
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using SoftImpute
_soft_impute(
Expand Down Expand Up @@ -551,6 +556,7 @@ def _soft_impute(
adata.X = imputer.fit_transform(adata.X)


@check_feature_types
def iterative_svd_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
Expand Down Expand Up @@ -607,6 +613,7 @@ def iterative_svd_impute(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.ad.infer_feature_types(adata)
>>> ep.pp.iterative_svd_impute(adata)
"""
if copy:
Expand Down Expand Up @@ -637,7 +644,7 @@ def iterative_svd_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using IterativeSVD
enc = OrdinalEncoder()
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using IterativeSVD
_iterative_svd_impute(
Expand Down Expand Up @@ -694,6 +701,7 @@ def _iterative_svd_impute(
adata.X = imputer.fit_transform(adata.X)


@check_feature_types
def matrix_factorization_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
Expand Down Expand Up @@ -743,6 +751,7 @@ def matrix_factorization_impute(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.ad.infer_feature_types(adata)
>>> ep.pp.matrix_factorization_impute(adata)
"""
if copy:
Expand Down Expand Up @@ -771,7 +780,7 @@ def matrix_factorization_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using MatrixFactorization
enc = OrdinalEncoder()
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using MatrixFactorization
_matrix_factorization_impute(
Expand Down Expand Up @@ -821,6 +830,7 @@ def _matrix_factorization_impute(
adata.X = imputer.fit_transform(adata.X)


@check_feature_types
def nuclear_norm_minimization_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
Expand Down Expand Up @@ -856,6 +866,7 @@ def nuclear_norm_minimization_impute(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.ad.infer_feature_types(adata)
>>> ep.pp.nuclear_norm_minimization_impute(adata)
"""
if copy:
Expand Down Expand Up @@ -883,7 +894,7 @@ def nuclear_norm_minimization_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using NuclearNormMinimization
enc = OrdinalEncoder()
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using NuclearNormMinimization
_nuclear_norm_minimization_impute(
Expand Down Expand Up @@ -931,6 +942,7 @@ def _nuclear_norm_minimization_impute(
adata.X = imputer.fit_transform(adata.X)


@check_feature_types
def mice_forest_impute(
adata: AnnData,
var_names: Iterable[str] | None = None,
Expand Down Expand Up @@ -972,6 +984,7 @@ def mice_forest_impute(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.ad.infer_feature_types(adata)
>>> ep.pp.mice_forest_impute(adata)
"""
if copy:
Expand Down Expand Up @@ -999,7 +1012,7 @@ def mice_forest_impute(
else:
# ordinal encoding is used since non-numerical data can not be imputed using miceforest
enc = OrdinalEncoder()
column_indices = adata.var[EHRAPY_TYPE_KEY] == NON_NUMERIC_TAG
column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using miceforest
_miceforest_impute(
Expand Down
Loading

0 comments on commit 169a5bb

Please sign in to comment.