From d3d6d3c74d6b80d7858c657897341865bc1028c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 24 Nov 2024 13:15:07 +0100 Subject: [PATCH 01/12] [pre-commit.ci] pre-commit autoupdate (#826) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.3 → v0.7.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.7.3...v0.7.4) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4b3d9285..edd6d21f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.3 + rev: v0.7.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] From f05adda01099fc2c29e99bf0a78b49b6b2f4cb82 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 13:18:00 -0800 Subject: [PATCH 02/12] [pre-commit.ci] pre-commit autoupdate (#828) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.4 → v0.8.0](https://github.com/astral-sh/ruff-pre-commit/compare/v0.7.4...v0.8.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index edd6d21f..0a5b6429 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.4 + rev: v0.8.0 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] From 67fedbf59c20e6d9cb52212e6c74a742995b2df1 Mon Sep 17 00:00:00 2001 From: Nicolas Sidoux <145907251+nicolassidoux@users.noreply.github.com> Date: Mon, 25 Nov 2024 23:28:03 +0100 Subject: [PATCH 03/12] Make all imputation methods consistent in regard to encoding requirements (#827) * Before test * After tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review part 1 Co-authored-by: Lukas Heumos * @nicolassidoux @Zethson Apply suggestions from code review part 2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updated _base_check_imputation to throw exception * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added spinner support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * After @eroell review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed spinner to Rich * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed missing import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * After @eroell review * Updated returns in imputation, rewrote miss_forest_impute * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed imputation returns --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lukas Heumos Co-authored-by: PRECIPOINT\nicolas.sidoux --- ehrapy/_settings.py | 2 +- ..._tool_available.py => _utils_available.py} | 6 +- ehrapy/{_doc_util.py => _utils_doc.py} | 0 ehrapy/_utils_rendering.py | 21 + ehrapy/anndata/anndata_ext.py | 52 ++- ehrapy/plot/_scanpy_pl_api.py | 2 +- ehrapy/preprocessing/_imputation.py | 374 +++++++----------- ehrapy/preprocessing/_normalization.py | 6 +- tests/anndata/test_anndata_ext.py | 16 + tests/preprocessing/test_imputation.py | 236 +++++++---- .../test_utils_available.py} | 4 +- 11 files changed, 407 insertions(+), 312 deletions(-) rename ehrapy/{core/_tool_available.py => _utils_available.py} (79%) rename ehrapy/{_doc_util.py => _utils_doc.py} (100%) create mode 100644 ehrapy/_utils_rendering.py rename tests/{core/_test_tool_available.py => utils/test_utils_available.py} (90%) diff --git a/ehrapy/_settings.py b/ehrapy/_settings.py index 3547af0a..f733c059 100644 --- a/ehrapy/_settings.py +++ b/ehrapy/_settings.py @@ -53,7 +53,7 @@ def __init__( figdir: str | Path = "./figures/", cache_compression: str | None = "lzf", max_memory=15, - n_jobs: int = 1, + n_jobs: int = -1, logfile: str | Path | None = None, categories_to_ignore: Iterable[str] = ("N/A", "dontknow", "no_gate", "?"), _frameon: bool = True, diff --git a/ehrapy/core/_tool_available.py b/ehrapy/_utils_available.py similarity index 79% rename from ehrapy/core/_tool_available.py rename to ehrapy/_utils_available.py index 75153b41..7b116681 100644 --- a/ehrapy/core/_tool_available.py +++ b/ehrapy/_utils_available.py @@ -4,7 +4,7 @@ from subprocess import PIPE, Popen -def _check_module_importable(package: str) -> bool: # pragma: no cover +def _check_module_importable(package: str) -> bool: """Checks whether a module is installed and can be loaded. Args: @@ -19,7 +19,7 @@ def _check_module_importable(package: str) -> bool: # pragma: no cover return module_available -def _shell_command_accessible(command: list[str]) -> bool: # pragma: no cover +def _shell_command_accessible(command: list[str]) -> bool: """Checks whether the provided command is accessible in the current shell. Args: @@ -29,7 +29,7 @@ def _shell_command_accessible(command: list[str]) -> bool: # pragma: no cover True if the command is accessible, False otherwise. """ command_accessible = Popen(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True) - (commmand_stdout, command_stderr) = command_accessible.communicate() + command_accessible.communicate() if command_accessible.returncode != 0: return False diff --git a/ehrapy/_doc_util.py b/ehrapy/_utils_doc.py similarity index 100% rename from ehrapy/_doc_util.py rename to ehrapy/_utils_doc.py diff --git a/ehrapy/_utils_rendering.py b/ehrapy/_utils_rendering.py new file mode 100644 index 00000000..43596c54 --- /dev/null +++ b/ehrapy/_utils_rendering.py @@ -0,0 +1,21 @@ +import functools + +from rich.progress import Progress, SpinnerColumn + + +def spinner(message: str = "Running task"): + def wrap(func): + @functools.wraps(func) + def wrapped_f(*args, **kwargs): + with Progress( + "[progress.description]{task.description}", + SpinnerColumn(), + refresh_per_second=1500, + ) as progress: + progress.add_task(f"[blue]{message}", total=1) + result = func(*args, **kwargs) + return result + + return wrapped_f + + return wrap diff --git a/ehrapy/anndata/anndata_ext.py b/ehrapy/anndata/anndata_ext.py index 36bdd6a5..a82721d3 100644 --- a/ehrapy/anndata/anndata_ext.py +++ b/ehrapy/anndata/anndata_ext.py @@ -3,7 +3,7 @@ import random from collections import OrderedDict from string import ascii_letters -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple import numpy as np import pandas as pd @@ -303,7 +303,7 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: return new_adata -def _get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[int]: +def get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[int]: """Fetches the column indices in X for a given list of column names Args: @@ -383,7 +383,7 @@ def set_numeric_vars( if copy: adata = adata.copy() - vars_idx = _get_column_indices(adata, vars) + vars_idx = get_column_indices(adata, vars) adata.X[:, vars_idx] = values @@ -663,3 +663,49 @@ def get_rank_features_df( class NotEncodedError(AssertionError): pass + + +def _are_ndarrays_equal(arr1: np.ndarray, arr2: np.ndarray) -> np.bool_: + """Check if two arrays are equal member-wise. + + Note: Two NaN are considered equal. + + Args: + arr1: First array to compare + arr2: Second array to compare + + Returns: + True if the two arrays are equal member-wise + """ + return np.all(np.equal(arr1, arr2, dtype=object) | ((arr1 != arr1) & (arr2 != arr2))) + + +def _is_val_missing(data: np.ndarray) -> np.ndarray[Any, np.dtype[np.bool_]]: + """Check if values in a AnnData matrix are missing. + + Args: + data: The AnnData matrix to check + + Returns: + An array of bool representing the missingness of the original data, with the same shape + """ + return np.isin(data, [None, ""]) | (data != data) + + +def _to_dense_matrix(adata: AnnData, layer: str | None = None) -> np.ndarray: # pragma: no cover + """Extract a layer from an AnnData object and convert it to a dense matrix if required. + + Args: + adata: The AnnData where to extract the layer from. + layer: Name of the layer to extract. If omitted, X is considered. + + Returns: + The layer as a dense matrix. If a conversion was required, this function returns a copy of the original layer, + othersize this function returns a reference. + """ + from scipy.sparse import issparse + + if layer is None: + return adata.X.toarray() if issparse(adata.X) else adata.X + else: + return adata.layers[layer].toarray() if issparse(adata.layers[layer]) else adata.layers[layer] diff --git a/ehrapy/plot/_scanpy_pl_api.py b/ehrapy/plot/_scanpy_pl_api.py index d09ca8ed..a1d032a2 100644 --- a/ehrapy/plot/_scanpy_pl_api.py +++ b/ehrapy/plot/_scanpy_pl_api.py @@ -9,7 +9,7 @@ import scanpy as sc from scanpy.plotting import DotPlot, MatrixPlot, StackedViolin -from ehrapy._doc_util import ( +from ehrapy._utils_doc import ( _doc_params, doc_adata_color_etc, doc_common_groupby_plot_args, diff --git a/ehrapy/preprocessing/_imputation.py b/ehrapy/preprocessing/_imputation.py index 6ff60a2d..60facfeb 100644 --- a/ehrapy/preprocessing/_imputation.py +++ b/ehrapy/preprocessing/_imputation.py @@ -7,22 +7,20 @@ import numpy as np import pandas as pd from lamin_utils import logger -from rich import print -from rich.progress import Progress, SpinnerColumn -from sklearn.experimental import enable_iterative_imputer # required to enable IterativeImputer (experimental feature) +from sklearn.experimental import enable_iterative_imputer # noinspection PyUnresolvedReference from sklearn.impute import SimpleImputer -from sklearn.preprocessing import OrdinalEncoder from ehrapy import settings +from ehrapy._utils_available import _check_module_importable +from ehrapy._utils_rendering import spinner from ehrapy.anndata import check_feature_types -from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY -from ehrapy.anndata.anndata_ext import _get_column_indices -from ehrapy.core._tool_available import _check_module_importable +from ehrapy.anndata.anndata_ext import get_column_indices if TYPE_CHECKING: from anndata import AnnData +@spinner("Performing explicit impute") def explicit_impute( adata: AnnData, replacement: (str | int) | (dict[str, str | int]), @@ -30,7 +28,7 @@ def explicit_impute( impute_empty_strings: bool = True, warning_threshold: int = 70, copy: bool = False, -) -> AnnData: +) -> AnnData | None: """Replaces all missing values in all columns or a subset of columns specified by the user with the passed replacement value. There are two scenarios to cover: @@ -47,7 +45,7 @@ def explicit_impute( Returns: If copy is True, a modified copy of the original AnnData object with imputed X. - If copy is False, the original AnnData object is modified in place. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: Replace all missing values in adata with the value 0: @@ -56,7 +54,7 @@ def explicit_impute( >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.explicit_impute(adata, replacement=0) """ - if copy: # pragma: no cover + if copy: adata = adata.copy() if isinstance(replacement, int) or isinstance(replacement, str): @@ -64,32 +62,25 @@ def explicit_impute( else: _warn_imputation_threshold(adata, var_names=replacement.keys(), threshold=warning_threshold) # type: ignore - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running explicit imputation", total=1) - # 1: Replace all missing values with the specified value - if isinstance(replacement, (int, str)): - _replace_explicit(adata.X, replacement, impute_empty_strings) - - # 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named - elif isinstance(replacement, dict): - for idx, column_name in enumerate(adata.var_names): - imputation_value = _extract_impute_value(replacement, column_name) - # only replace if an explicit value got passed or could be extracted from replacement - if imputation_value: - _replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings) - else: - logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.") - else: - raise ValueError( # pragma: no cover - f"Type {type(replacement)} is not a valid datatype for replacement parameter. Either use int, str or a dict!" - ) + # 1: Replace all missing values with the specified value + if isinstance(replacement, (int, str)): + _replace_explicit(adata.X, replacement, impute_empty_strings) + + # 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named + elif isinstance(replacement, dict): + for idx, column_name in enumerate(adata.var_names): + imputation_value = _extract_impute_value(replacement, column_name) + # only replace if an explicit value got passed or could be extracted from replacement + if imputation_value: + _replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings) + else: + logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.") + else: + raise ValueError( # pragma: no cover + f"Type {type(replacement)} is not a valid datatype for replacement parameter. Either use int, str or a dict!" + ) - if copy: - return adata + return adata if copy else None def _replace_explicit(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> None: @@ -119,6 +110,7 @@ def _extract_impute_value(replacement: dict[str, str | int], column_name: str) - return None +@spinner("Performing simple impute") def simple_impute( adata: AnnData, var_names: Iterable[str] | None = None, @@ -126,9 +118,12 @@ def simple_impute( strategy: Literal["mean", "median", "most_frequent"] = "mean", copy: bool = False, warning_threshold: int = 70, -) -> AnnData: +) -> AnnData | None: """Impute missing values in numerical data using mean/median/most frequent imputation. + If required and using mean or median strategy, the data needs to be properly encoded as this imputation requires + numerical data only. + Args: adata: The annotated data matrix to impute missing values on. var_names: A list of column names to apply imputation on (if None, impute all columns). @@ -137,13 +132,8 @@ def simple_impute( copy:Whether to return a copy of `adata` or modify it inplace. Returns: - An updated AnnData object with imputed values. - - Raises: - ValueError: - If the selected imputation strategy is not applicable to the data. - ValueError: - If an unknown imputation strategy is provided. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep @@ -155,43 +145,35 @@ def simple_impute( _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task(f"[blue]Running simple imputation with {strategy}", total=1) - # Imputation using median and mean strategy works with numerical data only - if strategy in {"median", "mean"}: - try: - _simple_impute(adata, var_names, strategy) - except ValueError: - raise ValueError( - f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation" - "to certain columns using var_names parameter or use a different mode." - ) from None - # most_frequent imputation works with non-numerical data as well - elif strategy == "most_frequent": + if strategy in {"median", "mean"}: + try: _simple_impute(adata, var_names, strategy) - # unknown simple imputation strategy - else: - raise ValueError( # pragma: no cover - f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent." + except ValueError: + raise ValueError( + f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation " + "to certain columns using var_names parameter or use a different mode." ) from None + # most_frequent imputation works with non-numerical data as well + elif strategy == "most_frequent": + _simple_impute(adata, var_names, strategy) + else: + raise ValueError( + f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent." + ) from None - if copy: - return adata + return adata if copy else None def _simple_impute(adata: AnnData, var_names: Iterable[str] | None, strategy: str) -> None: imputer = SimpleImputer(strategy=strategy) - if isinstance(var_names, Iterable): - column_indices = _get_column_indices(adata, var_names) + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): + column_indices = get_column_indices(adata, var_names) adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices]) else: adata.X = imputer.fit_transform(adata.X) +@spinner("Performing KNN impute") @check_feature_types def knn_impute( adata: AnnData, @@ -206,9 +188,7 @@ def knn_impute( ) -> AnnData: """Imputes missing values in the input AnnData object using K-nearest neighbor imputation. - When using KNN Imputation with mixed data (non-numerical and numerical), encoding using ordinal encoding is required - since KNN Imputation can only work on numerical data. The encoding itself is just a utility and will be undone once - imputation ran successfully. + If required, the data needs to be properly encoded as this imputation requires numerical data only. .. warning:: Currently, both `n_neighbours` and `n_neighbors` are accepted as parameters for the number of neighbors. @@ -234,10 +214,8 @@ def knn_impute( kwargs: Gathering keyword arguments of earlier ehrapy versions for backwards compatibility. It is encouraged to use the here listed, current arguments. Returns: - An updated AnnData object with imputed values. - - Raises: - ValueError: If the input data matrix contains only categorical (non-numeric) values. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep @@ -274,40 +252,26 @@ def knn_impute( from sklearnex import patch_sklearn, unpatch_sklearn patch_sklearn() + try: - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running KNN imputation", total=1) - # numerical only data needs no encoding since KNN Imputation can be applied directly - if np.issubdtype(adata.X.dtype, np.number): - _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) - else: - # ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation - enc = OrdinalEncoder() - column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG - adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) - # impute the data using KNN imputation - _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) - # imputing on encoded columns might result in float numbers; those can not be decoded - # cast them to int to ensure they can be decoded - adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int) - # knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding - adata.X = adata.X.astype("object") - # decode ordinal encoding to obtain imputed original data - adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + if np.issubdtype(adata.X.dtype, np.number): + _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) + else: + # Raise exception since non-numerical data can not be imputed using KNN Imputation + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using " + "var_names parameter or perform an encoding of your data." + ) + except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.error("Check that your matrix does not contain any NaN only columns!") - raise + raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() - if copy: - return adata + return adata if copy else None def _knn_impute( @@ -326,8 +290,8 @@ def _knn_impute( imputer = FaissImputer(n_neighbors=n_neighbors, **kwargs) - if isinstance(var_names, Iterable): - column_indices = _get_column_indices(adata, var_names) + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): + column_indices = get_column_indices(adata, var_names) adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices]) # this is required since X dtype has to be numerical in order to correctly round floats adata.X = adata.X.astype("float64") @@ -335,17 +299,18 @@ def _knn_impute( adata.X = imputer.fit_transform(adata.X) +@spinner("Performing miss-forest impute") def miss_forest_impute( adata: AnnData, - var_names: dict[str, list[str]] | list[str] | None = None, + var_names: Iterable[str] | None = None, *, num_initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = "mean", max_iter: int = 3, - n_estimators=100, + n_estimators: int = 100, random_state: int = 0, warning_threshold: int = 70, copy: bool = False, -) -> AnnData: +) -> AnnData | None: """Impute data using the MissForest strategy. This function uses the MissForest strategy to impute missing values in the data matrix of an AnnData object. @@ -353,12 +318,12 @@ def miss_forest_impute( and using the trained model to predict the missing values. See https://academic.oup.com/bioinformatics/article/28/1/112/219101. - This requires the computation of which columns in X contain numerical only (including NaNs) and which contain non-numerical data. + + If required, the data needs to be properly encoded as this imputation requires numerical data only. Args: adata: The AnnData object to use MissForest Imputation on. - var_names: List of columns to impute or a dict with two keys ('numerical' and 'non_numerical') indicating which var - contain mixed data and which numerical data only. + var_names: Iterable of columns to impute num_initial_strategy: The initial strategy to replace all missing numerical values with. max_iter: The maximum number of iterations if the stop criterion has not been met yet. n_estimators: The number of trees to fit for every missing variable. Has a big effect on the run time. @@ -368,21 +333,20 @@ def miss_forest_impute( copy: Whether to return a copy or act in place. Returns: - The imputed (but unencoded) AnnData object. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.miss_forest_impute(adata) """ - if copy: # pragma: no cover + if copy: adata = adata.copy() if var_names is None: _warn_imputation_threshold(adata, list(adata.var_names), threshold=warning_threshold) - elif isinstance(var_names, dict): - _warn_imputation_threshold(adata, var_names.keys(), threshold=warning_threshold) # type: ignore - elif isinstance(var_names, list): + elif isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) if _check_module_importable("sklearnex"): # pragma: no cover @@ -394,74 +358,49 @@ def miss_forest_impute( from sklearn.impute import IterativeImputer try: - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running MissForest imputation", total=1) - - if settings.n_jobs == 1: # pragma: no cover - logger.warning("The number of jobs is only 1. To decrease the runtime set ep.settings.n_jobs=-1.") - - imp_num = IterativeImputer( - estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), - initial_strategy=num_initial_strategy, - max_iter=max_iter, - random_state=random_state, - ) - # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data - imp_cat = IterativeImputer( - estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), - initial_strategy="most_frequent", - max_iter=max_iter, - random_state=random_state, + imp_num = IterativeImputer( + estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), + initial_strategy=num_initial_strategy, + max_iter=max_iter, + random_state=random_state, + ) + # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data + IterativeImputer( + estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), + initial_strategy="most_frequent", + max_iter=max_iter, + random_state=random_state, + ) + + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): # type: ignore + num_indices = get_column_indices(adata, var_names) + else: + num_indices = get_column_indices(adata, adata.var_names) + + if set(num_indices).issubset(_get_non_numerical_column_indices(adata.X)): + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using " + "var_names parameter." ) - if isinstance(var_names, list): - var_indices = _get_column_indices(adata, var_names) # type: ignore - adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices]) - elif isinstance(var_names, dict) or var_names is None: - if var_names: - try: - non_num_vars = var_names["non_numerical"] - num_vars = var_names["numerical"] - except KeyError: # pragma: no cover - raise ValueError( - "One or both of your keys provided for var_names are unknown. Only " - "numerical and non_numerical are available!" - ) from None - non_num_indices = _get_column_indices(adata, non_num_vars) - num_indices = _get_column_indices(adata, num_vars) - - # infer non numerical and numerical indices automatically - else: - non_num_indices_set = _get_non_numerical_column_indices(adata.X) - num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set] - non_num_indices = list(non_num_indices_set) - - # encode all non numerical columns - if non_num_indices: - enc = OrdinalEncoder() - adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices]) - # this step is the most expensive one and might extremely slow down the impute process - if num_indices: - adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices]) - if non_num_indices: - adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices]) - adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices]) + # this step is the most expensive one and might extremely slow down the impute process + if num_indices: + adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices]) + else: + raise ValueError("Cannot find any feature to perform imputation") + except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.error("Check that your matrix does not contain any NaN only columns!") - raise + raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() - if copy: - return adata + return adata if copy else None +@spinner("Performing mice-forest impute") @check_feature_types def mice_forest_impute( adata: AnnData, @@ -475,12 +414,14 @@ def mice_forest_impute( variable_parameters: dict | None = None, verbose: bool = False, copy: bool = False, -) -> AnnData: +) -> AnnData | None: """Impute data using the miceforest. See https://github.com/AnotherSamWilson/miceforest Fast, memory efficient Multiple Imputation by Chained Equations (MICE) with lightgbm. + If required, the data needs to be properly encoded as this imputation requires numerical data only. + Args: adata: The AnnData object containing the data to impute. var_names: A list of variable names to impute. If None, impute all variables. @@ -497,7 +438,8 @@ def mice_forest_impute( copy: Whether to return a copy of the AnnData object or modify it in-place. Returns: - The imputed AnnData object. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep @@ -509,49 +451,31 @@ def mice_forest_impute( adata = adata.copy() _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) + try: - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running miceforest", total=1) - if np.issubdtype(adata.X.dtype, np.number): - _miceforest_impute( - adata, - var_names, - save_all_iterations_data, - random_state, - inplace, - iterations, - variable_parameters, - verbose, - ) - else: - # ordinal encoding is used since non-numerical data can not be imputed using miceforest - enc = OrdinalEncoder() - column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG - adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) - # impute the data using miceforest - _miceforest_impute( - adata, - var_names, - save_all_iterations_data, - random_state, - inplace, - iterations, - variable_parameters, - verbose, - ) - adata.X = adata.X.astype("object") - # decode ordinal encoding to obtain imputed original data - adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + if np.issubdtype(adata.X.dtype, np.number): + _miceforest_impute( + adata, + var_names, + save_all_iterations_data, + random_state, + inplace, + iterations, + variable_parameters, + verbose, + ) + else: + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using " + "var_names parameter." + ) + except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.warning("Check that your matrix does not contain any NaN only columns!") - raise + raise - return adata + return adata if copy else None def _miceforest_impute( @@ -562,8 +486,8 @@ def _miceforest_impute( data_df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names) data_df = data_df.apply(pd.to_numeric, errors="coerce") - if isinstance(var_names, Iterable): - column_indices = _get_column_indices(adata, var_names) + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): + column_indices = get_column_indices(adata, var_names) selected_columns = data_df.iloc[:, column_indices] selected_columns = selected_columns.reset_index(drop=True) @@ -616,27 +540,21 @@ def _warn_imputation_threshold(adata: AnnData, var_names: Iterable[str] | None, return var_name_to_pct -def _get_non_numerical_column_indices(X: np.ndarray) -> set: +def _get_non_numerical_column_indices(arr: np.ndarray) -> set: """Return indices of columns, that contain at least one non-numerical value that is not "Nan".""" - def _is_float_or_nan(val): # pragma: no cover + def _is_float_or_nan(val) -> bool: # pragma: no cover """Check whether a given item is a float or np.nan""" try: - float(val) - except ValueError: - if val is np.nan: - return True + _ = float(val) + return not isinstance(val, bool) + except (ValueError, TypeError): return False - else: - if not isinstance(val, bool): - return True - else: - return False - is_numeric_numpy = np.vectorize(_is_float_or_nan, otypes=[bool]) - mask = np.apply_along_axis(is_numeric_numpy, 0, X) + def _is_float_or_nan_row(row) -> list[bool]: # pragma: no cover + return [_is_float_or_nan(val) for val in row] + mask = np.apply_along_axis(_is_float_or_nan_row, 0, arr) _, column_indices = np.where(~mask) - non_num_indices = set(column_indices) - return non_num_indices + return set(column_indices) diff --git a/ehrapy/preprocessing/_normalization.py b/ehrapy/preprocessing/_normalization.py index 6f3d5e1a..4541cef3 100644 --- a/ehrapy/preprocessing/_normalization.py +++ b/ehrapy/preprocessing/_normalization.py @@ -13,8 +13,8 @@ daskml_pp = None from ehrapy.anndata.anndata_ext import ( - _get_column_indices, assert_numeric_vars, + get_column_indices, get_numeric_vars, set_numeric_vars, ) @@ -48,7 +48,7 @@ def _scale_func_group( adata = _prep_adata_norm(adata, copy) - var_idx = _get_column_indices(adata, vars) + var_idx = get_column_indices(adata, vars) var_values = np.take(adata.X, var_idx, axis=1) if group_key is None: @@ -379,7 +379,7 @@ def log_norm( "or offset negative values with ep.pp.offset_negative_values()." ) - var_idx = _get_column_indices(adata, vars) + var_idx = get_column_indices(adata, vars) var_values = np.take(adata.X, var_idx, axis=1) if offset == 1: diff --git a/tests/anndata/test_anndata_ext.py b/tests/anndata/test_anndata_ext.py index ab3b971a..aca4ca77 100644 --- a/tests/anndata/test_anndata_ext.py +++ b/tests/anndata/test_anndata_ext.py @@ -11,7 +11,9 @@ from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG from ehrapy.anndata.anndata_ext import ( NotEncodedError, + _are_ndarrays_equal, _assert_encoded, + _is_val_missing, anndata_to_df, assert_numeric_vars, delete_from_obs, @@ -500,3 +502,17 @@ def test_set_numeric_vars(adata_strings_encoded): with pytest.raises(NotEncodedError, match=r"not yet been encoded"): set_numeric_vars(adata_strings, values) + + +def test_are_ndarrays_equal(impute_num_adata): + impute_num_adata_copy = impute_num_adata.copy() + assert _are_ndarrays_equal(impute_num_adata.X, impute_num_adata_copy.X) + impute_num_adata_copy.X[0, 0] = 42.0 + assert not _are_ndarrays_equal(impute_num_adata.X, impute_num_adata_copy.X) + + +def test_is_val_missing(impute_num_adata): + assert np.array_equal( + _is_val_missing(impute_num_adata.X), + np.array([[False, False, True], [False, False, False], [True, False, False], [False, False, True]]), + ) diff --git a/tests/preprocessing/test_imputation.py b/tests/preprocessing/test_imputation.py index d35d6055..21379ef0 100644 --- a/tests/preprocessing/test_imputation.py +++ b/tests/preprocessing/test_imputation.py @@ -1,11 +1,14 @@ import os import warnings +from collections.abc import Iterable from pathlib import Path import numpy as np import pytest +from anndata import AnnData from sklearn.exceptions import ConvergenceWarning +from ehrapy.anndata.anndata_ext import _are_ndarrays_equal, _is_val_missing, _to_dense_matrix from ehrapy.preprocessing._imputation import ( _warn_imputation_threshold, explicit_impute, @@ -20,17 +23,127 @@ _TEST_PATH = f"{TEST_DATA_PATH}/imputation" +def _base_check_imputation( + adata_before_imputation: AnnData, + adata_after_imputation: AnnData, + before_imputation_layer: str | None = None, + after_imputation_layer: str | None = None, + imputed_var_names: Iterable[str] | None = None, +): + """Provides a base check for all imputations: + + - Imputation doesn't leave any NaN behind + - Imputation doesn't modify anything in non-imputated columns (if the imputation on a subset was requested) + - Imputation doesn't modify any data that wasn't NaN + + Args: + adata_before_imputation: AnnData before imputation + adata_after_imputation: AnnData after imputation + before_imputation_layer: Layer to consider in the original ``AnnData``, ``X`` if not specified + after_imputation_layer: Layer to consider in the imputated ``AnnData``, ``X`` if not specified + imputed_var_names: Names of the features that were imputated, will consider all of them if not specified + + Raises: + AssertionError: If any of the checks fail. + """ + + layer_before = _to_dense_matrix(adata_before_imputation, before_imputation_layer) + layer_after = _to_dense_matrix(adata_after_imputation, after_imputation_layer) + + if layer_before.shape != layer_after.shape: + raise AssertionError("The shapes of the two layers do not match") + + var_indices = ( + np.arange(layer_before.shape[1]) + if imputed_var_names is None + else [ + adata_before_imputation.var_names.get_loc(var_name) + for var_name in imputed_var_names + if var_name in imputed_var_names + ] + ) + + before_nan_mask = _is_val_missing(layer_before) + imputed_mask = np.zeros(layer_before.shape[1], dtype=bool) + imputed_mask[var_indices] = True + + # Ensure no NaN remains in the imputed columns of layer_after + if np.any(before_nan_mask[:, imputed_mask] & _is_val_missing(layer_after[:, imputed_mask])): + raise AssertionError("NaN found in imputed columns of layer_after.") + + # Ensure unchanged values outside imputed columns + unchanged_mask = ~imputed_mask + if not _are_ndarrays_equal(layer_before[:, unchanged_mask], layer_after[:, unchanged_mask]): + raise AssertionError("Values outside imputed columns were modified.") + + # Ensure imputation does not alter non-NaN values in the imputed columns + imputed_non_nan_mask = (~before_nan_mask) & imputed_mask + if not _are_ndarrays_equal(layer_before[imputed_non_nan_mask], layer_after[imputed_non_nan_mask]): + raise AssertionError("Non-NaN values in imputed columns were modified.") + + # If reaching here: all checks passed + return + + +def test_base_check_imputation_incompatible_shapes(impute_num_adata): + adata_imputed = knn_impute(impute_num_adata, copy=True) + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed[1:, :]) + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed[:, 1:]) + + +def test_base_check_imputation_nan_detected_after_complete_imputation(impute_num_adata): + adata_imputed = knn_impute(impute_num_adata, copy=True) + adata_imputed.X[0, 2] = np.nan + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed) + + +def test_base_check_imputation_nan_detected_after_partial_imputation(impute_num_adata): + var_names = ("col2", "col3") + adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True) + adata_imputed.X[0, 2] = np.nan + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names) + + +def test_base_check_imputation_nan_ignored_if_not_in_imputed_column(impute_num_adata): + var_names = ("col2", "col3") + adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True) + # col1 has a NaN at row 2, should get ignored + _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names) + + +def test_base_check_imputation_change_detected_in_non_imputed_column(impute_num_adata): + var_names = ("col2", "col3") + adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True) + # col1 has a NaN at row 2, let's simulate it has been imputed by mistake + adata_imputed.X[2, 0] = 42.0 + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names) + + +def test_base_check_imputation_change_detected_in_imputed_column(impute_num_adata): + adata_imputed = knn_impute(impute_num_adata, copy=True) + # col3 didn't have a NaN at row 1, let's simulate it has been modified by mistake + adata_imputed.X[1, 2] = 42.0 + with pytest.raises(AssertionError): + _base_check_imputation(impute_num_adata, adata_imputed) + + def test_mean_impute_no_copy(impute_num_adata): + adata_not_imputed = impute_num_adata.copy() simple_impute(impute_num_adata) - assert not np.isnan(impute_num_adata.X).any() + _base_check_imputation(adata_not_imputed, impute_num_adata) def test_mean_impute_copy(impute_num_adata): adata_imputed = simple_impute(impute_num_adata, copy=True) assert id(impute_num_adata) != id(adata_imputed) - assert not np.isnan(adata_imputed.X).any() + _base_check_imputation(impute_num_adata, adata_imputed) def test_mean_impute_throws_error_non_numerical(impute_adata): @@ -39,23 +152,25 @@ def test_mean_impute_throws_error_non_numerical(impute_adata): def test_mean_impute_subset(impute_adata): - adata_imputed = simple_impute(impute_adata, var_names=["intcol", "indexcol"], copy=True) + var_names = ("intcol", "indexcol") + adata_imputed = simple_impute(impute_adata, var_names=var_names, copy=True) - assert not np.all([item != item for item in adata_imputed.X[::, 1:2]]) + _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names) assert np.any([item != item for item in adata_imputed.X[::, 3:4]]) def test_median_impute_no_copy(impute_num_adata): + adata_not_imputed = impute_num_adata.copy() simple_impute(impute_num_adata, strategy="median") - assert not np.isnan(impute_num_adata.X).any() + _base_check_imputation(adata_not_imputed, impute_num_adata) -def test_median_impute_copy(impute_num_adata, impute_adata): +def test_median_impute_copy(impute_num_adata): adata_imputed = simple_impute(impute_num_adata, strategy="median", copy=True) - assert id(impute_adata) != id(adata_imputed) - assert not np.isnan(adata_imputed.X).any() + _base_check_imputation(impute_num_adata, adata_imputed) + assert id(impute_num_adata) != id(adata_imputed) def test_median_impute_throws_error_non_numerical(impute_adata): @@ -64,156 +179,137 @@ def test_median_impute_throws_error_non_numerical(impute_adata): def test_median_impute_subset(impute_adata): - adata_imputed = simple_impute(impute_adata, var_names=["intcol", "indexcol"], strategy="median", copy=True) + var_names = ("intcol", "indexcol") + adata_imputed = simple_impute(impute_adata, var_names=var_names, strategy="median", copy=True) - assert not np.all([item != item for item in adata_imputed.X[::, 1:2]]) - assert np.any([item != item for item in adata_imputed.X[::, 3:4]]) + _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names) def test_most_frequent_impute_no_copy(impute_adata): + adata_not_imputed = impute_adata.copy() simple_impute(impute_adata, strategy="most_frequent") - assert not (np.all([item != item for item in impute_adata.X])) + _base_check_imputation(adata_not_imputed, impute_adata) def test_most_frequent_impute_copy(impute_adata): adata_imputed = simple_impute(impute_adata, strategy="most_frequent", copy=True) + _base_check_imputation(impute_adata, adata_imputed) assert id(impute_adata) != id(adata_imputed) - assert not (np.all([item != item for item in adata_imputed.X])) + + +def test_unknown_simple_imputation_strategy(impute_adata): + with pytest.raises(ValueError): + simple_impute(impute_adata, strategy="invalid_strategy", copy=True) # type: ignore def test_most_frequent_impute_subset(impute_adata): - adata_imputed = simple_impute(impute_adata, var_names=["intcol", "strcol"], strategy="most_frequent", copy=True) + var_names = ("intcol", "strcol") + adata_imputed = simple_impute(impute_adata, var_names=var_names, strategy="most_frequent", copy=True) - assert not (np.all([item != item for item in adata_imputed.X[::, 1:3]])) + _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names) def test_knn_impute_check_backend(impute_num_adata): - knn_impute(impute_num_adata, backend="faiss") - knn_impute(impute_num_adata, backend="scikit-learn") + knn_impute(impute_num_adata, backend="faiss", copy=True) + knn_impute(impute_num_adata, backend="scikit-learn", copy=True) with pytest.raises( ValueError, match="Unknown backend 'invalid_backend' for KNN imputation. Choose between 'scikit-learn' and 'faiss'.", ): - knn_impute(impute_num_adata, backend="invalid_backend") + knn_impute(impute_num_adata, backend="invalid_backend") # type: ignore def test_knn_impute_no_copy(impute_num_adata): + adata_not_imputed = impute_num_adata.copy() knn_impute(impute_num_adata) - assert not (np.all([item != item for item in impute_num_adata.X])) + _base_check_imputation(adata_not_imputed, impute_num_adata) def test_knn_impute_copy(impute_num_adata): adata_imputed = knn_impute(impute_num_adata, n_neighbors=3, copy=True) + _base_check_imputation(impute_num_adata, adata_imputed) assert id(impute_num_adata) != id(adata_imputed) - assert not (np.all([item != item for item in adata_imputed.X])) def test_knn_impute_non_numerical_data(impute_adata): - adata_imputed = knn_impute(impute_adata, n_neighbors=3, copy=True) - - assert not (np.all([item != item for item in adata_imputed.X])) + with pytest.raises(ValueError): + knn_impute(impute_adata, n_neighbors=3, copy=True) def test_knn_impute_numerical_data(impute_num_adata): adata_imputed = knn_impute(impute_num_adata, copy=True) - assert not (np.all([item != item for item in adata_imputed.X])) - - -def test_knn_impute_list_str(impute_adata): - adata_imputed = knn_impute(impute_adata, var_names=["intcol", "strcol", "boolcol"], copy=True) - - assert not (np.all([item != item for item in adata_imputed.X])) + _base_check_imputation(impute_num_adata, adata_imputed) def test_missforest_impute_non_numerical_data(impute_adata): - adata_imputed = miss_forest_impute(impute_adata, copy=True) - - assert not (np.all([item != item for item in adata_imputed.X])) + with pytest.raises(ValueError): + miss_forest_impute(impute_adata, copy=True) def test_missforest_impute_numerical_data(impute_num_adata): warnings.filterwarnings("ignore", category=ConvergenceWarning) adata_imputed = miss_forest_impute(impute_num_adata, copy=True) - assert not (np.all([item != item for item in adata_imputed.X])) + _base_check_imputation(impute_num_adata, adata_imputed) def test_missforest_impute_subset(impute_num_adata): - adata_imputed = miss_forest_impute( - impute_num_adata, var_names={"non_numerical": ["intcol"], "numerical": ["strcol"]}, copy=True - ) - - assert not (np.all([item != item for item in adata_imputed.X])) - - -def test_missforest_impute_list_str(impute_num_adata): - warnings.filterwarnings("ignore", category=ConvergenceWarning) - adata_imputed = miss_forest_impute(impute_num_adata, var_names=["col1", "col2", "col3"], copy=True) - - assert not (np.all([item != item for item in adata_imputed.X])) - - -def test_missforest_impute_dict(impute_adata): warnings.filterwarnings("ignore", category=ConvergenceWarning) - adata_imputed = miss_forest_impute( - impute_adata, var_names={"numerical": ["intcol", "datetime"], "non_numerical": ["strcol", "boolcol"]}, copy=True - ) + var_names = ("col2", "col3") + adata_imputed = miss_forest_impute(impute_num_adata, var_names=var_names, copy=True) - assert not (np.all([item != item for item in adata_imputed.X])) + _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names) @pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") def test_miceforest_impute_no_copy(impute_iris_adata): - adata_imputed = mice_forest_impute(impute_iris_adata) + adata_not_imputed = impute_iris_adata.copy() + mice_forest_impute(impute_iris_adata) - assert id(impute_iris_adata) == id(adata_imputed) + _base_check_imputation(adata_not_imputed, impute_iris_adata) @pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") def test_miceforest_impute_copy(impute_iris_adata): adata_imputed = mice_forest_impute(impute_iris_adata, copy=True) + _base_check_imputation(impute_iris_adata, adata_imputed) assert id(impute_iris_adata) != id(adata_imputed) @pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") def test_miceforest_impute_non_numerical_data(impute_titanic_adata): - adata_imputed = mice_forest_impute(impute_titanic_adata) - - assert not (np.all([item != item for item in adata_imputed.X])) + with pytest.raises(ValueError): + mice_forest_impute(impute_titanic_adata) @pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") def test_miceforest_impute_numerical_data(impute_iris_adata): - adata_imputed = mice_forest_impute(impute_iris_adata) - - assert not (np.all([item != item for item in adata_imputed.X])) - - -@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.") -def test_miceforest_impute_list_str(impute_titanic_adata): - adata_imputed = mice_forest_impute(impute_titanic_adata, var_names=["Cabin", "Age"]) + adata_not_imputed = impute_iris_adata.copy() + mice_forest_impute(impute_iris_adata) - assert not (np.all([item != item for item in adata_imputed.X])) + _base_check_imputation(adata_not_imputed, impute_iris_adata) def test_explicit_impute_all(impute_num_adata): warnings.filterwarnings("ignore", category=FutureWarning) adata_imputed = explicit_impute(impute_num_adata, replacement=1011, copy=True) - assert (adata_imputed.X == 1011).sum() == 3 + _base_check_imputation(impute_num_adata, adata_imputed) + assert np.sum([adata_imputed.X == 1011]) == 3 def test_explicit_impute_subset(impute_adata): adata_imputed = explicit_impute(impute_adata, replacement={"strcol": "REPLACED", "intcol": 1011}, copy=True) - assert (adata_imputed.X == 1011).sum() == 1 - assert (adata_imputed.X == "REPLACED").sum() == 1 + _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=("strcol", "intcol")) + assert np.sum([adata_imputed.X == 1011]) == 1 + assert np.sum([adata_imputed.X == "REPLACED"]) == 1 def test_warning(impute_num_adata): diff --git a/tests/core/_test_tool_available.py b/tests/utils/test_utils_available.py similarity index 90% rename from tests/core/_test_tool_available.py rename to tests/utils/test_utils_available.py index ebcd0f88..7e8044a5 100644 --- a/tests/core/_test_tool_available.py +++ b/tests/utils/test_utils_available.py @@ -1,6 +1,4 @@ -import pytest - -from ehrapy.core._tool_available import _check_module_importable, _shell_command_accessible +from ehrapy._utils_available import _check_module_importable, _shell_command_accessible def test_check_module_importable_true(): From 1d7c5d7dafb8ed0151fab31c9e52d8eb4a7b177c Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Thu, 28 Nov 2024 21:28:42 +0100 Subject: [PATCH 04/12] Python 3.10+ & use uv for docs & fix RTD & support numpy 2 (#830) * uv docs Signed-off-by: zethson * submodule Signed-off-by: zethson * fix types Signed-off-by: zethson * Install scvelo from github Signed-off-by: zethson * cellrank from github Signed-off-by: zethson * cellrank fix? Signed-off-by: zethson * cellrank out Signed-off-by: zethson --------- Signed-off-by: zethson --- .github/workflows/run_notebooks.yml | 5 +- .readthedocs.yml | 17 ++- docs/_ext/edit_on_github.py | 2 +- ehrapy/_utils_doc.py | 4 +- ehrapy/anndata/anndata_ext.py | 4 +- ehrapy/data/_datasets.py | 2 +- ehrapy/plot/_scanpy_pl_api.py | 10 +- ehrapy/preprocessing/_encoding.py | 4 +- ehrapy/preprocessing/_imputation.py | 2 +- ehrapy/preprocessing/_normalization.py | 4 +- ehrapy/preprocessing/_scanpy_pp_api.py | 7 +- ehrapy/tools/_method_options.py | 4 +- ehrapy/tools/_scanpy_tl_api.py | 111 +++++++++--------- ehrapy/tools/causal/_dowhy.py | 2 +- .../tools/cohort_tracking/_cohort_tracker.py | 4 +- .../feature_ranking/_rank_features_groups.py | 4 +- pyproject.toml | 12 +- 17 files changed, 102 insertions(+), 96 deletions(-) diff --git a/.github/workflows/run_notebooks.yml b/.github/workflows/run_notebooks.yml index 7f19f070..995b8dae 100644 --- a/.github/workflows/run_notebooks.yml +++ b/.github/workflows/run_notebooks.yml @@ -13,7 +13,7 @@ jobs: "docs/tutorials/notebooks/ehrapy_introduction.ipynb", "docs/tutorials/notebooks/mimic_2_introduction.ipynb", "docs/tutorials/notebooks/mimic_2_survival_analysis.ipynb", - "docs/tutorials/notebooks/mimic_2_fate.ipynb", + # "docs/tutorials/notebooks/mimic_2_fate.ipynb", # https://github.com/theislab/cellrank/issues/1235 "docs/tutorials/notebooks/mimic_2_causal_inference.ipynb", # "docs/tutorials/notebooks/mimic_3_demo.ipynb", # "docs/tutorials/notebooks/medcat.ipynb", @@ -34,5 +34,8 @@ jobs: - name: Install ehrapy and additional dependencies run: uv pip install --system . cellrank nbconvert ipykernel + - name: Install scvelo from Github + run: uv pip install --system git+https://github.com/theislab/scvelo.git + - name: Run ${{ matrix.notebook }} Notebook run: jupyter nbconvert --to notebook --execute ${{ matrix.notebook }} diff --git a/.readthedocs.yml b/.readthedocs.yml index 5135ef47..bac3abbe 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -3,16 +3,13 @@ build: os: ubuntu-22.04 tools: python: "3.11" - jobs: - pre_build: - - python -c "import ehrapy" - - pip freeze - post_create_environment: - - pip install uv - post_install: - # VIRTUAL_ENV needs to be set manually for now. - # See https://github.com/readthedocs/readthedocs.org/pull/11152/ - - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH pip install .[docs] + commands: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + - uv venv + - uv pip install .[docs] + - .venv/bin/python -m sphinx -T -b html -d docs/_build/doctrees -D language=en docs $READTHEDOCS_OUTPUT/html sphinx: configuration: docs/conf.py fail_on_warning: false diff --git a/docs/_ext/edit_on_github.py b/docs/_ext/edit_on_github.py index 85ed75de..746d2c41 100644 --- a/docs/_ext/edit_on_github.py +++ b/docs/_ext/edit_on_github.py @@ -20,7 +20,7 @@ def get_github_repo(app: Sphinx, path: str) -> str: def _html_page_context( - app: Sphinx, _pagename: str, templatename: str, context: dict[str, Any], doctree: Optional[Any] + app: Sphinx, _pagename: str, templatename: str, context: dict[str, Any], doctree: Any | None ) -> None: # doctree is None - otherwise viewcode fails if templatename != "page.html" or doctree is None: diff --git a/ehrapy/_utils_doc.py b/ehrapy/_utils_doc.py index b0cf9587..3d6b07f7 100644 --- a/ehrapy/_utils_doc.py +++ b/ehrapy/_utils_doc.py @@ -1,9 +1,9 @@ import inspect +from collections.abc import Callable from textwrap import dedent -from typing import Callable, Optional, Union -def getdoc(c_or_f: Union[Callable, type]) -> Optional[str]: # pragma: no cover +def getdoc(c_or_f: Callable | type) -> str | None: # pragma: no cover if getattr(c_or_f, "__doc__", None) is None: return None doc = inspect.getdoc(c_or_f) diff --git a/ehrapy/anndata/anndata_ext.py b/ehrapy/anndata/anndata_ext.py index a82721d3..fb420202 100644 --- a/ehrapy/anndata/anndata_ext.py +++ b/ehrapy/anndata/anndata_ext.py @@ -404,7 +404,7 @@ def _detect_binary_columns(df: pd.DataFrame, numerical_columns: list[str]) -> li for column in numerical_columns: # checking for float and int as well as NaNs (this is safe since checked columns are numericals only) # only columns that contain at least one 0 and one 1 are counted as binary (or 0.0/1.0) - if df[column].isin([0.0, 1.0, np.NaN, 0, 1]).all() and df[column].nunique() == 2: + if df[column].isin([0.0, 1.0, np.nan, 0, 1]).all() and df[column].nunique() == 2: binary_columns.append(column) return binary_columns @@ -423,7 +423,7 @@ def _cast_obs_columns(obs: pd.DataFrame) -> pd.DataFrame: # type cast each non-numerical column to either bool (if possible) or category else obs[object_columns] = obs[object_columns].apply( lambda obs_name: obs_name.astype("category") - if not set(pd.unique(obs_name)).issubset({False, True, np.NaN}) + if not set(pd.unique(obs_name)).issubset({False, True, np.nan}) else obs_name.astype("bool"), axis=0, ) diff --git a/ehrapy/data/_datasets.py b/ehrapy/data/_datasets.py index 5719373a..7957d17c 100644 --- a/ehrapy/data/_datasets.py +++ b/ehrapy/data/_datasets.py @@ -743,7 +743,7 @@ def synthea_1k_sample( df = anndata_to_df(adata) df.drop( - columns=[col for col in df.columns if any(isinstance(x, (list, dict)) for x in df[col].dropna())], inplace=True + columns=[col for col in df.columns if any(isinstance(x, list | dict) for x in df[col].dropna())], inplace=True ) df.drop(columns=df.columns[df.isna().all()], inplace=True) adata = df_to_anndata(df, index_column="id") diff --git a/ehrapy/plot/_scanpy_pl_api.py b/ehrapy/plot/_scanpy_pl_api.py index a1d032a2..4cf278d6 100644 --- a/ehrapy/plot/_scanpy_pl_api.py +++ b/ehrapy/plot/_scanpy_pl_api.py @@ -1,10 +1,10 @@ from __future__ import annotations -from collections.abc import Collection, Iterable, Mapping, Sequence +from collections.abc import Callable, Collection, Iterable, Mapping, Sequence from enum import Enum from functools import partial from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, Literal, Union +from typing import TYPE_CHECKING, Any, Literal import scanpy as sc from scanpy.plotting import DotPlot, MatrixPlot, StackedViolin @@ -36,12 +36,12 @@ from scanpy.plotting._utils import _AxesSubplot _Basis = Literal["pca", "tsne", "umap", "diffmap", "draw_graph_fr"] -_VarNames = Union[str, Sequence[str]] -ColorLike = Union[str, tuple[float, ...]] +_VarNames = str | Sequence[str] +ColorLike = str | tuple[float, ...] _IGraphLayout = Literal["fa", "fr", "rt", "rt_circular", "drl", "eq_tree", ...] # type: ignore _FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"] _FontSize = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"] -VBound = Union[str, float, Callable[[Sequence[float]], float]] +VBound = str | float | Callable[[Sequence[float]], float] @_doc_params(scatter_temp=doc_scatter_basic, show_save_ax=doc_show_save_ax) diff --git a/ehrapy/preprocessing/_encoding.py b/ehrapy/preprocessing/_encoding.py index 1f761e71..3c3426aa 100644 --- a/ehrapy/preprocessing/_encoding.py +++ b/ehrapy/preprocessing/_encoding.py @@ -73,7 +73,7 @@ def encode( if isinstance(encodings, str) and not autodetect: raise ValueError("Passing a string for parameter encodings is only possible when using autodetect=True!") - elif autodetect and not isinstance(encodings, (str, type(None))): + elif autodetect and not isinstance(encodings, str | type(None)): raise ValueError( f"Setting encode mode with autodetect=True only works by passing a string (encode mode name) or None not {type(encodings)}!" ) @@ -630,7 +630,7 @@ def _update_obs(adata: AnnData, categorical_names: list[str]) -> pd.DataFrame: updated_obs[var_name] = adata.X[::, idx : idx + 1].flatten() # note: this will count binary columns (0 and 1 only) as well # needed for writing to .h5ad files - if set(pd.unique(updated_obs[var_name])).issubset({False, True, np.NaN}): + if set(pd.unique(updated_obs[var_name])).issubset({False, True, np.nan}): updated_obs[var_name] = updated_obs[var_name].astype("bool") # get all non bool object columns and cast them to category dtype object_columns = list(updated_obs.select_dtypes(include="object").columns) diff --git a/ehrapy/preprocessing/_imputation.py b/ehrapy/preprocessing/_imputation.py index 60facfeb..03796c2b 100644 --- a/ehrapy/preprocessing/_imputation.py +++ b/ehrapy/preprocessing/_imputation.py @@ -63,7 +63,7 @@ def explicit_impute( _warn_imputation_threshold(adata, var_names=replacement.keys(), threshold=warning_threshold) # type: ignore # 1: Replace all missing values with the specified value - if isinstance(replacement, (int, str)): + if isinstance(replacement, int | str): _replace_explicit(adata.X, replacement, impute_empty_strings) # 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named diff --git a/ehrapy/preprocessing/_normalization.py b/ehrapy/preprocessing/_normalization.py index 4541cef3..de6cf646 100644 --- a/ehrapy/preprocessing/_normalization.py +++ b/ehrapy/preprocessing/_normalization.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING import numpy as np import sklearn.preprocessing as sklearn_pp @@ -20,7 +20,7 @@ ) if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence import pandas as pd from anndata import AnnData diff --git a/ehrapy/preprocessing/_scanpy_pp_api.py b/ehrapy/preprocessing/_scanpy_pp_api.py index 5317530e..e0e50221 100644 --- a/ehrapy/preprocessing/_scanpy_pp_api.py +++ b/ehrapy/preprocessing/_scanpy_pp_api.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, Literal, Union +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np import scanpy as sc @@ -15,7 +16,7 @@ from ehrapy.preprocessing._types import KnownTransformer -AnyRandom = Union[int, np.random.RandomState, None] +AnyRandom: TypeAlias = int | np.random.RandomState | None def pca( @@ -193,7 +194,7 @@ def combat( "sqeuclidean", "yule", ] -_Metric = Union[_MetricSparseCapable, _MetricScipySpatial] +_Metric = _MetricSparseCapable | _MetricScipySpatial def neighbors( diff --git a/ehrapy/tools/_method_options.py b/ehrapy/tools/_method_options.py index 315f47f7..5bcad67c 100644 --- a/ehrapy/tools/_method_options.py +++ b/ehrapy/tools/_method_options.py @@ -1,11 +1,11 @@ -from typing import Literal, Optional +from typing import Literal _InitPos = Literal["paga", "spectral", "random"] _LAYOUTS = ("fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa") _Layout = Literal[_LAYOUTS] # type: ignore -_rank_features_groups_method = Optional[Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"]] +_rank_features_groups_method = Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"] | None _correction_method = Literal["benjamini-hochberg", "bonferroni"] _rank_features_groups_cat_method = Literal[ "chi-square", "g-test", "freeman-tukey", "mod-log-likelihood", "neyman", "cressie-read" diff --git a/ehrapy/tools/_scanpy_tl_api.py b/ehrapy/tools/_scanpy_tl_api.py index 4e1af6b0..6b82448e 100644 --- a/ehrapy/tools/_scanpy_tl_api.py +++ b/ehrapy/tools/_scanpy_tl_api.py @@ -1,29 +1,34 @@ -from collections.abc import Iterable, Sequence -from typing import Any, Literal, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np import scanpy as sc -from anndata import AnnData -from leidenalg.VertexPartition import MutableVertexPartition -from scipy.sparse import spmatrix -from ehrapy.tools import _method_options +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from anndata import AnnData + from leidenalg.VertexPartition import MutableVertexPartition + from scipy.sparse import spmatrix + + from ehrapy.tools import _method_options -AnyRandom = Union[int, np.random.RandomState, None] +AnyRandom: TypeAlias = int | np.random.RandomState | None def tsne( adata: AnnData, - n_pcs: Optional[int] = None, - use_rep: Optional[str] = None, - perplexity: Union[float, int] = 30, - early_exaggeration: Union[float, int] = 12, - learning_rate: Union[float, int] = 1000, + n_pcs: int | None = None, + use_rep: str | None = None, + perplexity: float | int = 30, + early_exaggeration: float | int = 12, + learning_rate: float | int = 1000, random_state: AnyRandom = 0, - n_jobs: Optional[int] = None, + n_jobs: int | None = None, copy: bool = False, metric: str = "euclidean", -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Calculates t-SNE [Maaten08]_ [Amir13]_ [Pedregosa11]_. t-distributed stochastic neighborhood embedding (tSNE) [Maaten08]_ has been @@ -83,18 +88,18 @@ def umap( min_dist: float = 0.5, spread: float = 1.0, n_components: int = 2, - maxiter: Optional[int] = None, + maxiter: int | None = None, alpha: float = 1.0, gamma: float = 1.0, negative_sample_rate: int = 5, - init_pos: Union[_method_options._InitPos, np.ndarray, None] = "spectral", + init_pos: _method_options._InitPos | np.ndarray | None = "spectral", random_state: AnyRandom = 0, - a: Optional[float] = None, - b: Optional[float] = None, + a: float | None = None, + b: float | None = None, copy: bool = False, method: Literal["umap", "rapids"] = "umap", - neighbors_key: Optional[str] = None, -) -> Optional[AnnData]: # pragma: no cover + neighbors_key: str | None = None, +) -> AnnData | None: # pragma: no cover """Embed the neighborhood graph using UMAP [McInnes18]_. UMAP (Uniform Manifold Approximation and Projection) is a manifold learning @@ -186,17 +191,17 @@ def umap( def draw_graph( adata: AnnData, layout: _method_options._Layout = "fa", - init_pos: Union[str, bool, None] = None, - root: Optional[int] = None, + init_pos: str | bool | None = None, + root: int | None = None, random_state: AnyRandom = 0, - n_jobs: Optional[int] = None, - adjacency: Optional[spmatrix] = None, - key_added_ext: Optional[str] = None, - neighbors_key: Optional[str] = None, - obsp: Optional[str] = None, + n_jobs: int | None = None, + adjacency: spmatrix | None = None, + key_added_ext: str | None = None, + neighbors_key: str | None = None, + obsp: str | None = None, copy: bool = False, **kwds, -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Force-directed graph drawing [Islam11]_ [Jacomy14]_ [Chippada18]_. .. _fa2: https://github.com/bhargavchippada/forceatlas2 @@ -264,10 +269,10 @@ def draw_graph( def diffmap( adata: AnnData, n_comps: int = 15, - neighbors_key: Optional[str] = None, + neighbors_key: str | None = None, random_state: AnyRandom = 0, copy: bool = False, -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Diffusion Maps [Coifman05]_ [Haghverdi15]_ [Wolf18]_. Diffusion maps [Coifman05]_ has been proposed for visualizing single-cell @@ -309,9 +314,9 @@ def diffmap( def embedding_density( adata: AnnData, basis: str = "umap", # was positional before 1.4.5 - groupby: Optional[str] = None, - key_added: Optional[str] = None, - components: Union[str, Sequence[str]] = None, + groupby: str | None = None, + key_added: str | None = None, + components: str | Sequence[str] = None, ) -> None: # pragma: no cover """Calculate the density of observation in an embedding (per condition). @@ -353,19 +358,19 @@ def embedding_density( def leiden( adata: AnnData, resolution: float = 1, - restrict_to: Optional[tuple[str, Sequence[str]]] = None, + restrict_to: tuple[str, Sequence[str]] | None = None, random_state: AnyRandom = 0, key_added: str = "leiden", - adjacency: Optional[spmatrix] = None, + adjacency: spmatrix | None = None, directed: bool = True, use_weights: bool = True, n_iterations: int = -1, - partition_type: Optional[type[MutableVertexPartition]] = None, - neighbors_key: Optional[str] = None, - obsp: Optional[str] = None, + partition_type: type[MutableVertexPartition] | None = None, + neighbors_key: str | None = None, + obsp: str | None = None, copy: bool = False, **partition_kwargs, -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Cluster observations into subgroups [Traag18]_. Cluster observations using the Leiden algorithm [Traag18]_, @@ -429,15 +434,15 @@ def leiden( def dendrogram( adata: AnnData, groupby: str, - n_pcs: Optional[int] = None, - use_rep: Optional[str] = None, - var_names: Optional[Sequence[str]] = None, + n_pcs: int | None = None, + use_rep: str | None = None, + var_names: Sequence[str] | None = None, cor_method: str = "pearson", linkage_method: str = "complete", optimal_ordering: bool = False, - key_added: Optional[str] = None, + key_added: str | None = None, inplace: bool = True, -) -> Optional[dict[str, Any]]: # pragma: no cover +) -> dict[str, Any] | None: # pragma: no cover """Computes a hierarchical clustering for the given `groupby` categories. By default, the PCA representation is used unless `.X` has less than 50 variables. @@ -505,9 +510,9 @@ def dpt( n_branchings: int = 0, min_group_size: float = 0.01, allow_kendall_tau_shift: bool = True, - neighbors_key: Optional[str] = None, + neighbors_key: str | None = None, copy: bool = False, -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Infer progression of observations through geodesic distance along the graph [Haghverdi16]_ [Wolf19]_. Reconstruct the progression of a biological process from snapshot @@ -562,12 +567,12 @@ def dpt( def paga( adata: AnnData, - groups: Optional[str] = None, + groups: str | None = None, use_rna_velocity: bool = False, model: Literal["v1.2", "v1.0"] = "v1.2", - neighbors_key: Optional[str] = None, + neighbors_key: str | None = None, copy: bool = False, -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Mapping out the coarse-grained connectivity structures of complex manifolds [Wolf19]_. By quantifying the connectivity of partitions (groups, clusters), @@ -626,13 +631,13 @@ def paga( def ingest( adata: AnnData, adata_ref: AnnData, - obs: Optional[Union[str, Iterable[str]]] = None, - embedding_method: Union[str, Iterable[str]] = ("umap", "pca"), + obs: str | Iterable[str] | None = None, + embedding_method: str | Iterable[str] = ("umap", "pca"), labeling_method: str = "knn", - neighbors_key: Optional[str] = None, + neighbors_key: str | None = None, inplace: bool = True, **kwargs, -) -> Optional[AnnData]: # pragma: no cover +) -> AnnData | None: # pragma: no cover """Map labels and embeddings from reference data to new data. Integrates embeddings and annotations of an `adata` with a reference dataset diff --git a/ehrapy/tools/causal/_dowhy.py b/ehrapy/tools/causal/_dowhy.py index 55916179..f972804b 100644 --- a/ehrapy/tools/causal/_dowhy.py +++ b/ehrapy/tools/causal/_dowhy.py @@ -244,7 +244,7 @@ def causal_inference( pval = "Not applicable" # Format effect, can be list when refuter is "add_unobserved_common_cause" - if isinstance(refute.new_effect, (list, tuple)): + if isinstance(refute.new_effect, list | tuple): new_effect = ", ".join([str(np.round(x, 2)) for x in refute.new_effect]) else: new_effect = f"{refute.new_effect:.3f}" diff --git a/ehrapy/tools/cohort_tracking/_cohort_tracker.py b/ehrapy/tools/cohort_tracking/_cohort_tracker.py index 4d92fb84..fc331550 100644 --- a/ehrapy/tools/cohort_tracking/_cohort_tracker.py +++ b/ehrapy/tools/cohort_tracking/_cohort_tracker.py @@ -390,7 +390,7 @@ def create_legend_with_subtitles(patches_list, subtitles_list, tot_legend_kwargs # there can be empty lists which distort the logic of matching patches to subtitles patches_list = [patch for patch in patches_list if patch] - for patches, subtitle in zip(patches_list, subtitles_list): + for patches, subtitle in zip(patches_list, subtitles_list, strict=False): handles.append(Line2D([], [], linestyle="none", marker="", alpha=0)) # Placeholder for title labels.append(subtitle) @@ -494,7 +494,7 @@ def plot_flowchart( tot_bbox_kwargs = {"boxstyle": "round,pad=0.3", "fc": "lightblue", "alpha": 0.5} if bbox_kwargs is not None: tot_bbox_kwargs.update(bbox_kwargs) - for _, (y, label) in enumerate(zip(y_positions, node_labels)): + for _, (y, label) in enumerate(zip(y_positions, node_labels, strict=False)): axes.annotate( label, xy=(0, y), diff --git a/ehrapy/tools/feature_ranking/_rank_features_groups.py b/ehrapy/tools/feature_ranking/_rank_features_groups.py index d78bfd1c..6fb3932c 100644 --- a/ehrapy/tools/feature_ranking/_rank_features_groups.py +++ b/ehrapy/tools/feature_ranking/_rank_features_groups.py @@ -107,7 +107,7 @@ def _save_rank_features_result( fields = (names, scores, pvals, pvals_adj, logfoldchanges, pts) field_names = ("names", "scores", "pvals", "pvals_adj", "logfoldchanges", "pts") - for values, key in zip(fields, field_names): + for values, key in zip(fields, field_names, strict=False): if values is None or not len(values): continue @@ -139,7 +139,7 @@ def _get_groups_order(groups_subset, group_names, reference): """ if groups_subset == "all": groups_order = group_names - elif isinstance(groups_subset, (str, int)): + elif isinstance(groups_subset, str | int): raise ValueError("Specify a sequence of groups") else: groups_order = list(groups_subset) diff --git a/pyproject.toml b/pyproject.toml index 517a940a..7d910479 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "ehrapy" version = "0.9.0" description = "Electronic Health Record Analysis with Python." readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10,<3.13" license = {file = "LICENSE"} authors = [ {name = "Lukas Heumos"}, @@ -38,7 +38,6 @@ classifiers = [ "Operating System :: MacOS :: MacOS X", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -66,7 +65,6 @@ dependencies = [ "fknni", "python-dateutil", "filelock", - "numpy<2.0.0" # for compatiblity with lifelines ] [project.optional-dependencies] @@ -101,8 +99,7 @@ docs = [ "nbsphinx-link", "ipykernel", "ipython", - "medcat", - "ehrapy[dask]", + "ehrapy[dask,medcat]", ] test = [ "ehrapy[dask]", @@ -135,7 +132,10 @@ markers = [ filterwarnings = [ "ignore::DeprecationWarning", "ignore::anndata.OldFormatWarning:", - "ignore:X converted to numpy array with dtype object:UserWarning" + "ignore:X converted to numpy array with dtype object:UserWarning", + "ignore:`flavor='seurat_v3'` expects raw count data, but non-integers were found:UserWarning", + "ignore:All-NaN slice encountered:RuntimeWarning", + "ignore:Observation names are not unique. To make them unique, call `.obs_names_make_unique`.:UserWarning" ] minversion = 6.0 norecursedirs = [ '.*', 'build', 'dist', '*.egg', 'data', '__pycache__'] From 03cd18036761394ec9ebb3504f451935a8c48a9c Mon Sep 17 00:00:00 2001 From: Lilly May <93096564+Lilly-May@users.noreply.github.com> Date: Thu, 28 Nov 2024 21:40:53 +0100 Subject: [PATCH 05/12] Python 3.12 support (#794) * Change CI to python 3.12 * Change notebook CI to python 3.12 * Updated submodule * Add graphviz dependency * Fixed estimator data retrieval * Obtain data based on dowhy version * Estimate data retrieval only for latest dowhy versions * Move graphviz installation * More latest python 3.12 Signed-off-by: zethson --------- Signed-off-by: zethson Co-authored-by: Lukas Heumos --- .github/workflows/build.yml | 6 ++---- .github/workflows/release.yml | 4 ++-- .github/workflows/run_notebooks.yml | 4 ++-- .github/workflows/test.yml | 4 ++-- .readthedocs.yml | 4 ++-- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index db548369..589dd3ea 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,12 +16,10 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v5 with: - python-version: "3.11" - cache: "pip" - cache-dependency-path: "**/pyproject.toml" + python-version: "3.12" - name: Install build dependencies run: python -m pip install --upgrade pip wheel twine build diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 993f9bb9..98b082b7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,10 +13,10 @@ jobs: - name: Checkout code uses: actions/checkout@v3 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Install hatch run: pip install hatch diff --git a/.github/workflows/run_notebooks.yml b/.github/workflows/run_notebooks.yml index 995b8dae..f45fd96b 100644 --- a/.github/workflows/run_notebooks.yml +++ b/.github/workflows/run_notebooks.yml @@ -26,13 +26,13 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Install UV run: pip install uv - name: Install ehrapy and additional dependencies - run: uv pip install --system . cellrank nbconvert ipykernel + run: uv pip install --system . cellrank nbconvert ipykernel graphviz - name: Install scvelo from Github run: uv pip install --system git+https://github.com/theislab/scvelo.git diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 87e57959..8800bb5a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,9 +24,9 @@ jobs: - os: ubuntu-latest python: "3.10" - os: ubuntu-latest - python: "3.11" + python: "3.12" - os: ubuntu-latest - python: "3.11" + python: "3.12" pip-flags: "--pre" env: diff --git a/.readthedocs.yml b/.readthedocs.yml index bac3abbe..4dc2dd86 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,8 +1,8 @@ version: 2 build: - os: ubuntu-22.04 + os: ubuntu-24.04 tools: - python: "3.11" + python: "3.12" commands: - asdf plugin add uv - asdf install uv latest From ee84d9e9b616a8b1b82a8a1624bdf9765cf96239 Mon Sep 17 00:00:00 2001 From: Eljas Roellin <65244425+eroell@users.noreply.github.com> Date: Mon, 2 Dec 2024 08:46:27 +0100 Subject: [PATCH 06/12] move_to_x: Fix name of non-implemented argument "copy" to "copy_x", implement & test (#832) * fix name of & implement missing arg in move_to_x * fix argument description --- ehrapy/anndata/anndata_ext.py | 9 ++++++--- tests/anndata/test_anndata_ext.py | 7 +++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/ehrapy/anndata/anndata_ext.py b/ehrapy/anndata/anndata_ext.py index fb420202..38d60224 100644 --- a/ehrapy/anndata/anndata_ext.py +++ b/ehrapy/anndata/anndata_ext.py @@ -252,13 +252,13 @@ def delete_from_obs(adata: AnnData, to_delete: list[str]) -> AnnData: return adata -def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: +def move_to_x(adata: AnnData, to_x: list[str] | str, copy_x: bool = False) -> AnnData: """Move features from obs to X inplace. Args: adata: The AnnData object to_x: The columns to move to X - copy: Whether to return a copy or not + copy_x: The values are copied to X (and therefore kept in obs) instead of moved completely Returns: A new AnnData object with moved columns from obs to X. This should not be used for datetime columns currently. @@ -292,7 +292,10 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: if cols_not_in_x: new_adata = concat([adata, AnnData(adata.obs[cols_not_in_x])], axis=1) - new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]] + if copy_x: + new_adata.obs = adata.obs + else: + new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]] # AnnData's concat discards var if they don't match in their keys, so we need to create a new var created_var = pd.DataFrame(index=cols_not_in_x) diff --git a/tests/anndata/test_anndata_ext.py b/tests/anndata/test_anndata_ext.py index aca4ca77..6e5bbf83 100644 --- a/tests/anndata/test_anndata_ext.py +++ b/tests/anndata/test_anndata_ext.py @@ -164,6 +164,13 @@ def test_move_to_x(adata_move_obs_mix): ) +def test_move_to_x_copy_x(adata_move_obs_mix): + move_to_obs(adata_move_obs_mix, ["name"], copy_obs=False) + obs_df = adata_move_obs_mix.obs.copy() + new_adata = move_to_x(adata_move_obs_mix, ["name"], copy_x=True) + assert_frame_equal(new_adata.obs, obs_df) + + def test_move_to_x_invalid_column_names(adata_move_obs_mix): move_to_obs(adata_move_obs_mix, ["name"], copy_obs=True) move_to_obs(adata_move_obs_mix, ["clinic_id"], copy_obs=False) From 861d76251830d4de71c301ee6b28d0917d0fdf81 Mon Sep 17 00:00:00 2001 From: Carl Buchholz <32228189+aGuyLearning@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:05:42 +0100 Subject: [PATCH 07/12] Improve survival analysis interface (#825) * updated kmf to match method signature * updated notebook * updated ehrapy tutorial commit * updated docu for new method signature * added outputs to survival analysis * correctly passing on fitting options * pull request fixes. - removed kwargs - updated documentation * added legacy suport * added kmf function legacy support in tests and added new kaplan_meier function in line with new signature * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated notebook * added stacklevel to deprecation warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added deprecation warning in comment * Update ehrapy/plot/_survival_analysis.py * Update ehrapy/plot/_survival_analysis.py * Update ehrapy/plot/_survival_analysis.py * Update ehrapy/plot/_survival_analysis.py * Update tests/tools/test_sa.py * doc adjustments * change name of kmf plot to kaplan_meier, some adjustments * introduce keyword only for univariate sa * correct docstring * update submodule * add lifelines intersphinx mappings * Update ehrapy/tools/_sa.py * Update ehrapy/tools/_sa.py * Update ehrapy/tools/_sa.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lukas Heumos Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> Co-authored-by: eroell --- docs/conf.py | 1 + docs/contributing.md | 2 +- docs/tutorials/notebooks | 2 +- docs/usage/usage.md | 4 +- ehrapy/plot/__init__.py | 2 +- ehrapy/plot/_survival_analysis.py | 70 +++++++++-- ehrapy/tools/__init__.py | 2 + ehrapy/tools/_sa.py | 190 ++++++++++++++++++++++++++++-- tests/tools/test_sa.py | 11 +- 9 files changed, 256 insertions(+), 28 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 761c2a22..dd46ddab 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -96,6 +96,7 @@ "flax": ("https://flax.readthedocs.io/en/latest/", None), "jax": ("https://jax.readthedocs.io/en/latest/", None), "lamin": ("https://lamin.ai/docs", None), + "lifelines": ("https://lifelines.readthedocs.io/en/latest/", None), } language = "en" diff --git a/docs/contributing.md b/docs/contributing.md index ab0d890b..ce5858eb 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -51,7 +51,7 @@ and [prettier][prettier-editors]. ## Writing tests ```{note} -Remember to first install the package with `pip install -e "[dev,test,docs]"` +Remember to first install the package with `pip install -e ".[dev,test,docs]"` ``` This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 99b17e70..ac088bca 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 99b17e7039699548a908433fa3ee6b5cbac5e29f +Subproject commit ac088bcabae5de8516ca9a5aa036b4e3cdf67df6 diff --git a/docs/usage/usage.md b/docs/usage/usage.md index c77593b0..6f3f2366 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -226,7 +226,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret tools.ols tools.glm - tools.kmf + tools.kaplan_meier tools.test_kmf_logrank tools.test_nested_f_statistic tools.cox_ph @@ -368,7 +368,7 @@ Methods that extract and visualize tool-specific annotation in an AnnData object :nosignatures: plot.ols - plot.kmf + plot.kaplan_meier ``` ### Causal Inference diff --git a/ehrapy/plot/__init__.py b/ehrapy/plot/__init__.py index 5ae52ab1..0c740e95 100644 --- a/ehrapy/plot/__init__.py +++ b/ehrapy/plot/__init__.py @@ -2,6 +2,6 @@ from ehrapy.plot._colormaps import * # noqa: F403 from ehrapy.plot._missingno_pl_api import * # noqa: F403 from ehrapy.plot._scanpy_pl_api import * # noqa: F403 -from ehrapy.plot._survival_analysis import kmf, ols +from ehrapy.plot._survival_analysis import kaplan_meier, kmf, ols from ehrapy.plot.causal_inference._dowhy import causal_effect from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised diff --git a/ehrapy/plot/_survival_analysis.py b/ehrapy/plot/_survival_analysis.py index bf74df85..717f9202 100644 --- a/ehrapy/plot/_survival_analysis.py +++ b/ehrapy/plot/_survival_analysis.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING import matplotlib.pyplot as plt @@ -38,7 +39,7 @@ def ols( ax: Axes | None = None, title: str | None = None, **kwds, -): +) -> Axes | None: """Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot. Args: @@ -134,6 +135,8 @@ def ols( if not show: return ax + else: + return None def kmf( @@ -152,7 +155,48 @@ def kmf( figsize: tuple[float, float] | None = None, show: bool | None = None, title: str | None = None, -): +) -> Axes | None: + warnings.warn( + "This function is deprecated and will be removed in the next release. Use `ep.pl.kaplan_meier` instead.", + DeprecationWarning, + stacklevel=2, + ) + return kaplan_meier( + kmfs=kmfs, + ci_alpha=ci_alpha, + ci_force_lines=ci_force_lines, + ci_show=ci_show, + ci_legend=ci_legend, + at_risk_counts=at_risk_counts, + color=color, + grid=grid, + xlim=xlim, + ylim=ylim, + xlabel=xlabel, + ylabel=ylabel, + figsize=figsize, + show=show, + title=title, + ) + + +def kaplan_meier( + kmfs: Sequence[KaplanMeierFitter], + ci_alpha: list[float] | None = None, + ci_force_lines: list[Boolean] | None = None, + ci_show: list[Boolean] | None = None, + ci_legend: list[Boolean] | None = None, + at_risk_counts: list[Boolean] | None = None, + color: list[str] | None | None = None, + grid: Boolean | None = False, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + xlabel: str | None = None, + ylabel: str | None = None, + figsize: tuple[float, float] | None = None, + show: bool | None = None, + title: str | None = None, +) -> Axes | None: """Plots a pretty figure of the Fitted KaplanMeierFitter model See https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html @@ -186,23 +230,21 @@ def kmf( # So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) - >>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X) - >>> ep.pl.kmf( + >>> kmf = ep.tl.kaplan_meier(adata, "mort_day_censored", "censor_flg") + >>> ep.pl.kaplan_meier( ... [kmf], color=["r"], xlim=[0, 700], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived", show=True ... ) .. image:: /_static/docstring_previews/kmf_plot_1.png - >>> T = adata[:, ["mort_day_censored"]].X - >>> E = adata[:, ["censor_flg"]].X >>> groups = adata[:, ["service_unit"]].X - >>> ix1 = groups == "FICU" - >>> ix2 = groups == "MICU" - >>> ix3 = groups == "SICU" - >>> kmf_1 = ep.tl.kmf(T[ix1], E[ix1], label="FICU") - >>> kmf_2 = ep.tl.kmf(T[ix2], E[ix2], label="MICU") - >>> kmf_3 = ep.tl.kmf(T[ix3], E[ix3], label="SICU") - >>> ep.pl.kmf([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'], + >>> adata_ficu = adata[groups == "FICU"] + >>> adata_micu = adata[groups == "MICU"] + >>> adata_sicu = adata[groups == "SICU"] + >>> kmf_1 = ep.tl.kaplan_meier(adata_ficu, "mort_day_censored", "censor_flg", label="FICU") + >>> kmf_2 = ep.tl.kaplan_meier(adata_micu, "mort_day_censored", "censor_flg", label="MICU") + >>> kmf_3 = ep.tl.kaplan_meier(adata_sicu, "mort_day_censored", "censor_flg", label="SICU") + >>> ep.pl.kaplan_meier([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'], >>> xlim=[0, 750], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived") .. image:: /_static/docstring_previews/kmf_plot_2.png @@ -251,3 +293,5 @@ def kmf( if not show: return ax + else: + return None diff --git a/ehrapy/tools/__init__.py b/ehrapy/tools/__init__.py index 5da8fa69..c034882f 100644 --- a/ehrapy/tools/__init__.py +++ b/ehrapy/tools/__init__.py @@ -2,6 +2,7 @@ anova_glm, cox_ph, glm, + kaplan_meier, kmf, log_logistic_aft, nelson_aalen, @@ -31,6 +32,7 @@ "cox_ph", "glm", "kmf", + "kaplan_meier", "log_logistic_aft", "nelson_aalen", "ols", diff --git a/ehrapy/tools/_sa.py b/ehrapy/tools/_sa.py index e23b6a43..fed63b9e 100644 --- a/ehrapy/tools/_sa.py +++ b/ehrapy/tools/_sa.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Literal import numpy as np # This package is implicitly used @@ -126,7 +127,9 @@ def kmf( weights: Iterable | None = None, censoring: Literal["right", "left"] = None, ) -> KaplanMeierFitter: - """Fit the Kaplan-Meier estimate for the survival function. + """DEPRECATION WARNING: This function is deprecated and will be removed in the next release. Use `kaplan_meier` instead. + + Fit the Kaplan-Meier estimate for the survival function. The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data. In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment. @@ -158,6 +161,12 @@ def kmf( >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X) """ + + warnings.warn( + "This function is deprecated and will be removed in the next release. Use `ep.tl.kaplan_meier` instead.", + DeprecationWarning, + stacklevel=2, + ) kmf = KaplanMeierFitter() if censoring == "None" or "right": kmf.fit( @@ -185,6 +194,71 @@ def kmf( return kmf +def kaplan_meier( + adata: AnnData, + duration_col: str, + event_col: str | None = None, + *, + timeline: list[float] | None = None, + entry: str | None = None, + label: str | None = None, + alpha: float | None = None, + ci_labels: list[str] | None = None, + weights: list[float] | None = None, + fit_options: dict | None = None, + censoring: Literal["right", "left"] = "right", +) -> KaplanMeierFitter: + """Fit the Kaplan-Meier estimate for the survival function. + + The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data. + In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment. + + See https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator + https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter + + Args: + adata: AnnData object with necessary columns `duration_col` and `event_col`. + duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. + event_col: The name of the column in anndata that contains the subjects’ death observation. + timeline: Return the best estimate at the values in timelines (positively increasing) + entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations. + If None, all members of the population entered study when they were "born". + label: A string to name the column of the estimate. + alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only. + ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [, ] (default: