Skip to content

Commit

Permalink
Refactor outliers and IQR (#692)
Browse files Browse the repository at this point in the history
* Refactor outliers

Signed-off-by: zethson <[email protected]>

* Add IQR to var_metrics

Signed-off-by: zethson <[email protected]>

* Fix tests

Signed-off-by: zethson <[email protected]>

* Fix tests

Signed-off-by: zethson <[email protected]>

* dtype boolean for metric

Signed-off-by: zethson <[email protected]>

* Disable fate nb

Signed-off-by: zethson <[email protected]>

* Checkout version

Signed-off-by: zethson <[email protected]>

* Only install cellrank

Signed-off-by: zethson <[email protected]>

* Use uv for notebooks

Signed-off-by: zethson <[email protected]>

---------

Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson authored Apr 14, 2024
1 parent eb27f49 commit eebc635
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 147 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/run_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
# "docs/tutorials/notebooks/medcat.ipynb",
]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
with:
submodules: "true"
token: "${{ secrets.CT_SYNC_TOKEN }}"
Expand All @@ -28,11 +28,14 @@ jobs:
with:
python-version: "3.11"

- name: Install UV
run: pip install uv

- name: Install ehrapy
run: pip install .
run: uv pip install --system .

- name: Install additional dependencies
run: pip install medcat cellrank
run: uv pip install --system cellrank

- name: Install nbconvert ipykernel
run: pip install nbconvert ipykernel
Expand Down
7 changes: 4 additions & 3 deletions ehrapy/anndata/anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def df_to_anndata(

all_num = True if len(numerical_columns) == len(list(dataframes.df.columns)) else False
X = X.astype(np.number) if all_num else X.astype(object)
# cast non numerical obs only columns to category or bool dtype, which is needed for writing to .h5ad files
# cast non-numerical obs only columns to category or bool dtype, which is needed for writing to .h5ad files
adata = AnnData(
X=X,
obs=_cast_obs_columns(dataframes.obs),
Expand Down Expand Up @@ -215,13 +215,14 @@ def move_to_obs(adata: AnnData, to_obs: list[str] | str, copy_obs: bool = False)
if copy_obs:
cols_to_obs = adata[:, cols_to_obs_indices].to_df()
adata.obs = adata.obs.join(cols_to_obs)
adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, errors="ignore", downcast="float")
adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, downcast="float")

adata.obs = _cast_obs_columns(adata.obs)
else:
df = adata[:, cols_to_obs_indices].to_df()
adata._inplace_subset_var(~cols_to_obs_indices)
adata.obs = adata.obs.join(df)
adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, errors="ignore", downcast="float")
adata.obs[var_num] = adata.obs[var_num].apply(pd.to_numeric, downcast="float")
adata.obs = _cast_obs_columns(adata.obs)

