diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml index a62f0e00..7d6e3aa2 100644 --- a/.github/release-drafter.yml +++ b/.github/release-drafter.yml @@ -1,5 +1,5 @@ -name-template: "0.9.0 🌈" -tag-template: 0.9.0 +name-template: "0.11.0 🌈" +tag-template: 0.11.0 exclude-labels: - "skip-changelog" diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index db548369..589dd3ea 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,12 +16,10 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v5 with: - python-version: "3.11" - cache: "pip" - cache-dependency-path: "**/pyproject.toml" + python-version: "3.12" - name: Install build dependencies run: python -m pip install --upgrade pip wheel twine build diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 993f9bb9..98b082b7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,10 +13,10 @@ jobs: - name: Checkout code uses: actions/checkout@v3 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Install hatch run: pip install hatch diff --git a/.github/workflows/run_notebooks.yml b/.github/workflows/run_notebooks.yml index 7f19f070..f45fd96b 100644 --- a/.github/workflows/run_notebooks.yml +++ b/.github/workflows/run_notebooks.yml @@ -13,7 +13,7 @@ jobs: "docs/tutorials/notebooks/ehrapy_introduction.ipynb", "docs/tutorials/notebooks/mimic_2_introduction.ipynb", "docs/tutorials/notebooks/mimic_2_survival_analysis.ipynb", - "docs/tutorials/notebooks/mimic_2_fate.ipynb", + # "docs/tutorials/notebooks/mimic_2_fate.ipynb", # https://github.com/theislab/cellrank/issues/1235 "docs/tutorials/notebooks/mimic_2_causal_inference.ipynb", # "docs/tutorials/notebooks/mimic_3_demo.ipynb", # "docs/tutorials/notebooks/medcat.ipynb", @@ -26,13 +26,16 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.11" + python-version: "3.12" - name: Install UV run: pip install uv - name: Install ehrapy and additional dependencies - run: uv pip install --system . cellrank nbconvert ipykernel + run: uv pip install --system . cellrank nbconvert ipykernel graphviz + + - name: Install scvelo from Github + run: uv pip install --system git+https://github.com/theislab/scvelo.git - name: Run ${{ matrix.notebook }} Notebook run: jupyter nbconvert --to notebook --execute ${{ matrix.notebook }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 87e57959..8800bb5a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,9 +24,9 @@ jobs: - os: ubuntu-latest python: "3.10" - os: ubuntu-latest - python: "3.11" + python: "3.12" - os: ubuntu-latest - python: "3.11" + python: "3.12" pip-flags: "--pre" env: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4b3d9285..423780fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.3 + rev: v0.8.2 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] diff --git a/.readthedocs.yml b/.readthedocs.yml index 5135ef47..4dc2dd86 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,18 +1,15 @@ version: 2 build: - os: ubuntu-22.04 + os: ubuntu-24.04 tools: - python: "3.11" - jobs: - pre_build: - - python -c "import ehrapy" - - pip freeze - post_create_environment: - - pip install uv - post_install: - # VIRTUAL_ENV needs to be set manually for now. - # See https://github.com/readthedocs/readthedocs.org/pull/11152/ - - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH pip install .[docs] + python: "3.12" + commands: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + - uv venv + - uv pip install .[docs] + - .venv/bin/python -m sphinx -T -b html -d docs/_build/doctrees -D language=en docs $READTHEDOCS_OUTPUT/html sphinx: configuration: docs/conf.py fail_on_warning: false diff --git a/docs/_ext/edit_on_github.py b/docs/_ext/edit_on_github.py index 85ed75de..746d2c41 100644 --- a/docs/_ext/edit_on_github.py +++ b/docs/_ext/edit_on_github.py @@ -20,7 +20,7 @@ def get_github_repo(app: Sphinx, path: str) -> str: def _html_page_context( - app: Sphinx, _pagename: str, templatename: str, context: dict[str, Any], doctree: Optional[Any] + app: Sphinx, _pagename: str, templatename: str, context: dict[str, Any], doctree: Any | None ) -> None: # doctree is None - otherwise viewcode fails if templatename != "page.html" or doctree is None: diff --git a/docs/conf.py b/docs/conf.py index 761c2a22..dd46ddab 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -96,6 +96,7 @@ "flax": ("https://flax.readthedocs.io/en/latest/", None), "jax": ("https://jax.readthedocs.io/en/latest/", None), "lamin": ("https://lamin.ai/docs", None), + "lifelines": ("https://lifelines.readthedocs.io/en/latest/", None), } language = "en" diff --git a/docs/contributing.md b/docs/contributing.md index ab0d890b..ce5858eb 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -51,7 +51,7 @@ and [prettier][prettier-editors]. ## Writing tests ```{note} -Remember to first install the package with `pip install -e "[dev,test,docs]"` +Remember to first install the package with `pip install -e ".[dev,test,docs]"` ``` This package uses the [pytest][] for automated testing. Please [write tests][scanpy-test-docs] for every function added diff --git a/docs/usage/usage.md b/docs/usage/usage.md index c77593b0..6f3f2366 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -226,7 +226,7 @@ In contrast to a preprocessing function, a tool usually adds an easily interpret tools.ols tools.glm - tools.kmf + tools.kaplan_meier tools.test_kmf_logrank tools.test_nested_f_statistic tools.cox_ph @@ -368,7 +368,7 @@ Methods that extract and visualize tool-specific annotation in an AnnData object :nosignatures: plot.ols - plot.kmf + plot.kaplan_meier ``` ### Causal Inference diff --git a/ehrapy/_settings.py b/ehrapy/_settings.py index 3547af0a..f733c059 100644 --- a/ehrapy/_settings.py +++ b/ehrapy/_settings.py @@ -53,7 +53,7 @@ def __init__( figdir: str | Path = "./figures/", cache_compression: str | None = "lzf", max_memory=15, - n_jobs: int = 1, + n_jobs: int = -1, logfile: str | Path | None = None, categories_to_ignore: Iterable[str] = ("N/A", "dontknow", "no_gate", "?"), _frameon: bool = True, diff --git a/ehrapy/core/_tool_available.py b/ehrapy/_utils_available.py similarity index 79% rename from ehrapy/core/_tool_available.py rename to ehrapy/_utils_available.py index 75153b41..7b116681 100644 --- a/ehrapy/core/_tool_available.py +++ b/ehrapy/_utils_available.py @@ -4,7 +4,7 @@ from subprocess import PIPE, Popen -def _check_module_importable(package: str) -> bool: # pragma: no cover +def _check_module_importable(package: str) -> bool: """Checks whether a module is installed and can be loaded. Args: @@ -19,7 +19,7 @@ def _check_module_importable(package: str) -> bool: # pragma: no cover return module_available -def _shell_command_accessible(command: list[str]) -> bool: # pragma: no cover +def _shell_command_accessible(command: list[str]) -> bool: """Checks whether the provided command is accessible in the current shell. Args: @@ -29,7 +29,7 @@ def _shell_command_accessible(command: list[str]) -> bool: # pragma: no cover True if the command is accessible, False otherwise. """ command_accessible = Popen(command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True) - (commmand_stdout, command_stderr) = command_accessible.communicate() + command_accessible.communicate() if command_accessible.returncode != 0: return False diff --git a/ehrapy/_doc_util.py b/ehrapy/_utils_doc.py similarity index 99% rename from ehrapy/_doc_util.py rename to ehrapy/_utils_doc.py index b0cf9587..3d6b07f7 100644 --- a/ehrapy/_doc_util.py +++ b/ehrapy/_utils_doc.py @@ -1,9 +1,9 @@ import inspect +from collections.abc import Callable from textwrap import dedent -from typing import Callable, Optional, Union -def getdoc(c_or_f: Union[Callable, type]) -> Optional[str]: # pragma: no cover +def getdoc(c_or_f: Callable | type) -> str | None: # pragma: no cover if getattr(c_or_f, "__doc__", None) is None: return None doc = inspect.getdoc(c_or_f) diff --git a/ehrapy/_utils_rendering.py b/ehrapy/_utils_rendering.py new file mode 100644 index 00000000..43596c54 --- /dev/null +++ b/ehrapy/_utils_rendering.py @@ -0,0 +1,21 @@ +import functools + +from rich.progress import Progress, SpinnerColumn + + +def spinner(message: str = "Running task"): + def wrap(func): + @functools.wraps(func) + def wrapped_f(*args, **kwargs): + with Progress( + "[progress.description]{task.description}", + SpinnerColumn(), + refresh_per_second=1500, + ) as progress: + progress.add_task(f"[blue]{message}", total=1) + result = func(*args, **kwargs) + return result + + return wrapped_f + + return wrap diff --git a/ehrapy/anndata/anndata_ext.py b/ehrapy/anndata/anndata_ext.py index 36bdd6a5..85e53f6c 100644 --- a/ehrapy/anndata/anndata_ext.py +++ b/ehrapy/anndata/anndata_ext.py @@ -3,7 +3,7 @@ import random from collections import OrderedDict from string import ascii_letters -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple import numpy as np import pandas as pd @@ -252,13 +252,13 @@ def delete_from_obs(adata: AnnData, to_delete: list[str]) -> AnnData: return adata -def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: +def move_to_x(adata: AnnData, to_x: list[str] | str, copy_x: bool = False) -> AnnData: """Move features from obs to X inplace. Args: adata: The AnnData object to_x: The columns to move to X - copy: Whether to return a copy or not + copy_x: The values are copied to X (and therefore kept in obs) instead of moved completely Returns: A new AnnData object with moved columns from obs to X. This should not be used for datetime columns currently. @@ -292,7 +292,10 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: if cols_not_in_x: new_adata = concat([adata, AnnData(adata.obs[cols_not_in_x])], axis=1) - new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]] + if copy_x: + new_adata.obs = adata.obs + else: + new_adata.obs = adata.obs[adata.obs.columns[~adata.obs.columns.isin(cols_not_in_x)]] # AnnData's concat discards var if they don't match in their keys, so we need to create a new var created_var = pd.DataFrame(index=cols_not_in_x) @@ -303,7 +306,7 @@ def move_to_x(adata: AnnData, to_x: list[str] | str) -> AnnData: return new_adata -def _get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[int]: +def get_column_indices(adata: AnnData, col_names: str | Iterable[str]) -> list[int]: """Fetches the column indices in X for a given list of column names Args: @@ -383,7 +386,10 @@ def set_numeric_vars( if copy: adata = adata.copy() - vars_idx = _get_column_indices(adata, vars) + vars_idx = get_column_indices(adata, vars) + + # if e.g. adata.X is of type int64, and values of dtype float64, the floats will be casted to int + adata.X = adata.X.astype(values.dtype) adata.X[:, vars_idx] = values @@ -404,7 +410,7 @@ def _detect_binary_columns(df: pd.DataFrame, numerical_columns: list[str]) -> li for column in numerical_columns: # checking for float and int as well as NaNs (this is safe since checked columns are numericals only) # only columns that contain at least one 0 and one 1 are counted as binary (or 0.0/1.0) - if df[column].isin([0.0, 1.0, np.NaN, 0, 1]).all() and df[column].nunique() == 2: + if df[column].isin([0.0, 1.0, np.nan, 0, 1]).all() and df[column].nunique() == 2: binary_columns.append(column) return binary_columns @@ -423,7 +429,7 @@ def _cast_obs_columns(obs: pd.DataFrame) -> pd.DataFrame: # type cast each non-numerical column to either bool (if possible) or category else obs[object_columns] = obs[object_columns].apply( lambda obs_name: obs_name.astype("category") - if not set(pd.unique(obs_name)).issubset({False, True, np.NaN}) + if not set(pd.unique(obs_name)).issubset({False, True, np.nan}) else obs_name.astype("bool"), axis=0, ) @@ -663,3 +669,49 @@ def get_rank_features_df( class NotEncodedError(AssertionError): pass + + +def _are_ndarrays_equal(arr1: np.ndarray, arr2: np.ndarray) -> np.bool_: + """Check if two arrays are equal member-wise. + + Note: Two NaN are considered equal. + + Args: + arr1: First array to compare + arr2: Second array to compare + + Returns: + True if the two arrays are equal member-wise + """ + return np.all(np.equal(arr1, arr2, dtype=object) | ((arr1 != arr1) & (arr2 != arr2))) + + +def _is_val_missing(data: np.ndarray) -> np.ndarray[Any, np.dtype[np.bool_]]: + """Check if values in a AnnData matrix are missing. + + Args: + data: The AnnData matrix to check + + Returns: + An array of bool representing the missingness of the original data, with the same shape + """ + return np.isin(data, [None, ""]) | (data != data) + + +def _to_dense_matrix(adata: AnnData, layer: str | None = None) -> np.ndarray: # pragma: no cover + """Extract a layer from an AnnData object and convert it to a dense matrix if required. + + Args: + adata: The AnnData where to extract the layer from. + layer: Name of the layer to extract. If omitted, X is considered. + + Returns: + The layer as a dense matrix. If a conversion was required, this function returns a copy of the original layer, + othersize this function returns a reference. + """ + from scipy.sparse import issparse + + if layer is None: + return adata.X.toarray() if issparse(adata.X) else adata.X + else: + return adata.layers[layer].toarray() if issparse(adata.layers[layer]) else adata.layers[layer] diff --git a/ehrapy/data/_datasets.py b/ehrapy/data/_datasets.py index 5719373a..7957d17c 100644 --- a/ehrapy/data/_datasets.py +++ b/ehrapy/data/_datasets.py @@ -743,7 +743,7 @@ def synthea_1k_sample( df = anndata_to_df(adata) df.drop( - columns=[col for col in df.columns if any(isinstance(x, (list, dict)) for x in df[col].dropna())], inplace=True + columns=[col for col in df.columns if any(isinstance(x, list | dict) for x in df[col].dropna())], inplace=True ) df.drop(columns=df.columns[df.isna().all()], inplace=True) adata = df_to_anndata(df, index_column="id") diff --git a/ehrapy/plot/__init__.py b/ehrapy/plot/__init__.py index 170b9fe9..70ef2e16 100644 --- a/ehrapy/plot/__init__.py +++ b/ehrapy/plot/__init__.py @@ -2,6 +2,6 @@ from ehrapy.plot._colormaps import * # noqa: F403 from ehrapy.plot._missingno_pl_api import * # noqa: F403 from ehrapy.plot._scanpy_pl_api import * # noqa: F403 -from ehrapy.plot._survival_analysis import kmf, ols, coxph_forestplot +from ehrapy.plot._survival_analysis import kaplan_meier, ols, coxph_forestplot from ehrapy.plot.causal_inference._dowhy import causal_effect from ehrapy.plot.feature_ranking._feature_importances import rank_features_supervised diff --git a/ehrapy/plot/_scanpy_pl_api.py b/ehrapy/plot/_scanpy_pl_api.py index d09ca8ed..4cf278d6 100644 --- a/ehrapy/plot/_scanpy_pl_api.py +++ b/ehrapy/plot/_scanpy_pl_api.py @@ -1,15 +1,15 @@ from __future__ import annotations -from collections.abc import Collection, Iterable, Mapping, Sequence +from collections.abc import Callable, Collection, Iterable, Mapping, Sequence from enum import Enum from functools import partial from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, Literal, Union +from typing import TYPE_CHECKING, Any, Literal import scanpy as sc from scanpy.plotting import DotPlot, MatrixPlot, StackedViolin -from ehrapy._doc_util import ( +from ehrapy._utils_doc import ( _doc_params, doc_adata_color_etc, doc_common_groupby_plot_args, @@ -36,12 +36,12 @@ from scanpy.plotting._utils import _AxesSubplot _Basis = Literal["pca", "tsne", "umap", "diffmap", "draw_graph_fr"] -_VarNames = Union[str, Sequence[str]] -ColorLike = Union[str, tuple[float, ...]] +_VarNames = str | Sequence[str] +ColorLike = str | tuple[float, ...] _IGraphLayout = Literal["fa", "fr", "rt", "rt_circular", "drl", "eq_tree", ...] # type: ignore _FontWeight = Literal["light", "normal", "medium", "semibold", "bold", "heavy", "black"] _FontSize = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"] -VBound = Union[str, float, Callable[[Sequence[float]], float]] +VBound = str | float | Callable[[Sequence[float]], float] @_doc_params(scatter_temp=doc_scatter_basic, show_save_ax=doc_show_save_ax) diff --git a/ehrapy/plot/_survival_analysis.py b/ehrapy/plot/_survival_analysis.py index 91d78d81..e4533477 100644 --- a/ehrapy/plot/_survival_analysis.py +++ b/ehrapy/plot/_survival_analysis.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING from lifelines import CoxPHFitter @@ -42,7 +43,7 @@ def ols( ax: Axes | None = None, title: str | None = None, **kwds, -): +) -> Axes | None: """Plots an Ordinary Least Squares (OLS) Model result, scatter plot, and line plot. Args: @@ -138,6 +139,8 @@ def ols( if not show: return ax + else: + return None def kmf( @@ -156,7 +159,48 @@ def kmf( figsize: tuple[float, float] | None = None, show: bool | None = None, title: str | None = None, -): +) -> Axes | None: + warnings.warn( + "This function is deprecated and will be removed in the next release. Use `ep.pl.kaplan_meier` instead.", + DeprecationWarning, + stacklevel=2, + ) + return kaplan_meier( + kmfs=kmfs, + ci_alpha=ci_alpha, + ci_force_lines=ci_force_lines, + ci_show=ci_show, + ci_legend=ci_legend, + at_risk_counts=at_risk_counts, + color=color, + grid=grid, + xlim=xlim, + ylim=ylim, + xlabel=xlabel, + ylabel=ylabel, + figsize=figsize, + show=show, + title=title, + ) + + +def kaplan_meier( + kmfs: Sequence[KaplanMeierFitter], + ci_alpha: list[float] | None = None, + ci_force_lines: list[Boolean] | None = None, + ci_show: list[Boolean] | None = None, + ci_legend: list[Boolean] | None = None, + at_risk_counts: list[Boolean] | None = None, + color: list[str] | None | None = None, + grid: Boolean | None = False, + xlim: tuple[float, float] | None = None, + ylim: tuple[float, float] | None = None, + xlabel: str | None = None, + ylabel: str | None = None, + figsize: tuple[float, float] | None = None, + show: bool | None = None, + title: str | None = None, +) -> Axes | None: """Plots a pretty figure of the Fitted KaplanMeierFitter model See https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html @@ -190,23 +234,21 @@ def kmf( # So we need to flip `censor_fl` when pass `censor_fl` to KaplanMeierFitter >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) - >>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X) - >>> ep.pl.kmf( + >>> kmf = ep.tl.kaplan_meier(adata, "mort_day_censored", "censor_flg") + >>> ep.pl.kaplan_meier( ... [kmf], color=["r"], xlim=[0, 700], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived", show=True ... ) .. image:: /_static/docstring_previews/kmf_plot_1.png - >>> T = adata[:, ["mort_day_censored"]].X - >>> E = adata[:, ["censor_flg"]].X >>> groups = adata[:, ["service_unit"]].X - >>> ix1 = groups == "FICU" - >>> ix2 = groups == "MICU" - >>> ix3 = groups == "SICU" - >>> kmf_1 = ep.tl.kmf(T[ix1], E[ix1], label="FICU") - >>> kmf_2 = ep.tl.kmf(T[ix2], E[ix2], label="MICU") - >>> kmf_3 = ep.tl.kmf(T[ix3], E[ix3], label="SICU") - >>> ep.pl.kmf([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'], + >>> adata_ficu = adata[groups == "FICU"] + >>> adata_micu = adata[groups == "MICU"] + >>> adata_sicu = adata[groups == "SICU"] + >>> kmf_1 = ep.tl.kaplan_meier(adata_ficu, "mort_day_censored", "censor_flg", label="FICU") + >>> kmf_2 = ep.tl.kaplan_meier(adata_micu, "mort_day_censored", "censor_flg", label="MICU") + >>> kmf_3 = ep.tl.kaplan_meier(adata_sicu, "mort_day_censored", "censor_flg", label="SICU") + >>> ep.pl.kaplan_meier([kmf_1, kmf_2, kmf_3], ci_show=[False,False,False], color=['k','r', 'g'], >>> xlim=[0, 750], ylim=[0, 1], xlabel="Days", ylabel="Proportion Survived") .. image:: /_static/docstring_previews/kmf_plot_2.png @@ -255,7 +297,9 @@ def kmf( if not show: return ax - + + else: + return None def coxph_forestplot(coxph: CoxPHFitter, labels: list[str] | None = None, @@ -352,4 +396,3 @@ def coxph_forestplot(coxph: CoxPHFitter, plot.spines["right"].set_visible(False) plot.spines["left"].set_visible(False) return fig, plot - diff --git a/ehrapy/preprocessing/_encoding.py b/ehrapy/preprocessing/_encoding.py index 1f761e71..3c3426aa 100644 --- a/ehrapy/preprocessing/_encoding.py +++ b/ehrapy/preprocessing/_encoding.py @@ -73,7 +73,7 @@ def encode( if isinstance(encodings, str) and not autodetect: raise ValueError("Passing a string for parameter encodings is only possible when using autodetect=True!") - elif autodetect and not isinstance(encodings, (str, type(None))): + elif autodetect and not isinstance(encodings, str | type(None)): raise ValueError( f"Setting encode mode with autodetect=True only works by passing a string (encode mode name) or None not {type(encodings)}!" ) @@ -630,7 +630,7 @@ def _update_obs(adata: AnnData, categorical_names: list[str]) -> pd.DataFrame: updated_obs[var_name] = adata.X[::, idx : idx + 1].flatten() # note: this will count binary columns (0 and 1 only) as well # needed for writing to .h5ad files - if set(pd.unique(updated_obs[var_name])).issubset({False, True, np.NaN}): + if set(pd.unique(updated_obs[var_name])).issubset({False, True, np.nan}): updated_obs[var_name] = updated_obs[var_name].astype("bool") # get all non bool object columns and cast them to category dtype object_columns = list(updated_obs.select_dtypes(include="object").columns) diff --git a/ehrapy/preprocessing/_imputation.py b/ehrapy/preprocessing/_imputation.py index 6ff60a2d..03796c2b 100644 --- a/ehrapy/preprocessing/_imputation.py +++ b/ehrapy/preprocessing/_imputation.py @@ -7,22 +7,20 @@ import numpy as np import pandas as pd from lamin_utils import logger -from rich import print -from rich.progress import Progress, SpinnerColumn -from sklearn.experimental import enable_iterative_imputer # required to enable IterativeImputer (experimental feature) +from sklearn.experimental import enable_iterative_imputer # noinspection PyUnresolvedReference from sklearn.impute import SimpleImputer -from sklearn.preprocessing import OrdinalEncoder from ehrapy import settings +from ehrapy._utils_available import _check_module_importable +from ehrapy._utils_rendering import spinner from ehrapy.anndata import check_feature_types -from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY -from ehrapy.anndata.anndata_ext import _get_column_indices -from ehrapy.core._tool_available import _check_module_importable +from ehrapy.anndata.anndata_ext import get_column_indices if TYPE_CHECKING: from anndata import AnnData +@spinner("Performing explicit impute") def explicit_impute( adata: AnnData, replacement: (str | int) | (dict[str, str | int]), @@ -30,7 +28,7 @@ def explicit_impute( impute_empty_strings: bool = True, warning_threshold: int = 70, copy: bool = False, -) -> AnnData: +) -> AnnData | None: """Replaces all missing values in all columns or a subset of columns specified by the user with the passed replacement value. There are two scenarios to cover: @@ -47,7 +45,7 @@ def explicit_impute( Returns: If copy is True, a modified copy of the original AnnData object with imputed X. - If copy is False, the original AnnData object is modified in place. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: Replace all missing values in adata with the value 0: @@ -56,7 +54,7 @@ def explicit_impute( >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.explicit_impute(adata, replacement=0) """ - if copy: # pragma: no cover + if copy: adata = adata.copy() if isinstance(replacement, int) or isinstance(replacement, str): @@ -64,32 +62,25 @@ def explicit_impute( else: _warn_imputation_threshold(adata, var_names=replacement.keys(), threshold=warning_threshold) # type: ignore - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running explicit imputation", total=1) - # 1: Replace all missing values with the specified value - if isinstance(replacement, (int, str)): - _replace_explicit(adata.X, replacement, impute_empty_strings) - - # 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named - elif isinstance(replacement, dict): - for idx, column_name in enumerate(adata.var_names): - imputation_value = _extract_impute_value(replacement, column_name) - # only replace if an explicit value got passed or could be extracted from replacement - if imputation_value: - _replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings) - else: - logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.") - else: - raise ValueError( # pragma: no cover - f"Type {type(replacement)} is not a valid datatype for replacement parameter. Either use int, str or a dict!" - ) + # 1: Replace all missing values with the specified value + if isinstance(replacement, int | str): + _replace_explicit(adata.X, replacement, impute_empty_strings) + + # 2: Replace all missing values in a subset of columns with a specified value per column or a default value, when the column is not explicitly named + elif isinstance(replacement, dict): + for idx, column_name in enumerate(adata.var_names): + imputation_value = _extract_impute_value(replacement, column_name) + # only replace if an explicit value got passed or could be extracted from replacement + if imputation_value: + _replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings) + else: + logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.") + else: + raise ValueError( # pragma: no cover + f"Type {type(replacement)} is not a valid datatype for replacement parameter. Either use int, str or a dict!" + ) - if copy: - return adata + return adata if copy else None def _replace_explicit(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> None: @@ -119,6 +110,7 @@ def _extract_impute_value(replacement: dict[str, str | int], column_name: str) - return None +@spinner("Performing simple impute") def simple_impute( adata: AnnData, var_names: Iterable[str] | None = None, @@ -126,9 +118,12 @@ def simple_impute( strategy: Literal["mean", "median", "most_frequent"] = "mean", copy: bool = False, warning_threshold: int = 70, -) -> AnnData: +) -> AnnData | None: """Impute missing values in numerical data using mean/median/most frequent imputation. + If required and using mean or median strategy, the data needs to be properly encoded as this imputation requires + numerical data only. + Args: adata: The annotated data matrix to impute missing values on. var_names: A list of column names to apply imputation on (if None, impute all columns). @@ -137,13 +132,8 @@ def simple_impute( copy:Whether to return a copy of `adata` or modify it inplace. Returns: - An updated AnnData object with imputed values. - - Raises: - ValueError: - If the selected imputation strategy is not applicable to the data. - ValueError: - If an unknown imputation strategy is provided. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep @@ -155,43 +145,35 @@ def simple_impute( _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task(f"[blue]Running simple imputation with {strategy}", total=1) - # Imputation using median and mean strategy works with numerical data only - if strategy in {"median", "mean"}: - try: - _simple_impute(adata, var_names, strategy) - except ValueError: - raise ValueError( - f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation" - "to certain columns using var_names parameter or use a different mode." - ) from None - # most_frequent imputation works with non-numerical data as well - elif strategy == "most_frequent": + if strategy in {"median", "mean"}: + try: _simple_impute(adata, var_names, strategy) - # unknown simple imputation strategy - else: - raise ValueError( # pragma: no cover - f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent." + except ValueError: + raise ValueError( + f"Can only impute numerical data using {strategy} strategy. Try to restrict imputation " + "to certain columns using var_names parameter or use a different mode." ) from None + # most_frequent imputation works with non-numerical data as well + elif strategy == "most_frequent": + _simple_impute(adata, var_names, strategy) + else: + raise ValueError( + f"Unknown impute strategy {strategy} for simple Imputation. Choose any of mean, median or most_frequent." + ) from None - if copy: - return adata + return adata if copy else None def _simple_impute(adata: AnnData, var_names: Iterable[str] | None, strategy: str) -> None: imputer = SimpleImputer(strategy=strategy) - if isinstance(var_names, Iterable): - column_indices = _get_column_indices(adata, var_names) + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): + column_indices = get_column_indices(adata, var_names) adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices]) else: adata.X = imputer.fit_transform(adata.X) +@spinner("Performing KNN impute") @check_feature_types def knn_impute( adata: AnnData, @@ -206,9 +188,7 @@ def knn_impute( ) -> AnnData: """Imputes missing values in the input AnnData object using K-nearest neighbor imputation. - When using KNN Imputation with mixed data (non-numerical and numerical), encoding using ordinal encoding is required - since KNN Imputation can only work on numerical data. The encoding itself is just a utility and will be undone once - imputation ran successfully. + If required, the data needs to be properly encoded as this imputation requires numerical data only. .. warning:: Currently, both `n_neighbours` and `n_neighbors` are accepted as parameters for the number of neighbors. @@ -234,10 +214,8 @@ def knn_impute( kwargs: Gathering keyword arguments of earlier ehrapy versions for backwards compatibility. It is encouraged to use the here listed, current arguments. Returns: - An updated AnnData object with imputed values. - - Raises: - ValueError: If the input data matrix contains only categorical (non-numeric) values. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep @@ -274,40 +252,26 @@ def knn_impute( from sklearnex import patch_sklearn, unpatch_sklearn patch_sklearn() + try: - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running KNN imputation", total=1) - # numerical only data needs no encoding since KNN Imputation can be applied directly - if np.issubdtype(adata.X.dtype, np.number): - _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) - else: - # ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation - enc = OrdinalEncoder() - column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG - adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) - # impute the data using KNN imputation - _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) - # imputing on encoded columns might result in float numbers; those can not be decoded - # cast them to int to ensure they can be decoded - adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int) - # knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding - adata.X = adata.X.astype("object") - # decode ordinal encoding to obtain imputed original data - adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + if np.issubdtype(adata.X.dtype, np.number): + _knn_impute(adata, var_names, n_neighbors, backend=backend, **backend_kwargs) + else: + # Raise exception since non-numerical data can not be imputed using KNN Imputation + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using " + "var_names parameter or perform an encoding of your data." + ) + except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.error("Check that your matrix does not contain any NaN only columns!") - raise + raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() - if copy: - return adata + return adata if copy else None def _knn_impute( @@ -326,8 +290,8 @@ def _knn_impute( imputer = FaissImputer(n_neighbors=n_neighbors, **kwargs) - if isinstance(var_names, Iterable): - column_indices = _get_column_indices(adata, var_names) + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): + column_indices = get_column_indices(adata, var_names) adata.X[::, column_indices] = imputer.fit_transform(adata.X[::, column_indices]) # this is required since X dtype has to be numerical in order to correctly round floats adata.X = adata.X.astype("float64") @@ -335,17 +299,18 @@ def _knn_impute( adata.X = imputer.fit_transform(adata.X) +@spinner("Performing miss-forest impute") def miss_forest_impute( adata: AnnData, - var_names: dict[str, list[str]] | list[str] | None = None, + var_names: Iterable[str] | None = None, *, num_initial_strategy: Literal["mean", "median", "most_frequent", "constant"] = "mean", max_iter: int = 3, - n_estimators=100, + n_estimators: int = 100, random_state: int = 0, warning_threshold: int = 70, copy: bool = False, -) -> AnnData: +) -> AnnData | None: """Impute data using the MissForest strategy. This function uses the MissForest strategy to impute missing values in the data matrix of an AnnData object. @@ -353,12 +318,12 @@ def miss_forest_impute( and using the trained model to predict the missing values. See https://academic.oup.com/bioinformatics/article/28/1/112/219101. - This requires the computation of which columns in X contain numerical only (including NaNs) and which contain non-numerical data. + + If required, the data needs to be properly encoded as this imputation requires numerical data only. Args: adata: The AnnData object to use MissForest Imputation on. - var_names: List of columns to impute or a dict with two keys ('numerical' and 'non_numerical') indicating which var - contain mixed data and which numerical data only. + var_names: Iterable of columns to impute num_initial_strategy: The initial strategy to replace all missing numerical values with. max_iter: The maximum number of iterations if the stop criterion has not been met yet. n_estimators: The number of trees to fit for every missing variable. Has a big effect on the run time. @@ -368,21 +333,20 @@ def miss_forest_impute( copy: Whether to return a copy or act in place. Returns: - The imputed (but unencoded) AnnData object. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=True) >>> ep.pp.miss_forest_impute(adata) """ - if copy: # pragma: no cover + if copy: adata = adata.copy() if var_names is None: _warn_imputation_threshold(adata, list(adata.var_names), threshold=warning_threshold) - elif isinstance(var_names, dict): - _warn_imputation_threshold(adata, var_names.keys(), threshold=warning_threshold) # type: ignore - elif isinstance(var_names, list): + elif isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) if _check_module_importable("sklearnex"): # pragma: no cover @@ -394,74 +358,49 @@ def miss_forest_impute( from sklearn.impute import IterativeImputer try: - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running MissForest imputation", total=1) - - if settings.n_jobs == 1: # pragma: no cover - logger.warning("The number of jobs is only 1. To decrease the runtime set ep.settings.n_jobs=-1.") - - imp_num = IterativeImputer( - estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), - initial_strategy=num_initial_strategy, - max_iter=max_iter, - random_state=random_state, - ) - # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data - imp_cat = IterativeImputer( - estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), - initial_strategy="most_frequent", - max_iter=max_iter, - random_state=random_state, + imp_num = IterativeImputer( + estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs), + initial_strategy=num_initial_strategy, + max_iter=max_iter, + random_state=random_state, + ) + # initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data + IterativeImputer( + estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs), + initial_strategy="most_frequent", + max_iter=max_iter, + random_state=random_state, + ) + + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): # type: ignore + num_indices = get_column_indices(adata, var_names) + else: + num_indices = get_column_indices(adata, adata.var_names) + + if set(num_indices).issubset(_get_non_numerical_column_indices(adata.X)): + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using " + "var_names parameter." ) - if isinstance(var_names, list): - var_indices = _get_column_indices(adata, var_names) # type: ignore - adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices]) - elif isinstance(var_names, dict) or var_names is None: - if var_names: - try: - non_num_vars = var_names["non_numerical"] - num_vars = var_names["numerical"] - except KeyError: # pragma: no cover - raise ValueError( - "One or both of your keys provided for var_names are unknown. Only " - "numerical and non_numerical are available!" - ) from None - non_num_indices = _get_column_indices(adata, non_num_vars) - num_indices = _get_column_indices(adata, num_vars) - - # infer non numerical and numerical indices automatically - else: - non_num_indices_set = _get_non_numerical_column_indices(adata.X) - num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set] - non_num_indices = list(non_num_indices_set) - - # encode all non numerical columns - if non_num_indices: - enc = OrdinalEncoder() - adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices]) - # this step is the most expensive one and might extremely slow down the impute process - if num_indices: - adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices]) - if non_num_indices: - adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices]) - adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices]) + # this step is the most expensive one and might extremely slow down the impute process + if num_indices: + adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices]) + else: + raise ValueError("Cannot find any feature to perform imputation") + except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.error("Check that your matrix does not contain any NaN only columns!") - raise + raise if _check_module_importable("sklearnex"): # pragma: no cover unpatch_sklearn() - if copy: - return adata + return adata if copy else None +@spinner("Performing mice-forest impute") @check_feature_types def mice_forest_impute( adata: AnnData, @@ -475,12 +414,14 @@ def mice_forest_impute( variable_parameters: dict | None = None, verbose: bool = False, copy: bool = False, -) -> AnnData: +) -> AnnData | None: """Impute data using the miceforest. See https://github.com/AnotherSamWilson/miceforest Fast, memory efficient Multiple Imputation by Chained Equations (MICE) with lightgbm. + If required, the data needs to be properly encoded as this imputation requires numerical data only. + Args: adata: The AnnData object containing the data to impute. var_names: A list of variable names to impute. If None, impute all variables. @@ -497,7 +438,8 @@ def mice_forest_impute( copy: Whether to return a copy of the AnnData object or modify it in-place. Returns: - The imputed AnnData object. + If copy is True, a modified copy of the original AnnData object with imputed X. + If copy is False, the original AnnData object is modified in place, and None is returned. Examples: >>> import ehrapy as ep @@ -509,49 +451,31 @@ def mice_forest_impute( adata = adata.copy() _warn_imputation_threshold(adata, var_names, threshold=warning_threshold) + try: - with Progress( - "[progress.description]{task.description}", - SpinnerColumn(), - refresh_per_second=1500, - ) as progress: - progress.add_task("[blue]Running miceforest", total=1) - if np.issubdtype(adata.X.dtype, np.number): - _miceforest_impute( - adata, - var_names, - save_all_iterations_data, - random_state, - inplace, - iterations, - variable_parameters, - verbose, - ) - else: - # ordinal encoding is used since non-numerical data can not be imputed using miceforest - enc = OrdinalEncoder() - column_indices = adata.var[FEATURE_TYPE_KEY] == CATEGORICAL_TAG - adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices]) - # impute the data using miceforest - _miceforest_impute( - adata, - var_names, - save_all_iterations_data, - random_state, - inplace, - iterations, - variable_parameters, - verbose, - ) - adata.X = adata.X.astype("object") - # decode ordinal encoding to obtain imputed original data - adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices]) + if np.issubdtype(adata.X.dtype, np.number): + _miceforest_impute( + adata, + var_names, + save_all_iterations_data, + random_state, + inplace, + iterations, + variable_parameters, + verbose, + ) + else: + raise ValueError( + "Can only impute numerical data. Try to restrict imputation to certain columns using " + "var_names parameter." + ) + except ValueError as e: if "Data matrix has wrong shape" in str(e): logger.warning("Check that your matrix does not contain any NaN only columns!") - raise + raise - return adata + return adata if copy else None def _miceforest_impute( @@ -562,8 +486,8 @@ def _miceforest_impute( data_df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names) data_df = data_df.apply(pd.to_numeric, errors="coerce") - if isinstance(var_names, Iterable): - column_indices = _get_column_indices(adata, var_names) + if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names): + column_indices = get_column_indices(adata, var_names) selected_columns = data_df.iloc[:, column_indices] selected_columns = selected_columns.reset_index(drop=True) @@ -616,27 +540,21 @@ def _warn_imputation_threshold(adata: AnnData, var_names: Iterable[str] | None, return var_name_to_pct -def _get_non_numerical_column_indices(X: np.ndarray) -> set: +def _get_non_numerical_column_indices(arr: np.ndarray) -> set: """Return indices of columns, that contain at least one non-numerical value that is not "Nan".""" - def _is_float_or_nan(val): # pragma: no cover + def _is_float_or_nan(val) -> bool: # pragma: no cover """Check whether a given item is a float or np.nan""" try: - float(val) - except ValueError: - if val is np.nan: - return True + _ = float(val) + return not isinstance(val, bool) + except (ValueError, TypeError): return False - else: - if not isinstance(val, bool): - return True - else: - return False - is_numeric_numpy = np.vectorize(_is_float_or_nan, otypes=[bool]) - mask = np.apply_along_axis(is_numeric_numpy, 0, X) + def _is_float_or_nan_row(row) -> list[bool]: # pragma: no cover + return [_is_float_or_nan(val) for val in row] + mask = np.apply_along_axis(_is_float_or_nan_row, 0, arr) _, column_indices = np.where(~mask) - non_num_indices = set(column_indices) - return non_num_indices + return set(column_indices) diff --git a/ehrapy/preprocessing/_normalization.py b/ehrapy/preprocessing/_normalization.py index 6f3d5e1a..de6cf646 100644 --- a/ehrapy/preprocessing/_normalization.py +++ b/ehrapy/preprocessing/_normalization.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING import numpy as np import sklearn.preprocessing as sklearn_pp @@ -13,14 +13,14 @@ daskml_pp = None from ehrapy.anndata.anndata_ext import ( - _get_column_indices, assert_numeric_vars, + get_column_indices, get_numeric_vars, set_numeric_vars, ) if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence import pandas as pd from anndata import AnnData @@ -48,7 +48,7 @@ def _scale_func_group( adata = _prep_adata_norm(adata, copy) - var_idx = _get_column_indices(adata, vars) + var_idx = get_column_indices(adata, vars) var_values = np.take(adata.X, var_idx, axis=1) if group_key is None: @@ -379,7 +379,7 @@ def log_norm( "or offset negative values with ep.pp.offset_negative_values()." ) - var_idx = _get_column_indices(adata, vars) + var_idx = get_column_indices(adata, vars) var_values = np.take(adata.X, var_idx, axis=1) if offset == 1: diff --git a/ehrapy/preprocessing/_scanpy_pp_api.py b/ehrapy/preprocessing/_scanpy_pp_api.py index 5317530e..e0e50221 100644 --- a/ehrapy/preprocessing/_scanpy_pp_api.py +++ b/ehrapy/preprocessing/_scanpy_pp_api.py @@ -1,7 +1,8 @@ from __future__ import annotations +from collections.abc import Callable from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, Literal, Union +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np import scanpy as sc @@ -15,7 +16,7 @@ from ehrapy.preprocessing._types import KnownTransformer -AnyRandom = Union[int, np.random.RandomState, None] +AnyRandom: TypeAlias = int | np.random.RandomState | None def pca( @@ -193,7 +194,7 @@ def combat( "sqeuclidean", "yule", ] -_Metric = Union[_MetricSparseCapable, _MetricScipySpatial] +_Metric = _MetricSparseCapable | _MetricScipySpatial def neighbors( diff --git a/ehrapy/tools/__init__.py b/ehrapy/tools/__init__.py index 5da8fa69..c034882f 100644 --- a/ehrapy/tools/__init__.py +++ b/ehrapy/tools/__init__.py @@ -2,6 +2,7 @@ anova_glm, cox_ph, glm, + kaplan_meier, kmf, log_logistic_aft, nelson_aalen, @@ -31,6 +32,7 @@ "cox_ph", "glm", "kmf", + "kaplan_meier", "log_logistic_aft", "nelson_aalen", "ols", diff --git a/ehrapy/tools/_method_options.py b/ehrapy/tools/_method_options.py index 315f47f7..5bcad67c 100644 --- a/ehrapy/tools/_method_options.py +++ b/ehrapy/tools/_method_options.py @@ -1,11 +1,11 @@ -from typing import Literal, Optional +from typing import Literal _InitPos = Literal["paga", "spectral", "random"] _LAYOUTS = ("fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa") _Layout = Literal[_LAYOUTS] # type: ignore -_rank_features_groups_method = Optional[Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"]] +_rank_features_groups_method = Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"] | None _correction_method = Literal["benjamini-hochberg", "bonferroni"] _rank_features_groups_cat_method = Literal[ "chi-square", "g-test", "freeman-tukey", "mod-log-likelihood", "neyman", "cressie-read" diff --git a/ehrapy/tools/_sa.py b/ehrapy/tools/_sa.py index e23b6a43..fed63b9e 100644 --- a/ehrapy/tools/_sa.py +++ b/ehrapy/tools/_sa.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Literal import numpy as np # This package is implicitly used @@ -126,7 +127,9 @@ def kmf( weights: Iterable | None = None, censoring: Literal["right", "left"] = None, ) -> KaplanMeierFitter: - """Fit the Kaplan-Meier estimate for the survival function. + """DEPRECATION WARNING: This function is deprecated and will be removed in the next release. Use `kaplan_meier` instead. + + Fit the Kaplan-Meier estimate for the survival function. The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data. In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment. @@ -158,6 +161,12 @@ def kmf( >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> kmf = ep.tl.kmf(adata[:, ["mort_day_censored"]].X, adata[:, ["censor_flg"]].X) """ + + warnings.warn( + "This function is deprecated and will be removed in the next release. Use `ep.tl.kaplan_meier` instead.", + DeprecationWarning, + stacklevel=2, + ) kmf = KaplanMeierFitter() if censoring == "None" or "right": kmf.fit( @@ -185,6 +194,71 @@ def kmf( return kmf +def kaplan_meier( + adata: AnnData, + duration_col: str, + event_col: str | None = None, + *, + timeline: list[float] | None = None, + entry: str | None = None, + label: str | None = None, + alpha: float | None = None, + ci_labels: list[str] | None = None, + weights: list[float] | None = None, + fit_options: dict | None = None, + censoring: Literal["right", "left"] = "right", +) -> KaplanMeierFitter: + """Fit the Kaplan-Meier estimate for the survival function. + + The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data. + In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment. + + See https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator + https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter + + Args: + adata: AnnData object with necessary columns `duration_col` and `event_col`. + duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. + event_col: The name of the column in anndata that contains the subjects’ death observation. + timeline: Return the best estimate at the values in timelines (positively increasing) + entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations. + If None, all members of the population entered study when they were "born". + label: A string to name the column of the estimate. + alpha: The alpha value in the confidence intervals. Overrides the initializing alpha for this call to fit only. + ci_labels: Add custom column names to the generated confidence intervals as a length-2 list: [, ] (default: