From 228ac1cf402623a2dd3f4c8b6ed26b37b2c7204b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Dec 2024 11:56:39 +0100 Subject: [PATCH 01/10] [pre-commit.ci] pre-commit autoupdate (#841) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.2 → v0.8.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.2...v0.8.3) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix doc build --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: eroell --- .github/pull_request_template.md | 8 ++++---- .pre-commit-config.yaml | 2 +- CODE_OF_CONDUCT.md | 28 ++++++++++++++-------------- README.md | 8 ++++---- docs/contributing.md | 18 +++++++++--------- docs/index.md | 6 +++--- docs/installation.md | 8 ++++---- pyproject.toml | 2 +- 8 files changed, 40 insertions(+), 40 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 4214c3b8..0bafff61 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -4,10 +4,10 @@ -- [ ] This comment contains a description of changes (with reason) -- [ ] Referenced issue is linked -- [ ] If you've fixed a bug or added code that should be tested, add tests! -- [ ] Documentation in `docs` is updated +- [ ] This comment contains a description of changes (with reason) +- [ ] Referenced issue is linked +- [ ] If you've fixed a bug or added code that should be tested, add tests! +- [ ] Documentation in `docs` is updated **Description of changes** diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 423780fb..12e5ab31 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.2 + rev: v0.8.3 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index d6209cca..39816a93 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -14,23 +14,23 @@ religion, or sexual identity and orientation. Examples of behavior that contributes to creating a positive environment include: -- Using welcoming and inclusive language -- Being respectful of differing viewpoints and experiences -- Gracefully accepting constructive criticism -- Focusing on what is best for the community -- Showing empathy towards other community members +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members Examples of unacceptable behavior by participants include: -- The use of sexualized language or imagery and unwelcome sexual - attention or advances -- Trolling, insulting/derogatory comments, and personal or political - attacks -- Public or private harassment -- Publishing others’ private information, such as a physical or - electronic address, without explicit permission -- Other conduct which could reasonably be considered inappropriate in a - professional setting +- The use of sexualized language or imagery and unwelcome sexual + attention or advances +- Trolling, insulting/derogatory comments, and personal or political + attacks +- Public or private harassment +- Publishing others’ private information, such as a physical or + electronic address, without explicit permission +- Other conduct which could reasonably be considered inappropriate in a + professional setting ## Our Responsibilities diff --git a/README.md b/README.md index 32e66dec..6c4533b4 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,10 @@ ## Features -- Exploratory and targeted analysis of Electronic Health Records -- Quality control & preprocessing -- Visualization & Exploration -- Clustering & trajectory inference +- Exploratory and targeted analysis of Electronic Health Records +- Quality control & preprocessing +- Visualization & Exploration +- Clustering & trajectory inference ## Installation diff --git a/docs/contributing.md b/docs/contributing.md index ce5858eb..0a5b318e 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -126,11 +126,11 @@ in the cookiecutter-scverse template. Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features: -- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text -- Google-style docstrings -- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) -- [Sphinx autodoc typehints][], to automatically reference annotated input and output types -- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/) +- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text +- Google-style docstrings +- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) +- [Sphinx autodoc typehints][], to automatically reference annotated input and output types +- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/) See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information on how to write documentation. @@ -144,10 +144,10 @@ These notebooks come from [pert-tutorials](https://github.com/theislab/ehrapy-tu #### Hints -- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only - if you do so can sphinx automatically create a link to the external documentation. -- If building the documentation fails because of a missing link that is outside your control, you can add an entry to - the `nitpick_ignore` list in `docs/conf.py` +- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only + if you do so can sphinx automatically create a link to the external documentation. +- If building the documentation fails because of a missing link that is outside your control, you can add an entry to + the `nitpick_ignore` list in `docs/conf.py` #### Building the docs locally diff --git a/docs/index.md b/docs/index.md index 56cc3037..03a0987d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -61,8 +61,8 @@ medRxiv 2023.12.11.23299816; doi: https://doi.org/10.1101/2023.12.11.23299816 ]( # Indices and tables -- {ref}`genindex` -- {ref}`modindex` -- {ref}`search` +- {ref}`genindex` +- {ref}`modindex` +- {ref}`search` [scanpy genome biology (2018)]: https://doi.org/10.1186/s13059-017-1382-0 diff --git a/docs/installation.md b/docs/installation.md index ba7010a9..b349394e 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -51,10 +51,10 @@ pip install ehrapy[medcat] Available language models are -- en_core_web_md (python -m spacy download en_core_web_md) -- en-core-sci-sm (pip install ) -- en-core-sci-md (pip install ) -- en-core-sci-lg (pip install ) +- en_core_web_md (python -m spacy download en_core_web_md) +- en-core-sci-sm (pip install ) +- en-core-sci-md (pip install ) +- en-core-sci-lg (pip install ) [github repo]: https://github.com/theislab/ehrapy [pip]: https://pip.pypa.io diff --git a/pyproject.toml b/pyproject.toml index 78931dcc..7bc35d29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ docs = [ "nbsphinx-link", "ipykernel", "ipython", - "ehrapy[dask,medcat]", + "ehrapy[dask]", ] test = [ "ehrapy[dask]", From 4903d6a6586ac3310320f8842eea59f257bcb504 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Jan 2025 07:07:05 -0800 Subject: [PATCH 02/10] [pre-commit.ci] pre-commit autoupdate (#843) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.3 → v0.8.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.3...v0.8.4) - [github.com/pre-commit/mirrors-mypy: v1.13.0 → v1.14.0](https://github.com/pre-commit/mirrors-mypy/compare/v1.13.0...v1.14.0) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 12e5ab31..d16fc187 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.3 + rev: v0.8.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] @@ -27,7 +27,7 @@ repos: - id: trailing-whitespace - id: check-case-conflict - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.13.0 + rev: v1.14.0 hooks: - id: mypy args: [--no-strict-optional, --ignore-missing-imports] From 8e00c45921534427f41b7fc2196453d78a6580cc Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Mon, 6 Jan 2025 15:19:40 +0100 Subject: [PATCH 03/10] Fix CI (#844) * Fix CI Signed-off-by: zethson * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix CI Signed-off-by: zethson * Fix CI Signed-off-by: zethson * Fix CI Signed-off-by: zethson --------- Signed-off-by: zethson Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 +- ehrapy/__init__.py | 5 ++ ehrapy/core/meta_information.py | 49 ++++--------------- .../feature_ranking/_rank_features_groups.py | 3 +- pyproject.toml | 5 +- tests/anndata/test_anndata_ext.py | 4 +- tests/preprocessing/test_normalization.py | 22 ++++----- 7 files changed, 34 insertions(+), 58 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d16fc187..fdddc766 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,8 +6,8 @@ default_stages: - pre-push minimum_pre_commit_version: 2.16.0 repos: - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 + - repo: https://github.com/rbubley/mirrors-prettier + rev: v3.4.2 hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/ehrapy/__init__.py b/ehrapy/__init__.py index 53d9b85e..0255cd2a 100644 --- a/ehrapy/__init__.py +++ b/ehrapy/__init__.py @@ -4,6 +4,11 @@ __email__ = "lukas.heumos@posteo.net" __version__ = "0.9.0" +import os + +# https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html +os.environ["SCIPY_ARRAY_API"] = "1" + from ehrapy._settings import EhrapyConfig, ehrapy_settings settings: EhrapyConfig = ehrapy_settings diff --git a/ehrapy/core/meta_information.py b/ehrapy/core/meta_information.py index 9af10c6b..b08ba1ea 100644 --- a/ehrapy/core/meta_information.py +++ b/ehrapy/core/meta_information.py @@ -3,9 +3,7 @@ import sys from datetime import datetime -import session_info from rich import print -from scanpy.logging import _versions_dependencies from ehrapy import __version__ @@ -17,23 +15,7 @@ def print_versions(): # pragma: no cover >>> import ehrapy as ep >>> ep.print_versions() """ - try: - session_info.show( - dependencies=True, - html=False, - excludes=[ - "builtins", - "stdlib_list", - "importlib_metadata", - "jupyter_core" - # Special module present if test coverage being calculated - # https://gitlab.com/joelostblom/session_info/-/issues/10 - "$coverage", - ], - ) - except AttributeError: - print("[bold yellow]Unable to fetch versions for one or more dependencies.") - pass + print_header() def print_version_and_date(*, file=None): # pragma: no cover @@ -47,26 +29,13 @@ def print_version_and_date(*, file=None): # pragma: no cover def print_header(*, file=None): # pragma: no cover - """Versions that might influence the numerical results. + """Versions that might influence the numerical results.""" + from session_info2 import session_info - Matplotlib and Seaborn are excluded from this. - """ - _DEPENDENCIES_NUMERICS = [ - "scanpy", - "anndata", - "umap", - "numpy", - "scipy", - "pandas", - ("sklearn", "scikit-learn"), - "statsmodels", - ("igraph", "python-igraph"), - "leidenalg", - "pynndescent", - ] + sinfo = session_info(os=True, cpu=True, gpu=True, dependencies=True) - modules = ["ehrapy"] + _DEPENDENCIES_NUMERICS - print( - " ".join(f"{mod}=={ver}" for mod, ver in _versions_dependencies(modules)), - file=file or sys.stdout, - ) + if file is not None: + print(sinfo, file=file) + return + + return sinfo diff --git a/ehrapy/tools/feature_ranking/_rank_features_groups.py b/ehrapy/tools/feature_ranking/_rank_features_groups.py index 6fb3932c..dc778868 100644 --- a/ehrapy/tools/feature_ranking/_rank_features_groups.py +++ b/ehrapy/tools/feature_ranking/_rank_features_groups.py @@ -3,6 +3,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Literal +import anndata as ad import numpy as np import pandas as pd import scanpy as sc @@ -446,7 +447,7 @@ def rank_features_groups( X_to_keep = np.zeros((len(adata), 1)) var_to_keep = pd.DataFrame({"dummy": [0]}) - adata_minimal = sc.AnnData( + adata_minimal = ad.AnnData( X=X_to_keep, obs=adata.obs, var=var_to_keep, diff --git a/pyproject.toml b/pyproject.toml index 7bc35d29..654d9092 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ classifiers = [ ] dependencies = [ - "session-info", + "session-info2", "lamin_utils", "rich", "scanpy[leiden]", @@ -135,7 +135,8 @@ filterwarnings = [ "ignore:X converted to numpy array with dtype object:UserWarning", "ignore:`flavor='seurat_v3'` expects raw count data, but non-integers were found:UserWarning", "ignore:All-NaN slice encountered:RuntimeWarning", - "ignore:Observation names are not unique. To make them unique, call `.obs_names_make_unique`.:UserWarning" + "ignore:Observation names are not unique. To make them unique, call `.obs_names_make_unique`.:UserWarning", + "ignore:Trying to modify attribute .var of view" ] minversion = 6.0 norecursedirs = [ '.*', 'build', 'dist', '*.egg', 'data', '__pycache__'] diff --git a/tests/anndata/test_anndata_ext.py b/tests/anndata/test_anndata_ext.py index 6e5bbf83..8e50635a 100644 --- a/tests/anndata/test_anndata_ext.py +++ b/tests/anndata/test_anndata_ext.py @@ -42,10 +42,10 @@ def setup_binary_df_to_anndata() -> DataFrame: col2_val = ["another_str" + str(idx) for idx in range(100)] col3_val = [0 for _ in range(100)] col4_val = [1.0 for _ in range(100)] - col5_val = [0.0 if idx % 2 == 0 else np.NaN for idx in range(100)] + col5_val = [0.0 if idx % 2 == 0 else np.nan for idx in range(100)] col6_val = [idx % 2 for idx in range(100)] col7_val = [float(idx % 2) for idx in range(100)] - col8_val = [idx % 3 if idx % 3 in {0, 1} else np.NaN for idx in range(100)] + col8_val = [idx % 3 if idx % 3 in {0, 1} else np.nan for idx in range(100)] df = DataFrame( { "col1": col1_val, diff --git a/tests/preprocessing/test_normalization.py b/tests/preprocessing/test_normalization.py index 5067b237..9b5b1c7b 100644 --- a/tests/preprocessing/test_normalization.py +++ b/tests/preprocessing/test_normalization.py @@ -503,8 +503,8 @@ def test_norm_power_kwargs(array_type, adata_to_norm): num1_norm = np.array([201.03636, 1132.8341, 1399.3877], dtype=np.float32) num2_norm = np.array([-1.8225479, 5.921072, 3.397709], dtype=np.float32) - assert np.allclose(adata_norm.X[:, 3], num1_norm) - assert np.allclose(adata_norm.X[:, 4], num2_norm) + assert np.allclose(adata_norm.X[:, 3], num1_norm, rtol=1e-02, atol=1e-02) + assert np.allclose(adata_norm.X[:, 4], num2_norm, rtol=1e-02, atol=1e-02) @pytest.mark.parametrize("array_type", ARRAY_TYPES) @@ -540,18 +540,18 @@ def test_norm_power_group(array_type, adata_mini): ) col2_norm = np.array( [ - -1.34342372, - -0.44542197, - 0.44898626, - 1.33985944, - -1.34344617, - -0.4453993, - 0.44900845, - 1.33983703, + -1.3504524, + -0.43539175, + 0.4501508, + 1.3356934, + -1.3437141, + -0.44512963, + 0.44927517, + 1.3395685, ], dtype=np.float32, ) - assert np.allclose(adata_mini_norm.X[:, 0], adata_mini_casted.X[:, 0]) + assert np.allclose(adata_mini_norm.X[:, 0], adata_mini_casted.X[:, 0], rtol=1e-02, atol=1e-02) assert np.allclose(adata_mini_norm.X[:, 1], col1_norm, rtol=1e-02, atol=1e-02) assert np.allclose(adata_mini_norm.X[:, 2], col2_norm, rtol=1e-02, atol=1e-02) From 15c63154bba280dc6a75f53de335fa89acccf558 Mon Sep 17 00:00:00 2001 From: Eljas Roellin <65244425+eroell@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:10:08 +0100 Subject: [PATCH 04/10] initial suggestions of array type handling on example of normalization methods (#835) * initial suggestions of array type checks on example of scale_norm * singledispatch normalization functions and test them * try dask import * doc build fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DRY Signed-off-by: zethson * Fix tests Signed-off-by: zethson * Fix tests Signed-off-by: zethson * Simplify tests Signed-off-by: zethson --------- Signed-off-by: zethson Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lukas Heumos --- .github/workflows/run_notebooks.yml | 1 - ehrapy/_compat.py | 9 +- ehrapy/preprocessing/_normalization.py | 130 ++++++++++--- pyproject.toml | 5 +- tests/preprocessing/test_normalization.py | 224 ++++++++++++++++------ 5 files changed, 277 insertions(+), 92 deletions(-) diff --git a/.github/workflows/run_notebooks.yml b/.github/workflows/run_notebooks.yml index f45fd96b..4af421c5 100644 --- a/.github/workflows/run_notebooks.yml +++ b/.github/workflows/run_notebooks.yml @@ -1,7 +1,6 @@ name: Run Notebooks on: - - push - pull_request jobs: diff --git a/ehrapy/_compat.py b/ehrapy/_compat.py index 5497eead..dc94a6d3 100644 --- a/ehrapy/_compat.py +++ b/ehrapy/_compat.py @@ -1,7 +1,6 @@ # Since we might check whether an object is an instance of dask.array.Array # without requiring dask installed in the environment. -# This would become obsolete should dask become a requirement for ehrapy - +from collections.abc import Callable try: import dask.array as da @@ -11,6 +10,12 @@ DASK_AVAILABLE = False +def _raise_array_type_not_implemented(func: Callable, type_: type) -> NotImplementedError: + raise NotImplementedError( + f"{func.__name__} does not support array type {type_}. Must be of type {func.registry.keys()}." # type: ignore + ) + + def is_dask_array(array): if DASK_AVAILABLE: return isinstance(array, da.Array) diff --git a/ehrapy/preprocessing/_normalization.py b/ehrapy/preprocessing/_normalization.py index de6cf646..67c41d02 100644 --- a/ehrapy/preprocessing/_normalization.py +++ b/ehrapy/preprocessing/_normalization.py @@ -1,16 +1,22 @@ from __future__ import annotations +from functools import singledispatch from typing import TYPE_CHECKING import numpy as np import sklearn.preprocessing as sklearn_pp -from ehrapy._compat import is_dask_array +from ehrapy._compat import _raise_array_type_not_implemented try: + import dask.array as da import dask_ml.preprocessing as daskml_pp + + DASK_AVAILABLE = True except ImportError: daskml_pp = None + DASK_AVAILABLE = False + from ehrapy.anndata.anndata_ext import ( assert_numeric_vars, @@ -69,6 +75,23 @@ def _scale_func_group( return None +@singledispatch +def _scale_norm_function(arr): + _raise_array_type_not_implemented(_scale_norm_function, type(arr)) + + +@_scale_norm_function.register +def _(arr: np.ndarray, **kwargs): + return sklearn_pp.StandardScaler(**kwargs).fit_transform + + +if DASK_AVAILABLE: + + @_scale_norm_function.register + def _(arr: da.Array, **kwargs): + return daskml_pp.StandardScaler(**kwargs).fit_transform + + def scale_norm( adata: AnnData, vars: str | Sequence[str] | None = None, @@ -98,10 +121,7 @@ def scale_norm( >>> adata_norm = ep.pp.scale_norm(adata, copy=True) """ - if is_dask_array(adata.X): - scale_func = daskml_pp.StandardScaler(**kwargs).fit_transform - else: - scale_func = sklearn_pp.StandardScaler(**kwargs).fit_transform + scale_func = _scale_norm_function(adata.X, **kwargs) return _scale_func_group( adata=adata, @@ -113,6 +133,23 @@ def scale_norm( ) +@singledispatch +def _minmax_norm_function(arr): + _raise_array_type_not_implemented(_minmax_norm_function, type(arr)) + + +@_minmax_norm_function.register +def _(arr: np.ndarray, **kwargs): + return sklearn_pp.MinMaxScaler(**kwargs).fit_transform + + +if DASK_AVAILABLE: + + @_minmax_norm_function.register + def _(arr: da.Array, **kwargs): + return daskml_pp.MinMaxScaler(**kwargs).fit_transform + + def minmax_norm( adata: AnnData, vars: str | Sequence[str] | None = None, @@ -143,10 +180,7 @@ def minmax_norm( >>> adata_norm = ep.pp.minmax_norm(adata, copy=True) """ - if is_dask_array(adata.X): - scale_func = daskml_pp.MinMaxScaler(**kwargs).fit_transform - else: - scale_func = sklearn_pp.MinMaxScaler(**kwargs).fit_transform + scale_func = _minmax_norm_function(adata.X, **kwargs) return _scale_func_group( adata=adata, @@ -158,6 +192,16 @@ def minmax_norm( ) +@singledispatch +def _maxabs_norm_function(arr): + _raise_array_type_not_implemented(_scale_norm_function, type(arr)) + + +@_maxabs_norm_function.register +def _(arr: np.ndarray): + return sklearn_pp.MaxAbsScaler().fit_transform + + def maxabs_norm( adata: AnnData, vars: str | Sequence[str] | None = None, @@ -184,10 +228,8 @@ def maxabs_norm( >>> adata = ep.dt.mimic_2(encoded=True) >>> adata_norm = ep.pp.maxabs_norm(adata, copy=True) """ - if is_dask_array(adata.X): - raise NotImplementedError("MaxAbsScaler is not implemented in dask_ml.") - else: - scale_func = sklearn_pp.MaxAbsScaler().fit_transform + + scale_func = _maxabs_norm_function(adata.X) return _scale_func_group( adata=adata, @@ -199,6 +241,23 @@ def maxabs_norm( ) +@singledispatch +def _robust_scale_norm_function(arr, **kwargs): + _raise_array_type_not_implemented(_robust_scale_norm_function, type(arr)) + + +@_robust_scale_norm_function.register +def _(arr: np.ndarray, **kwargs): + return sklearn_pp.RobustScaler(**kwargs).fit_transform + + +if DASK_AVAILABLE: + + @_robust_scale_norm_function.register + def _(arr: da.Array, **kwargs): + return daskml_pp.RobustScaler(**kwargs).fit_transform + + def robust_scale_norm( adata: AnnData, vars: str | Sequence[str] | None = None, @@ -229,10 +288,8 @@ def robust_scale_norm( >>> adata = ep.dt.mimic_2(encoded=True) >>> adata_norm = ep.pp.robust_scale_norm(adata, copy=True) """ - if is_dask_array(adata.X): - scale_func = daskml_pp.RobustScaler(**kwargs).fit_transform - else: - scale_func = sklearn_pp.RobustScaler(**kwargs).fit_transform + + scale_func = _robust_scale_norm_function(adata.X, **kwargs) return _scale_func_group( adata=adata, @@ -244,6 +301,23 @@ def robust_scale_norm( ) +@singledispatch +def _quantile_norm_function(arr): + _raise_array_type_not_implemented(_quantile_norm_function, type(arr)) + + +@_quantile_norm_function.register +def _(arr: np.ndarray, **kwargs): + return sklearn_pp.QuantileTransformer(**kwargs).fit_transform + + +if DASK_AVAILABLE: + + @_quantile_norm_function.register + def _(arr: da.Array, **kwargs): + return daskml_pp.QuantileTransformer(**kwargs).fit_transform + + def quantile_norm( adata: AnnData, vars: str | Sequence[str] | None = None, @@ -273,10 +347,8 @@ def quantile_norm( >>> adata = ep.dt.mimic_2(encoded=True) >>> adata_norm = ep.pp.quantile_norm(adata, copy=True) """ - if is_dask_array(adata.X): - scale_func = daskml_pp.QuantileTransformer(**kwargs).fit_transform - else: - scale_func = sklearn_pp.QuantileTransformer(**kwargs).fit_transform + + scale_func = _quantile_norm_function(adata.X, **kwargs) return _scale_func_group( adata=adata, @@ -288,6 +360,16 @@ def quantile_norm( ) +@singledispatch +def _power_norm_function(arr, **kwargs): + _raise_array_type_not_implemented(_power_norm_function, type(arr)) + + +@_power_norm_function.register +def _(arr: np.ndarray, **kwargs): + return sklearn_pp.PowerTransformer(**kwargs).fit_transform + + def power_norm( adata: AnnData, vars: str | Sequence[str] | None = None, @@ -317,10 +399,8 @@ def power_norm( >>> adata = ep.dt.mimic_2(encoded=True) >>> adata_norm = ep.pp.power_norm(adata, copy=True) """ - if is_dask_array(adata.X): - raise NotImplementedError("dask-ml has no PowerTransformer, this is only available in scikit-learn") - else: - scale_func = sklearn_pp.PowerTransformer(**kwargs).fit_transform + + scale_func = _power_norm_function(adata.X, **kwargs) return _scale_func_group( adata=adata, diff --git a/pyproject.toml b/pyproject.toml index 654d9092..55f2f277 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ medcat = [ "medcat", ] dask = [ - "dask", + "anndata[dask]", "dask-ml", ] dev = [ @@ -136,7 +136,8 @@ filterwarnings = [ "ignore:`flavor='seurat_v3'` expects raw count data, but non-integers were found:UserWarning", "ignore:All-NaN slice encountered:RuntimeWarning", "ignore:Observation names are not unique. To make them unique, call `.obs_names_make_unique`.:UserWarning", - "ignore:Trying to modify attribute .var of view" + "ignore:Trying to modify attribute `.var` of view, initializing view as actual.:anndata.ImplicitModificationWarning", + "ignore:Transforming to str index.:anndata.ImplicitModificationWarning:" ] minversion = 6.0 norecursedirs = [ '.*', 'build', 'dist', '*.egg', 'data', '__pycache__'] diff --git a/tests/preprocessing/test_normalization.py b/tests/preprocessing/test_normalization.py index 9b5b1c7b..249e6892 100644 --- a/tests/preprocessing/test_normalization.py +++ b/tests/preprocessing/test_normalization.py @@ -2,6 +2,7 @@ from collections import OrderedDict from pathlib import Path +import dask.array as da import numpy as np import pandas as pd import pytest @@ -13,6 +14,7 @@ from tests.conftest import ARRAY_TYPES, TEST_DATA_PATH CURRENT_DIR = Path(__file__).parent +from scipy import sparse @pytest.fixture @@ -87,24 +89,40 @@ def test_vars_checks(adata_to_norm): ep.pp.scale_norm(adata_to_norm, vars=["String1"]) -@pytest.mark.parametrize("array_type", ARRAY_TYPES) -def test_norm_scale(array_type, adata_to_norm): +# TODO: check this for each function, with just default settings? +@pytest.mark.parametrize( + "array_type,expected_error", + [ + (np.array, None), + (da.array, None), + (sparse.csr_matrix, NotImplementedError), + ], +) +def test_norm_scale_array_types(adata_to_norm, array_type, expected_error): + adata_to_norm.X = array_type(adata_to_norm.X) + if expected_error: + with pytest.raises(expected_error): + ep.pp.scale_norm(adata_to_norm) + + +@pytest.mark.parametrize("array_type", [np.array, da.array]) +def test_norm_scale(adata_to_norm, array_type): """Test for the scaling normalization method.""" warnings.filterwarnings("ignore") - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) + ep.pp.scale_norm(adata_to_norm) adata_norm = ep.pp.scale_norm(adata_to_norm, copy=True) num1_norm = np.array([-1.4039999, 0.55506986, 0.84893], dtype=np.float32) num2_norm = np.array([-1.3587323, 1.0190493, 0.3396831], dtype=np.float32) - assert np.array_equal(adata_norm.X[:, 0], adata_to_norm_casted.X[:, 0]) - assert np.array_equal(adata_norm.X[:, 1], adata_to_norm_casted.X[:, 1]) - assert np.array_equal(adata_norm.X[:, 2], adata_to_norm_casted.X[:, 2]) + assert np.array_equal(adata_norm.X[:, 0], adata_to_norm.X[:, 0]) + assert np.array_equal(adata_norm.X[:, 1], adata_to_norm.X[:, 1]) + assert np.array_equal(adata_norm.X[:, 2], adata_to_norm.X[:, 2]) assert np.allclose(adata_norm.X[:, 3], num1_norm) assert np.allclose(adata_norm.X[:, 4], num2_norm) - assert np.allclose(adata_norm.X[:, 5], adata_to_norm_casted.X[:, 5], equal_nan=True) + assert np.allclose(adata_norm.X[:, 5], adata_to_norm.X[:, 5], equal_nan=True) def test_norm_scale_integers(adata_mini_integers_in_X): @@ -130,8 +148,7 @@ def test_norm_scale_integers(adata_mini_integers_in_X): @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_norm_scale_kwargs(array_type, adata_to_norm): - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) adata_norm = ep.pp.scale_norm(adata_to_norm, copy=True, with_mean=False) @@ -174,23 +191,37 @@ def test_norm_scale_group(array_type, adata_mini): assert np.allclose(adata_mini_norm.X[:, 2], col2_norm) +@pytest.mark.parametrize( + "array_type,expected_error", + [ + (np.array, None), + (da.array, None), + (sparse.csr_matrix, NotImplementedError), + ], +) +def test_norm_minmax_array_types(adata_to_norm, array_type, expected_error): + adata_to_norm.X = array_type(adata_to_norm.X) + if expected_error: + with pytest.raises(expected_error): + ep.pp.minmax_norm(adata_to_norm) + + @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_norm_minmax(array_type, adata_to_norm): """Test for the minmax normalization method.""" - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) - adata_norm = ep.pp.minmax_norm(adata_to_norm_casted, copy=True) + adata_norm = ep.pp.minmax_norm(adata_to_norm, copy=True) num1_norm = np.array([0.0, 0.86956537, 0.9999999], dtype=np.dtype(np.float32)) num2_norm = np.array([0.0, 1.0, 0.71428573], dtype=np.float32) - assert np.array_equal(adata_norm.X[:, 0], adata_to_norm_casted.X[:, 0]) - assert np.array_equal(adata_norm.X[:, 1], adata_to_norm_casted.X[:, 1]) - assert np.array_equal(adata_norm.X[:, 2], adata_to_norm_casted.X[:, 2]) + assert np.array_equal(adata_norm.X[:, 0], adata_to_norm.X[:, 0]) + assert np.array_equal(adata_norm.X[:, 1], adata_to_norm.X[:, 1]) + assert np.array_equal(adata_norm.X[:, 2], adata_to_norm.X[:, 2]) assert np.allclose(adata_norm.X[:, 3], num1_norm) assert np.allclose(adata_norm.X[:, 4], num2_norm) - assert np.allclose(adata_norm.X[:, 5], adata_to_norm_casted.X[:, 5], equal_nan=True) + assert np.allclose(adata_norm.X[:, 5], adata_to_norm.X[:, 5], equal_nan=True) def test_norm_minmax_integers(adata_mini_integers_in_X): @@ -201,10 +232,9 @@ def test_norm_minmax_integers(adata_mini_integers_in_X): @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_norm_minmax_kwargs(array_type, adata_to_norm): - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) - adata_norm = ep.pp.minmax_norm(adata_to_norm_casted, copy=True, feature_range=(0, 2)) + adata_norm = ep.pp.minmax_norm(adata_to_norm, copy=True, feature_range=(0, 2)) num1_norm = np.array([0.0, 1.7391307, 1.9999998], dtype=np.float32) num2_norm = np.array([0.0, 2.0, 1.4285715], dtype=np.float32) @@ -234,28 +264,44 @@ def test_norm_minmax_group(array_type, adata_mini): assert np.allclose(adata_mini_norm.X[:, 2], col2_norm) +@pytest.mark.parametrize( + "array_type,expected_error", + [ + (np.array, None), + (da.array, NotImplementedError), + (sparse.csr_matrix, NotImplementedError), + ], +) +def test_norm_maxabs_array_types(adata_to_norm, array_type, expected_error): + adata_to_norm.X = array_type(adata_to_norm.X) + if expected_error: + with pytest.raises(expected_error): + ep.pp.maxabs_norm(adata_to_norm) + else: + ep.pp.maxabs_norm(adata_to_norm) + + @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_norm_maxabs(array_type, adata_to_norm): """Test for the maxabs normalization method.""" - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) if "dask" in array_type.__name__: with pytest.raises(NotImplementedError): - adata_norm = ep.pp.maxabs_norm(adata_to_norm_casted, copy=True) + adata_norm = ep.pp.maxabs_norm(adata_to_norm, copy=True) else: - adata_norm = ep.pp.maxabs_norm(adata_to_norm_casted, copy=True) + adata_norm = ep.pp.maxabs_norm(adata_to_norm, copy=True) num1_norm = np.array([0.5964913, 0.94736844, 1.0], dtype=np.float32) num2_norm = np.array([-0.4, 1.0, 0.6], dtype=np.float32) - assert np.array_equal(adata_norm.X[:, 0], adata_to_norm_casted.X[:, 0]) - assert np.array_equal(adata_norm.X[:, 1], adata_to_norm_casted.X[:, 1]) - assert np.array_equal(adata_norm.X[:, 2], adata_to_norm_casted.X[:, 2]) + assert np.array_equal(adata_norm.X[:, 0], adata_to_norm.X[:, 0]) + assert np.array_equal(adata_norm.X[:, 1], adata_to_norm.X[:, 1]) + assert np.array_equal(adata_norm.X[:, 2], adata_to_norm.X[:, 2]) assert np.allclose(adata_norm.X[:, 3], num1_norm) assert np.allclose(adata_norm.X[:, 4], num2_norm) - assert np.allclose(adata_norm.X[:, 5], adata_to_norm_casted.X[:, 5], equal_nan=True) + assert np.allclose(adata_norm.X[:, 5], adata_to_norm.X[:, 5], equal_nan=True) def test_norm_maxabs_integers(adata_mini_integers_in_X): @@ -300,23 +346,37 @@ def test_norm_maxabs_group(array_type, adata_mini): assert np.allclose(adata_mini_norm.X[:, 2], col2_norm) +@pytest.mark.parametrize( + "array_type,expected_error", + [ + (np.array, None), + (da.array, None), + (sparse.csr_matrix, NotImplementedError), + ], +) +def test_norm_robust_scale_array_types(adata_to_norm, array_type, expected_error): + adata_to_norm.X = array_type(adata_to_norm.X) + if expected_error: + with pytest.raises(expected_error): + ep.pp.robust_scale_norm(adata_to_norm) + + @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_norm_robust_scale(array_type, adata_to_norm): """Test for the robust_scale normalization method.""" - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) - adata_norm = ep.pp.robust_scale_norm(adata_to_norm_casted, copy=True) + adata_norm = ep.pp.robust_scale_norm(adata_to_norm, copy=True) num1_norm = np.array([-1.73913043, 0.0, 0.26086957], dtype=np.float32) num2_norm = np.array([-1.4285715, 0.5714286, 0.0], dtype=np.float32) - assert np.array_equal(adata_norm.X[:, 0], adata_to_norm_casted.X[:, 0]) - assert np.array_equal(adata_norm.X[:, 1], adata_to_norm_casted.X[:, 1]) - assert np.array_equal(adata_norm.X[:, 2], adata_to_norm_casted.X[:, 2]) + assert np.array_equal(adata_norm.X[:, 0], adata_to_norm.X[:, 0]) + assert np.array_equal(adata_norm.X[:, 1], adata_to_norm.X[:, 1]) + assert np.array_equal(adata_norm.X[:, 2], adata_to_norm.X[:, 2]) assert np.allclose(adata_norm.X[:, 3], num1_norm) assert np.allclose(adata_norm.X[:, 4], num2_norm) - assert np.allclose(adata_norm.X[:, 5], adata_to_norm_casted.X[:, 5], equal_nan=True) + assert np.allclose(adata_norm.X[:, 5], adata_to_norm.X[:, 5], equal_nan=True) def test_norm_robust_scale_integers(adata_mini_integers_in_X): @@ -326,11 +386,10 @@ def test_norm_robust_scale_integers(adata_mini_integers_in_X): @pytest.mark.parametrize("array_type", ARRAY_TYPES) -def test_norm_robust_scale_kwargs(array_type, adata_to_norm): - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) +def test_norm_robust_scale_kwargs(adata_to_norm, array_type): + adata_to_norm.X = array_type(adata_to_norm.X) - adata_norm = ep.pp.robust_scale_norm(adata_to_norm_casted, copy=True, with_scaling=False) + adata_norm = ep.pp.robust_scale_norm(adata_to_norm, copy=True, with_scaling=False) num1_norm = np.array([-2.0, 0.0, 0.2999997], dtype=np.float32) num2_norm = np.array([-5.0, 2.0, 0.0], dtype=np.float32) @@ -363,24 +422,38 @@ def test_norm_robust_scale_group(array_type, adata_mini): assert np.allclose(adata_mini_norm.X[:, 2], col2_norm) +@pytest.mark.parametrize( + "array_type,expected_error", + [ + (np.array, None), + (da.array, None), + (sparse.csr_matrix, NotImplementedError), + ], +) +def test_norm_quantile_array_types(adata_to_norm, array_type, expected_error): + adata_to_norm.X = array_type(adata_to_norm.X) + if expected_error: + with pytest.raises(expected_error): + ep.pp.quantile_norm(adata_to_norm) + + @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_norm_quantile_uniform(array_type, adata_to_norm): """Test for the quantile normalization method.""" warnings.filterwarnings("ignore", category=UserWarning) - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) - adata_norm = ep.pp.quantile_norm(adata_to_norm_casted, copy=True) + adata_norm = ep.pp.quantile_norm(adata_to_norm, copy=True) num1_norm = np.array([0.0, 0.5, 1.0], dtype=np.float32) num2_norm = np.array([0.0, 1.0, 0.5], dtype=np.float32) - assert np.array_equal(adata_norm.X[:, 0], adata_to_norm_casted.X[:, 0]) - assert np.array_equal(adata_norm.X[:, 1], adata_to_norm_casted.X[:, 1]) - assert np.array_equal(adata_norm.X[:, 2], adata_to_norm_casted.X[:, 2]) + assert np.array_equal(adata_norm.X[:, 0], adata_to_norm.X[:, 0]) + assert np.array_equal(adata_norm.X[:, 1], adata_to_norm.X[:, 1]) + assert np.array_equal(adata_norm.X[:, 2], adata_to_norm.X[:, 2]) assert np.allclose(adata_norm.X[:, 3], num1_norm) assert np.allclose(adata_norm.X[:, 4], num2_norm) - assert np.allclose(adata_norm.X[:, 5], adata_to_norm_casted.X[:, 5], equal_nan=True) + assert np.allclose(adata_norm.X[:, 5], adata_to_norm.X[:, 5], equal_nan=True) def test_norm_quantile_integers(adata_mini_integers_in_X): @@ -406,10 +479,9 @@ def test_norm_quantile_integers(adata_mini_integers_in_X): @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_norm_quantile_uniform_kwargs(array_type, adata_to_norm): - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) - adata_norm = ep.pp.quantile_norm(adata_to_norm_casted, copy=True, output_distribution="normal") + adata_norm = ep.pp.quantile_norm(adata_to_norm, copy=True, output_distribution="normal") num1_norm = np.array([-5.19933758, 0.0, 5.19933758], dtype=np.float32) num2_norm = np.array([-5.19933758, 5.19933758, 0.0], dtype=np.float32) @@ -442,27 +514,41 @@ def test_norm_quantile_uniform_group(array_type, adata_mini): assert np.allclose(adata_mini_norm.X[:, 2], col2_norm) +@pytest.mark.parametrize( + "array_type,expected_error", + [ + (np.array, None), + (da.array, None), + (sparse.csr_matrix, NotImplementedError), + ], +) +def test_norm_power_array_types(adata_to_norm, array_type, expected_error): + adata_to_norm.X = array_type(adata_to_norm.X) + if expected_error: + with pytest.raises(expected_error): + ep.pp.power_norm(adata_to_norm) + + @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_norm_power(array_type, adata_to_norm): """Test for the power transformation normalization method.""" - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) if "dask" in array_type.__name__: with pytest.raises(NotImplementedError): - ep.pp.power_norm(adata_to_norm_casted, copy=True) + ep.pp.power_norm(adata_to_norm, copy=True) else: - adata_norm = ep.pp.power_norm(adata_to_norm_casted, copy=True) + adata_norm = ep.pp.power_norm(adata_to_norm, copy=True) num1_norm = np.array([-1.3821232, 0.43163615, 0.950487], dtype=np.float32) num2_norm = np.array([-1.340104, 1.0613203, 0.27878374], dtype=np.float32) - assert np.array_equal(adata_norm.X[:, 0], adata_to_norm_casted.X[:, 0]) - assert np.array_equal(adata_norm.X[:, 1], adata_to_norm_casted.X[:, 1]) - assert np.array_equal(adata_norm.X[:, 2], adata_to_norm_casted.X[:, 2]) + assert np.array_equal(adata_norm.X[:, 0], adata_to_norm.X[:, 0]) + assert np.array_equal(adata_norm.X[:, 1], adata_to_norm.X[:, 1]) + assert np.array_equal(adata_norm.X[:, 2], adata_to_norm.X[:, 2]) assert np.allclose(adata_norm.X[:, 3], num1_norm, rtol=1.1) assert np.allclose(adata_norm.X[:, 4], num2_norm, rtol=1.1) - assert np.allclose(adata_norm.X[:, 5], adata_to_norm_casted.X[:, 5], equal_nan=True) + assert np.allclose(adata_norm.X[:, 5], adata_to_norm.X[:, 5], equal_nan=True) def test_norm_power_integers(adata_mini_integers_in_X): @@ -488,17 +574,16 @@ def test_norm_power_integers(adata_mini_integers_in_X): @pytest.mark.parametrize("array_type", ARRAY_TYPES) def test_norm_power_kwargs(array_type, adata_to_norm): - adata_to_norm_casted = adata_to_norm.copy() - adata_to_norm_casted.X = array_type(adata_to_norm_casted.X) + adata_to_norm.X = array_type(adata_to_norm.X) if "dask" in array_type.__name__: with pytest.raises(NotImplementedError): - ep.pp.power_norm(adata_to_norm_casted, copy=True) + ep.pp.power_norm(adata_to_norm, copy=True) else: with pytest.raises(ValueError): - ep.pp.power_norm(adata_to_norm_casted, copy=True, method="box-cox") + ep.pp.power_norm(adata_to_norm, copy=True, method="box-cox") - adata_norm = ep.pp.power_norm(adata_to_norm_casted, copy=True, standardize=False) + adata_norm = ep.pp.power_norm(adata_to_norm, copy=True, standardize=False) num1_norm = np.array([201.03636, 1132.8341, 1399.3877], dtype=np.float32) num2_norm = np.array([-1.8225479, 5.921072, 3.397709], dtype=np.float32) @@ -556,6 +641,21 @@ def test_norm_power_group(array_type, adata_mini): assert np.allclose(adata_mini_norm.X[:, 2], col2_norm, rtol=1e-02, atol=1e-02) +@pytest.mark.parametrize( + "array_type,expected_error", + [ + (np.array, None), + (da.array, None), + (sparse.csr_matrix, None), + ], +) +def test_norm_log_norm_array_types(adata_to_norm, array_type, expected_error): + adata_to_norm.X = array_type(adata_to_norm.X) + if expected_error: + with pytest.raises(expected_error): + ep.pp.log_norm(adata_to_norm) + + def test_norm_log1p(adata_to_norm): """Test for the log normalization method.""" # Ensure that some test data is strictly positive From 61bd7bfce83b159f3341cc05a2b3ccf2ba5a4f27 Mon Sep 17 00:00:00 2001 From: zethson Date: Mon, 6 Jan 2025 17:11:44 +0100 Subject: [PATCH 05/10] 0.11.0 release prep Signed-off-by: zethson --- ehrapy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ehrapy/__init__.py b/ehrapy/__init__.py index 0255cd2a..1a52786c 100644 --- a/ehrapy/__init__.py +++ b/ehrapy/__init__.py @@ -2,7 +2,7 @@ __author__ = "Lukas Heumos" __email__ = "lukas.heumos@posteo.net" -__version__ = "0.9.0" +__version__ = "0.11.0" import os From 2b1bf44c8a52325a48a463d2f495449c67de097f Mon Sep 17 00:00:00 2001 From: zethson Date: Mon, 6 Jan 2025 17:17:47 +0100 Subject: [PATCH 06/10] 0.12.0 prep Signed-off-by: zethson --- .github/release-drafter.yml | 4 ++-- ehrapy/__init__.py | 2 +- pyproject.toml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml index 7d6e3aa2..7fe75c1e 100644 --- a/.github/release-drafter.yml +++ b/.github/release-drafter.yml @@ -1,5 +1,5 @@ -name-template: "0.11.0 🌈" -tag-template: 0.11.0 +name-template: "0.12.0 🌈" +tag-template: 0.12.0 exclude-labels: - "skip-changelog" diff --git a/ehrapy/__init__.py b/ehrapy/__init__.py index 1a52786c..81bdecb3 100644 --- a/ehrapy/__init__.py +++ b/ehrapy/__init__.py @@ -2,7 +2,7 @@ __author__ = "Lukas Heumos" __email__ = "lukas.heumos@posteo.net" -__version__ = "0.11.0" +__version__ = "0.12.0" import os diff --git a/pyproject.toml b/pyproject.toml index 55f2f277..a2bec3dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = ["hatchling"] [project] name = "ehrapy" -version = "0.11.0" +version = "0.12.0" description = "Electronic Health Record Analysis with Python." readme = "README.md" requires-python = ">=3.10,<3.13" From 0701a45b818ce15af039e598aa9e8684bd01e223 Mon Sep 17 00:00:00 2001 From: Vladimir Shitov <35199218+VladimirShitov@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:01:21 +0100 Subject: [PATCH 07/10] Fix a typo in `pl.paga_compare`: `pos` -> `pos,` (#846) --- ehrapy/plot/_scanpy_pl_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ehrapy/plot/_scanpy_pl_api.py b/ehrapy/plot/_scanpy_pl_api.py index 4cf278d6..81791036 100644 --- a/ehrapy/plot/_scanpy_pl_api.py +++ b/ehrapy/plot/_scanpy_pl_api.py @@ -1984,7 +1984,8 @@ def paga_compare( save=save, title_graph=title_graph, groups_graph=groups_graph, - pos=pos**paga_graph_params, + pos=pos, + **paga_graph_params, ) From 5cc736d410bc392f1fd1221ebccc439e7d434215 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:01:38 +0100 Subject: [PATCH 08/10] [pre-commit.ci] pre-commit autoupdate (#845) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.4 → v0.8.6](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.4...v0.8.6) - [github.com/pre-commit/mirrors-mypy: v1.14.0 → v1.14.1](https://github.com/pre-commit/mirrors-mypy/compare/v1.14.0...v1.14.1) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fdddc766..61fd4955 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.4 + rev: v0.8.6 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] @@ -27,7 +27,7 @@ repos: - id: trailing-whitespace - id: check-case-conflict - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.14.0 + rev: v1.14.1 hooks: - id: mypy args: [--no-strict-optional, --ignore-missing-imports] From b2b8e403fd4e365810f23b0366fb090e27d1e3b5 Mon Sep 17 00:00:00 2001 From: Carl Buchholz <32228189+aGuyLearning@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:18:00 +0100 Subject: [PATCH 09/10] Revamp survival analysis interface (#842) * cox_ph add all arguments * updated test to use keywords * weibull_aft arguments update * log_logistic update * updated log logistic example * store summary df in adata.uns * try moving np * omit inplace keyword * added explanation, as to where the results are stored * corrected spelling * updated tests to check for .uns ( should be removed later, when the univariates are updated ) * fix argument order, doc fixes * slightly simpler wording * fiexed spelling Co-authored-by: Lukas Heumos * Update ehrapy/tools/_sa.py Co-authored-by: Lukas Heumos * Update ehrapy/tools/_sa.py Co-authored-by: Lukas Heumos * Update ehrapy/tools/_sa.py Co-authored-by: Lukas Heumos * Update ehrapy/tools/_sa.py Co-authored-by: Lukas Heumos * Update ehrapy/tools/_sa.py Co-authored-by: Lukas Heumos * Update ehrapy/tools/_sa.py Co-authored-by: Lukas Heumos * Update ehrapy/tools/_sa.py Co-authored-by: Lukas Heumos * renamed function to be clearer * Add uns_key parameter to Kaplan-Meier, Nelson-Aalen, and Weibull functions for customizable storage in AnnData object * Update test assertions in TestSA for event_table handling and pass adata to assertion method * uns to in doc --------- Co-authored-by: Lukas Heumos Co-authored-by: eroell Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> --- ehrapy/tools/_sa.py | 310 +++++++++++++++++++++++++++++++++++------ tests/tools/test_sa.py | 16 ++- 2 files changed, 281 insertions(+), 45 deletions(-) diff --git a/ehrapy/tools/_sa.py b/ehrapy/tools/_sa.py index fed63b9e..e436e73d 100644 --- a/ehrapy/tools/_sa.py +++ b/ehrapy/tools/_sa.py @@ -3,7 +3,7 @@ import warnings from typing import TYPE_CHECKING, Literal -import numpy as np # This package is implicitly used +import numpy as np # noqa: TC002 import pandas as pd import statsmodels.api as sm import statsmodels.formula.api as smf @@ -199,6 +199,7 @@ def kaplan_meier( duration_col: str, event_col: str | None = None, *, + uns_key: str = "kaplan_meier", timeline: list[float] | None = None, entry: str | None = None, label: str | None = None, @@ -212,14 +213,18 @@ def kaplan_meier( The Kaplan–Meier estimator, also known as the product limit estimator, is a non-parametric statistic used to estimate the survival function from lifetime data. In medical research, it is often used to measure the fraction of patients living for a certain amount of time after treatment. + The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'kaplan_meier' unless specified otherwise in the `uns_key` parameter. See https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#module-lifelines.fitters.kaplan_meier_fitter Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. - duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: The name of the column in anndata that contains the subjects’ death observation. + adata: AnnData object. + duration_col: The name of the column in the AnnData object that contains the subjects’ lifetimes. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. + uns_key: The key to use for the `.uns` slot in the AnnData object. timeline: Return the best estimate at the values in timelines (positively increasing) entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations. If None, all members of the population entered study when they were "born". @@ -247,6 +252,7 @@ def kaplan_meier( duration_col, event_col, KaplanMeierFitter, + uns_key, True, timeline, entry, @@ -347,9 +353,7 @@ def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_ return dataframe -def _regression_model( - model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None, accept_zero_duration=True -): +def _build_model_input_dataframe(adata: AnnData, duration_col: str, accept_zero_duration=True): """Convenience function for regression models.""" df = anndata_to_df(adata) df = df.dropna() @@ -357,26 +361,67 @@ def _regression_model( if not accept_zero_duration: df.loc[df[duration_col] == 0, duration_col] += 1e-5 - model = model_class() - model.fit(df, duration_col, event_col, entry_col=entry_col) - - return model + return df -def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> CoxPHFitter: +def cox_ph( + adata: AnnData, + duration_col: str, + event_col: str = None, + *, + uns_key: str = "cox_ph", + alpha: float = 0.05, + label: str | None = None, + baseline_estimation_method: Literal["breslow", "spline", "piecewise"] = "breslow", + penalizer: float | np.ndarray = 0.0, + l1_ratio: float = 0.0, + strata: list[str] | str | None = None, + n_baseline_knots: int = 4, + knots: list[float] | None = None, + breakpoints: list[float] | None = None, + weights_col: str | None = None, + cluster_col: str | None = None, + entry_col: str = None, + robust: bool = False, + formula: str = None, + batch_mode: bool = None, + show_progress: bool = False, + initial_point: np.ndarray | None = None, + fit_options: dict | None = None, +) -> CoxPHFitter: """Fit the Cox’s proportional hazard for the survival function. The Cox proportional hazards model (CoxPH) examines the relationship between the survival time of subjects and one or more predictor variables. It models the hazard rate as a product of a baseline hazard function and an exponential function of the predictors, assuming proportional hazards over time. + The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'cox_ph' unless specified otherwise in the `uns_key` parameter. See https://lifelines.readthedocs.io/en/latest/fitters/regression/CoxPHFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: The name of the column in anndata that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. + uns_key: The key to use for the `.uns` slot in the AnnData object. + alpha: The alpha value in the confidence intervals. + label: The name of the column of the estimate. + baseline_estimation_method: The method used to estimate the baseline hazard. Options are 'breslow', 'spline', and 'piecewise'. + penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates. + l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above. + strata: specify a list of columns to use in stratification. This is useful if a categorical covariate does not obey the proportional hazard assumption. This is used similar to the strata expression in R. See http://courses.washington.edu/b515/l17.pdf. + n_baseline_knots: Used when baseline_estimation_method="spline". Set the number of knots (interior & exterior) in the baseline hazard, which will be placed evenly along the time axis. Should be at least 2. Royston et. al, the authors of this model, suggest 4 to start, but any values between 2 and 8 are reasonable. If you need to customize the timestamps used to calculate the curve, use the knots parameter instead. + knots: When baseline_estimation_method="spline", this allows customizing the points in the time axis for the baseline hazard curve. To use evenly-spaced points in time, the n_baseline_knots parameter can be employed instead. + breakpoints: Used when baseline_estimation_method="piecewise". Set the positions of the baseline hazard breakpoints. + weights_col: The name of the column in DataFrame that contains the weights for each subject. + cluster_col: The name of the column in DataFrame that contains the cluster variable. Using this forces the sandwich estimator (robust variance estimator) to be used. entry_col: Column denoting when a subject entered the study, i.e. left-truncation. + robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ. + formula: an Wilkinson formula, like in R and statsmodels, for the right-hand-side. If left as None, all columns not assigned as durations, weights, etc. are used. Uses the library Formulaic for parsing. + batch_mode: Enabling batch_mode can be faster for datasets with a large number of ties. If left as `None`, lifelines will choose the best option. + show_progress: Since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing. + initial_point: set the starting point for the iterative solver. + fit_options: Additional keyword arguments to pass into the estimator. Returns: Fitted CoxPHFitter. @@ -388,24 +433,95 @@ def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = N >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) >>> cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg") """ - return _regression_model(CoxPHFitter, adata, duration_col, event_col, entry_col) + df = _build_model_input_dataframe(adata, duration_col) + cox_ph = CoxPHFitter( + alpha=alpha, + label=label, + strata=strata, + baseline_estimation_method=baseline_estimation_method, + penalizer=penalizer, + l1_ratio=l1_ratio, + n_baseline_knots=n_baseline_knots, + knots=knots, + breakpoints=breakpoints, + ) + cox_ph.fit( + df, + duration_col=duration_col, + event_col=event_col, + entry_col=entry_col, + robust=robust, + initial_point=initial_point, + weights_col=weights_col, + cluster_col=cluster_col, + batch_mode=batch_mode, + formula=formula, + fit_options=fit_options, + show_progress=show_progress, + ) + + summary = cox_ph.summary + adata.uns[uns_key] = summary + return cox_ph -def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> WeibullAFTFitter: + +def weibull_aft( + adata: AnnData, + duration_col: str, + event_col: str, + *, + uns_key: str = "weibull_aft", + alpha: float = 0.05, + fit_intercept: bool = True, + penalizer: float | np.ndarray = 0.0, + l1_ratio: float = 0.0, + model_ancillary: bool = True, + ancillary: bool | pd.DataFrame | str | None = None, + show_progress: bool = False, + weights_col: str | None = None, + robust: bool = False, + initial_point=None, + entry_col: str | None = None, + formula: str | None = None, + fit_options: dict | None = None, +) -> WeibullAFTFitter: """Fit the Weibull accelerated failure time regression for the survival function. The Weibull Accelerated Failure Time (AFT) survival regression model is a statistical method used to analyze time-to-event data, where the underlying assumption is that the logarithm of survival time follows a Weibull distribution. It models the survival time as an exponential function of the predictors, assuming a specific shape parameter for the distribution and allowing for accelerated or decelerated failure times based on the covariates. + The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'weibull_aft' unless specified otherwise in the `uns_key` parameter. + See https://lifelines.readthedocs.io/en/latest/fitters/regression/WeibullAFTFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: Name of the column in anndata that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. + uns_key: The key to use for the `.uns` slot in the AnnData object. + alpha: The alpha value in the confidence intervals. + fit_intercept: Whether to fit an intercept term in the model. + penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates. + l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above. + model_ancillary: set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization. + ancillary: Choose to model the ancillary parameters. + If None or False, explicitly do not fit the ancillary parameters using any covariates. + If True, model the ancillary parameters with the same covariates as ``df``. + If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``. + If str, should be a formula + show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing. + weights_col: The name of the column in DataFrame that contains the weights for each subject. + robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ. + initial_point: set the starting point for the iterative solver. entry_col: Column denoting when a subject entered the study, i.e. left-truncation. + formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/ + If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.) + fit_options: Additional keyword arguments to pass into the estimator. + Returns: Fitted WeibullAFTFitter. @@ -413,27 +529,96 @@ def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: st Examples: >>> import ehrapy as ep >>> adata = ep.dt.mimic_2(encoded=False) - >>> # Flip 'censor_fl' because 0 = death and 1 = censored >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) - >>> aft = ep.tl.weibull_aft(adata, "mort_day_censored", "censor_flg") + >>> adata = adata[:, ["mort_day_censored", "censor_flg"]] + >>> aft = ep.tl.weibull_aft(adata, duration_col="mort_day_censored", event_col="censor_flg") + >>> aft.print_summary() """ - return _regression_model(WeibullAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False) + df = _build_model_input_dataframe(adata, duration_col, accept_zero_duration=False) -def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> LogLogisticAFTFitter: + weibull_aft = WeibullAFTFitter( + alpha=alpha, + fit_intercept=fit_intercept, + penalizer=penalizer, + l1_ratio=l1_ratio, + model_ancillary=model_ancillary, + ) + + weibull_aft.fit( + df, + duration_col=duration_col, + event_col=event_col, + entry_col=entry_col, + ancillary=ancillary, + show_progress=show_progress, + weights_col=weights_col, + robust=robust, + initial_point=initial_point, + formula=formula, + fit_options=fit_options, + ) + + summary = weibull_aft.summary + adata.uns[uns_key] = summary + + return weibull_aft + + +def log_logistic_aft( + adata: AnnData, + duration_col: str, + event_col: str | None = None, + *, + uns_key: str = "log_logistic_aft", + alpha: float = 0.05, + fit_intercept: bool = True, + penalizer: float | np.ndarray = 0.0, + l1_ratio: float = 0.0, + model_ancillary: bool = False, + ancillary: bool | pd.DataFrame | str | None = None, + show_progress: bool = False, + weights_col: str | None = None, + robust: bool = False, + initial_point=None, + entry_col: str | None = None, + formula: str | None = None, + fit_options: dict | None = None, +) -> LogLogisticAFTFitter: """Fit the log logistic accelerated failure time regression for the survival function. The Log-Logistic Accelerated Failure Time (AFT) survival regression model is a powerful statistical tool employed in the analysis of time-to-event data. This model operates under the assumption that the logarithm of survival time adheres to a log-logistic distribution, offering a flexible framework for understanding the impact of covariates on survival times. By modeling survival time as a function of predictors, the Log-Logistic AFT model enables researchers to explore how specific factors influence the acceleration or deceleration of failure times, providing valuable insights into the underlying mechanisms driving event occurrence. + The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'log_logistic_aft' unless specified otherwise in the `uns_key` parameter. + See https://lifelines.readthedocs.io/en/latest/fitters/regression/LogLogisticAFTFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: Name of the column in anndata that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. + uns_key: The key to use for the `.uns` slot in the AnnData object. + alpha: The alpha value in the confidence intervals. + fit_intercept: Whether to fit an intercept term in the model. + penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates. + l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above. + model_ancillary: Set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization. + ancillary: Choose to model the ancillary parameters. + If None or False, explicitly do not fit the ancillary parameters using any covariates. + If True, model the ancillary parameters with the same covariates as ``df``. + If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``. + If str, should be a formula + show_progress: Since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing. + weights_col: The name of the column in DataFrame that contains the weights for each subject. + robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ. + initial_point: set the starting point for the iterative solver. entry_col: Column denoting when a subject entered the study, i.e. left-truncation. + formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/ + If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.) + fit_options: Additional keyword arguments to pass into the estimator. Returns: Fitted LogLogisticAFTFitter. @@ -443,18 +628,45 @@ def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_co >>> adata = ep.dt.mimic_2(encoded=False) >>> # Flip 'censor_fl' because 0 = death and 1 = censored >>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0) - >>> llf = ep.tl.log_logistic_aft(adata, "mort_day_censored", "censor_flg") + >>> adata = adata[:, ["mort_day_censored", "censor_flg"]] + >>> llf = ep.tl.log_logistic_aft(adata, duration_col="mort_day_censored", event_col="censor_flg") """ - return _regression_model( - LogLogisticAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False + df = _build_model_input_dataframe(adata, duration_col, accept_zero_duration=False) + + log_logistic_aft = LogLogisticAFTFitter( + alpha=alpha, + fit_intercept=fit_intercept, + penalizer=penalizer, + l1_ratio=l1_ratio, + model_ancillary=model_ancillary, ) + log_logistic_aft.fit( + df, + duration_col=duration_col, + event_col=event_col, + entry_col=entry_col, + ancillary=ancillary, + show_progress=show_progress, + weights_col=weights_col, + robust=robust, + initial_point=initial_point, + formula=formula, + fit_options=fit_options, + ) + + summary = log_logistic_aft.summary + adata.uns[uns_key] = summary + + return log_logistic_aft + def _univariate_model( adata: AnnData, duration_col: str, event_col: str, model_class, + uns_key: str, accept_zero_duration=True, timeline: list[float] | None = None, entry: str | None = None, @@ -466,10 +678,7 @@ def _univariate_model( censoring: Literal["right", "left"] = "right", ): """Convenience function for univariate models.""" - df = anndata_to_df(adata) - - if not accept_zero_duration: - df.loc[df[duration_col] == 0, duration_col] += 1e-5 + df = _build_model_input_dataframe(adata, duration_col, accept_zero_duration) T = df[duration_col] E = df[event_col] @@ -490,6 +699,14 @@ def _univariate_model( fit_options=fit_options, ) + if isinstance(model, NelsonAalenFitter) or isinstance( + model, KaplanMeierFitter + ): # NelsonAalenFitter and KaplanMeierFitter have no summary attribute + summary = model.event_table + else: + summary = model.summary + adata.uns[uns_key] = summary + return model @@ -498,6 +715,7 @@ def nelson_aalen( duration_col: str, event_col: str | None = None, *, + uns_key: str = "nelson_aalen", timeline: list[float] | None = None, entry: str | None = None, label: str | None = None, @@ -512,13 +730,16 @@ def nelson_aalen( The Nelson-Aalen estimator is a non-parametric method used in survival analysis to estimate the cumulative hazard function. This technique is particularly useful when dealing with censored data, as it accounts for the presence of individuals whose event times are unknown due to censoring. By estimating the cumulative hazard function, the Nelson-Aalen estimator allows researchers to assess the risk of an event occurring over time, providing valuable insights into the underlying dynamics of the survival process. + The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'nelson_aalen' unless specified otherwise in the `uns_key` parameter. See https://lifelines.readthedocs.io/en/latest/fitters/univariate/NelsonAalenFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: The name of the column in anndata that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. + uns_key: The key to use for the `.uns` slot in the AnnData object. timeline: Return the best estimate at the values in timelines (positively increasing) entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations. If None, all members of the population entered study when they were "born". @@ -547,7 +768,8 @@ def nelson_aalen( duration_col, event_col, NelsonAalenFitter, - True, + uns_key=uns_key, + accept_zero_duration=True, timeline=timeline, entry=entry, label=label, @@ -564,6 +786,7 @@ def weibull( duration_col: str, event_col: str, *, + uns_key: str = "weibull", timeline: list[float] | None = None, entry: str | None = None, label: str | None = None, @@ -580,14 +803,16 @@ def weibull( By fitting the Weibull model to censored survival data, researchers can estimate these parameters and gain insights into the hazard rate over time, facilitating comparisons between different groups or treatments. This method provides a comprehensive framework for examining survival data and offers valuable insights into the factors influencing event occurrence dynamics. + The results will be stored in the `.uns` slot of the :class:`AnnData` object under the key 'weibull' unless specified otherwise in the `uns_key` parameter. See https://lifelines.readthedocs.io/en/latest/fitters/univariate/WeibullFitter.html Args: - adata: AnnData object with necessary columns `duration_col` and `event_col`. + adata: AnnData object. duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes. - event_col: Name of the column in the AnnData object that contains the subjects’ death observation. - If left as None, assume all individuals are uncensored. - adata: AnnData object with necessary columns `duration_col` and `event_col`. + event_col: The name of the column in the AnnData object that specifies whether the event has been observed, or censored. + Column values are `True` if the event was observed, `False` if the event was lost (right-censored). + If left `None`, all individuals are assumed to be uncensored. + uns_key: The key to use for the `.uns` slot in the AnnData object. timeline: Return the best estimate at the values in timelines (positively increasing) entry: Relative time when a subject entered the study. This is useful for left-truncated (not left-censored) observations. If None, all members of the population entered study when they were "born". @@ -613,6 +838,7 @@ def weibull( duration_col, event_col, WeibullFitter, + uns_key=uns_key, accept_zero_duration=False, timeline=timeline, entry=entry, diff --git a/tests/tools/test_sa.py b/tests/tools/test_sa.py index 48d85b36..a383bfc0 100644 --- a/tests/tools/test_sa.py +++ b/tests/tools/test_sa.py @@ -84,16 +84,26 @@ def test_anova_glm(self): assert dataframe.iloc[1, 4] == 2 assert pytest.approx(dataframe.iloc[1, 5], 0.1) == 0.103185 - def _sa_function_assert(self, model, model_class): + def _sa_function_assert(self, model, model_class, adata=None): assert isinstance(model, model_class) assert len(model.durations) == 1776 assert sum(model.event_observed) == 497 + if adata is not None: # doing it disway, due to legacy kmf function + model_summary = adata.uns.get("test") + assert model_summary is not None + if isinstance(model, KaplanMeierFitter) or isinstance( + model, NelsonAalenFitter + ): # kmf and nelson_aalen have event_table + assert model_summary.equals(model.event_table) + else: + assert model_summary.equals(model.summary) + def _sa_func_test(self, sa_function, sa_class, mimic_2_sa): adata, duration_col, event_col = mimic_2_sa + sa = sa_function(adata, duration_col=duration_col, event_col=event_col, uns_key="test") - sa = sa_function(adata, duration_col, event_col) - self._sa_function_assert(sa, sa_class) + self._sa_function_assert(sa, sa_class, adata) def test_kmf(self, mimic_2_sa): # check for deprecation warning From bddb1f3aedda7f6514e95715b3dd72e370cc3450 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Jan 2025 22:43:12 +0100 Subject: [PATCH 10/10] [pre-commit.ci] pre-commit autoupdate (#852) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.8.6 → v0.9.1](https://github.com/astral-sh/ruff-pre-commit/compare/v0.8.6...v0.9.1) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- docs/_ext/typed_returns.py | 2 +- ehrapy/_settings.py | 4 ++-- ehrapy/anndata/_feature_specifications.py | 2 +- ehrapy/core/meta_information.py | 2 +- ehrapy/plot/_scanpy_pl_api.py | 10 ++++++++-- tests/preprocessing/test_quality_control.py | 6 +++--- 7 files changed, 17 insertions(+), 11 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61fd4955..0d2a7f34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: prettier - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.1 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] diff --git a/docs/_ext/typed_returns.py b/docs/_ext/typed_returns.py index 1ed85771..f625eff5 100644 --- a/docs/_ext/typed_returns.py +++ b/docs/_ext/typed_returns.py @@ -10,7 +10,7 @@ def _process_return(lines: Iterable[str]) -> Iterator[str]: m = re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line) if m: # Once this is in scanpydoc, we can use the fancy hover stuff - yield f'**{m["param"]}** : :class:`~{m["type"]}`' + yield f"**{m['param']}** : :class:`~{m['type']}`" else: yield line diff --git a/ehrapy/_settings.py b/ehrapy/_settings.py index f733c059..e9ddfe73 100644 --- a/ehrapy/_settings.py +++ b/ehrapy/_settings.py @@ -144,7 +144,7 @@ def file_format_data(self, file_format: str): _type_check(file_format, "file_format_data", str) file_format_options = {"csv", "h5ad"} if file_format not in file_format_options: - raise ValueError(f"Cannot set file_format_data to {file_format}. " f"Must be one of {file_format_options}") + raise ValueError(f"Cannot set file_format_data to {file_format}. Must be one of {file_format_options}") self._file_format_data = file_format @property @@ -237,7 +237,7 @@ def cache_compression(self) -> str | None: @cache_compression.setter def cache_compression(self, cache_compression: str | None): if cache_compression not in {"lzf", "gzip", None}: - raise ValueError(f"`cache_compression` ({cache_compression}) " "must be in {'lzf', 'gzip', None}") + raise ValueError(f"`cache_compression` ({cache_compression}) must be in {{'lzf', 'gzip', None}}") self._cache_compression = cache_compression @property diff --git a/ehrapy/anndata/_feature_specifications.py b/ehrapy/anndata/_feature_specifications.py index fdebb9ff..e3b8c2d7 100644 --- a/ehrapy/anndata/_feature_specifications.py +++ b/ehrapy/anndata/_feature_specifications.py @@ -113,7 +113,7 @@ def infer_feature_types( if verbose: logger.warning( - f"{'Features' if len(uncertain_features) >1 else 'Feature'} {str(uncertain_features)[1:-1]} {'were' if len(uncertain_features) >1 else 'was'} detected as categorical features stored numerically." + f"{'Features' if len(uncertain_features) > 1 else 'Feature'} {str(uncertain_features)[1:-1]} {'were' if len(uncertain_features) > 1 else 'was'} detected as categorical features stored numerically." f"Please verify and correct using `ep.ad.replace_feature_types` if necessary." ) diff --git a/ehrapy/core/meta_information.py b/ehrapy/core/meta_information.py index b08ba1ea..d9f702b2 100644 --- a/ehrapy/core/meta_information.py +++ b/ehrapy/core/meta_information.py @@ -23,7 +23,7 @@ def print_version_and_date(*, file=None): # pragma: no cover if file is None: file = sys.stdout print( - f"Running ehrapy {__version__}, " f"on {datetime.now():%Y-%m-%d %H:%M}.", + f"Running ehrapy {__version__}, on {datetime.now():%Y-%m-%d %H:%M}.", file=file, ) diff --git a/ehrapy/plot/_scanpy_pl_api.py b/ehrapy/plot/_scanpy_pl_api.py index 81791036..73272fe3 100644 --- a/ehrapy/plot/_scanpy_pl_api.py +++ b/ehrapy/plot/_scanpy_pl_api.py @@ -1188,7 +1188,10 @@ def tsne(adata, **kwargs) -> Axes | list[Axes] | None: # pragma: no cover .. image:: /_static/docstring_previews/tsne_1.png >>> ep.pl.tsne( - ... adata, color=["day_icu_intime", "service_unit"], wspace=0.5, title=["Day of ICU admission", "Service unit"] + ... adata, + ... color=["day_icu_intime", "service_unit"], + ... wspace=0.5, + ... title=["Day of ICU admission", "Service unit"], ... ) .. image:: /_static/docstring_previews/tsne_2.png @@ -1233,7 +1236,10 @@ def umap(adata: AnnData, **kwargs) -> Axes | list[Axes] | None: # pragma: no co .. image:: /_static/docstring_previews/umap_1.png >>> ep.pl.umap( - ... adata, color=["day_icu_intime", "service_unit"], wspace=0.5, title=["Day of ICU admission", "Service unit"] + ... adata, + ... color=["day_icu_intime", "service_unit"], + ... wspace=0.5, + ... title=["Day of ICU admission", "Service unit"], ... ) .. image:: /_static/docstring_previews/umap_2.png diff --git a/tests/preprocessing/test_quality_control.py b/tests/preprocessing/test_quality_control.py index eeff849e..dee27b3c 100644 --- a/tests/preprocessing/test_quality_control.py +++ b/tests/preprocessing/test_quality_control.py @@ -200,9 +200,9 @@ def test_qc_lab_measurements_multiple_measurements(): def test_mcar_test_method_output_types(mar_adata, method, expected_output_type): """Tests if mcar_test returns the correct output type for different methods.""" output = mcar_test(mar_adata, method=method) - assert isinstance( - output, expected_output_type - ), f"Output type for method '{method}' should be {expected_output_type}, got {type(output)} instead." + assert isinstance(output, expected_output_type), ( + f"Output type for method '{method}' should be {expected_output_type}, got {type(output)} instead." + ) def test_mar_data_identification(mar_adata):