return adata
Expand Down
73 changes: 2 additions & 71 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def explicit_impute(
# 1: Replace all missing values with the specified value
if isinstance(replacement, (int, str)):
_replace_explicit(adata.X, replacement, impute_empty_strings)
logg.debug(f"Imputed missing values in the AnnData object by `{replacement}`")

# 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named
elif isinstance(replacement, dict):
Expand All @@ -81,9 +80,6 @@ def explicit_impute(
_replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings)
else:
print(f"[bold yellow]No replace value passed and found for var [not bold green]{column_name}.")
logg.debug(
f"Imputed missing values in columns `{replacement.keys()}` by `{replacement.values()}` respectively."
)
else:
raise ValueError( # pragma: no cover
f"Type {type(replacement)} is not a valid datatype for replacement parameter. Either use int, str or a dict!"
Expand Down Expand Up @@ -165,16 +161,14 @@ def simple_impute(
if strategy in {"median", "mean"}:
try:
_simple_impute(adata, var_names, strategy)
logg.debug(f"Imputed the AnnData object using `{strategy}` Imputation.")
except ValueError:
raise ValueError(
f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation"
"to certain columns using var_names parameter or use a different mode."
) from None
# most_frequent imputation works with non numerical data as well
# most_frequent imputation works with non-numerical data as well
elif strategy == "most_frequent":
_simple_impute(adata, var_names, strategy)
logg.debug("Imputed the AnnData object using `most_frequent` Imputation.")
# unknown simple imputation strategy
else:
raise ValueError( # pragma: no cover
Expand Down Expand Up @@ -272,21 +266,11 @@ def knn_impute(
if _check_module_importable("sklearnex"): # pragma: no cover
unpatch_sklearn()

if var_names:
logg.debug(
f"Imputed the columns `{var_names}` in the AnnData object using kNN Imputation with {n_neighbours} neighbours considered."
)
elif not var_names:
logg.debug(
f"Imputed the data in the AnnData object using kNN Imputation with {n_neighbours} neighbours considered."
)

if copy:
return adata


def _knn_impute(adata: AnnData, var_names: Iterable[str] | None, n_neighbours: int) -> None:
"""Utility function to impute data using KNN-Imputer"""
from sklearn.impute import KNNImputer

imputer = KNNImputer(n_neighbors=n_neighbours)
Expand Down Expand Up @@ -428,13 +412,6 @@ def miss_forest_impute(
if _check_module_importable("sklearnex"): # pragma: no cover
unpatch_sklearn()

if var_names:
logg.debug(
f"Imputed the columns `{var_names}` in the AnnData object with MissForest Imputation using {num_initial_strategy} strategy."
)
elif not var_names:
logg.debug("Imputed the data in the AnnData object using MissForest Imputation.")

if copy:
return adata

Expand Down Expand Up @@ -535,15 +512,6 @@ def soft_impute(
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])

if var_names:
logg.debug(
f"Imputed the columns `{var_names}` in the AnnData object using Soft Imputation with shrinkage value of `{shrinkage_value}`."
)
elif not var_names:
logg.debug(
f"Imputed the data in the AnnData object using Soft Imputation with shrinkage value of `{shrinkage_value}`."
)

return adata


Expand All @@ -561,7 +529,6 @@ def _soft_impute(
normalizer,
verbose,
) -> None:
"""Utility function to impute data using SoftImpute"""
from fancyimpute import SoftImpute

imputer = SoftImpute(
Expand Down Expand Up @@ -690,11 +657,6 @@ def iterative_svd_impute(
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])

if var_names:
logg.debug(f"Imputed the columns `{var_names}` in the AnnData object using IterativeSVD Imputation.")
elif not var_names:
logg.debug("Imputed the data in the AnnData object using IterativeSVD Imputation.")

return adata


Expand All @@ -711,7 +673,6 @@ def _iterative_svd_impute(
max_value,
verbose,
) -> None:
"""Utility function to impute data using IterativeSVD"""
from fancyimpute import IterativeSVD

imputer = IterativeSVD(
Expand Down Expand Up @@ -773,7 +734,7 @@ def matrix_factorization_impute(
Defaults to None.
max_value: The maximum value allowed for the imputed data. Any imputed value greater than `max_value` is clipped to `max_value`.
Defaults to None.
verbose: Whether or not to printout training progress. Defaults to False.
verbose: Whether to printout training progress. Defaults to False.
copy: Whether to return a copy or act in place. Defaults to False.
Returns:
Expand Down Expand Up @@ -827,15 +788,6 @@ def matrix_factorization_impute(
adata.X = adata.X.astype("object")
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])

if var_names:
logg.debug(
f"Imputed the columns `{var_names}` in the AnnData object using MatrixFactorization Imputation with learning rate `{learning_rate}` and shrinkage value `{shrinkage_value}`."
)
elif not var_names:
logg.debug(
f"Imputed the data in the AnnData object using MatrixFactorization Imputation with learning rate `{learning_rate}` and shrinkage value `{shrinkage_value}`."
)

return adata


Expand All @@ -850,7 +802,6 @@ def _matrix_factorization_impute(
max_value,
verbose,
) -> None:
"""Utility function to impute data using MatrixFactorization"""
from fancyimpute import MatrixFactorization

imputer = MatrixFactorization(
Expand Down Expand Up @@ -949,15 +900,6 @@ def nuclear_norm_minimization_impute(
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])

if var_names:
logg.debug(
f"Imputed the columns `{var_names}` in the AnnData object using NuclearNormMinimization Imputation with error tolerance of `{error_tolerance}`."
)
elif not var_names:
logg.debug(
f"Imputed the data in the AnnData object using NuclearNormMinimization Imputation with error tolerance of `{error_tolerance}`."
)

return adata


Expand All @@ -971,7 +913,6 @@ def _nuclear_norm_minimization_impute(
max_iters,
verbose,
) -> None:
"""Utility function to impute data using NuclearNormMinimization"""
from fancyimpute import NuclearNormMinimization

imputer = NuclearNormMinimization(
Expand Down Expand Up @@ -1079,22 +1020,12 @@ def mice_forest_impute(
print("[bold red]Check that your matrix does not contain any NaN only columns!")
raise

if var_names:
logg.debug(
f"Imputed the columns `{var_names}` in the AnnData object using MiceForest Imputation with `{iterations}` iterations."
)
elif not var_names:
logg.debug(
f"Imputed the data in the AnnData object using MiceForest Imputation with `{iterations}` iterations."
)

return adata


def _miceforest_impute(
adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose
) -> None:
"""Utility function to impute data using miceforest"""
import miceforest as mf

if isinstance(var_names, Iterable):
Expand Down
57 changes: 23 additions & 34 deletions ehrapy/preprocessing/_outliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
import scipy.stats.mstats

if TYPE_CHECKING:
from collections.abc import Collection

from anndata import AnnData


def winsorize(
adata: AnnData,
vars: str | list[str] | set[str] = None,
obs_cols: str | list[str] | set[str] = None,
limits: list[float] = None,
vars: Collection[str] = None,
obs_cols: Collection[str] = None,
*,
limits: tuple[float, float] = (0.01, 0.99),
copy: bool = False,
**kwargs,
) -> AnnData:
Expand All @@ -23,12 +26,12 @@ def winsorize(
The implementation is based on https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.mstats.winsorize.html
Args:
adata: AnnData object to winsorize
vars: The features to winsorize.
obs_cols: Columns in obs with features to winsorize.
adata: AnnData object to winsorize.
vars: The features to winsorize. Defaults to None.
obs_cols: Columns in obs with features to winsorize. Defaults to None.
limits: Tuple of the percentages to cut on each side of the array as floats between 0. and 1.
Defaults to (0.01, 0.99)
copy: Whether to return a copy or not
copy: Whether to return a copy.
**kwargs: Keywords arguments get passed to scipy.stats.mstats.winsorize
Returns:
Expand All @@ -37,7 +40,7 @@ def winsorize(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.pp.winsorize(adata, ["bmi"])
>>> ep.pp.winsorize(adata, vars=["bmi"])
"""
if copy: # pragma: no cover
adata = adata.copy()
Expand All @@ -61,22 +64,21 @@ def winsorize(

def clip_quantile(
adata: AnnData,
limits: list[float],
vars: str | list[str] | set[str] = None,
obs_cols: str | list[str] | set[str] = None,
limits: tuple[float, float],
vars: Collection[str] = None,
obs_cols: Collection[str] = None,
*,
copy: bool = False,
) -> AnnData:
"""Clips (limits) features.
Given an interval, values outside the interval are clipped to the interval edges.
The implementation is based on https://numpy.org/doc/stable/reference/generated/numpy.clip.html
Args:
adata: The AnnData object
vars: Columns in var with features to clip
adata: The AnnData object to clip.
limits: Values outside the interval are clipped to the interval edges.
vars: Columns in var with features to clip.
obs_cols: Columns in obs with features to clip
limits: Interval, values outside of which are clipped to the interval edges
copy: Whether to return a copy of AnnData or not
Returns:
Expand All @@ -85,7 +87,7 @@ def clip_quantile(
Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=True)
>>> ep.pp.clip_quantile(adata, ["bmi"])
>>> ep.pp.clip_quantile(adata, vars=["bmi"])
"""
obs_cols, vars = _validate_outlier_input(adata, obs_cols, vars) # type: ignore

Expand All @@ -106,23 +108,10 @@ def clip_quantile(
return adata


def _validate_outlier_input(
adata, obs_cols: str | list[str] | set[str], vars: str | list[str] | set[str]
) -> tuple[set[str], set[str]]:
"""Validates the obs/var columns for outlier preprocessing.
Args:
adata: AnnData object
obs_cols: str or list of obs columns
vars: str or list of var names
Returns:
A tuple of lists of obs/var columns
"""
if isinstance(vars, str) or isinstance(vars, list): # pragma: no cover
vars = set(vars)
if isinstance(obs_cols, str) or isinstance(obs_cols, list): # pragma: no cover
obs_cols = set(obs_cols)
def _validate_outlier_input(adata, obs_cols: Collection[str], vars: Collection[str]) -> tuple[set[str], set[str]]:
"""Validates the obs/var columns for outlier preprocessing."""
vars = set(vars) if vars else set()
obs_cols = set(obs_cols) if obs_cols else set()

if vars is not None:
diff = vars - set(adata.var_names)
Expand Down
Loading

0 comments on commit eebc635

Please sign in to comment.