Skip to content

Commit

Permalink
Merge branch 'main' into enhancement/issue-743
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell authored Jan 8, 2025
2 parents f1b648b + 5cc736d commit de4572d
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 156 deletions.
4 changes: 2 additions & 2 deletions .github/release-drafter.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name-template: "0.11.0 🌈"
tag-template: 0.11.0
name-template: "0.12.0 🌈"
tag-template: 0.12.0
exclude-labels:
- "skip-changelog"

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/run_notebooks.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: Run Notebooks

on:
- push
- pull_request

jobs:
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ default_stages:
- pre-push
minimum_pre_commit_version: 2.16.0
repos:
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
- repo: https://github.com/rbubley/mirrors-prettier
rev: v3.4.2
hooks:
- id: prettier
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
rev: v0.8.6
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes]
Expand All @@ -27,7 +27,7 @@ repos:
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.14.1
hooks:
- id: mypy
args: [--no-strict-optional, --ignore-missing-imports]
Expand Down
7 changes: 6 additions & 1 deletion ehrapy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

__author__ = "Lukas Heumos"
__email__ = "[email protected]"
__version__ = "0.9.0"
__version__ = "0.12.0"

import os

# https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html
os.environ["SCIPY_ARRAY_API"] = "1"

from ehrapy._settings import EhrapyConfig, ehrapy_settings

Expand Down
9 changes: 7 additions & 2 deletions ehrapy/_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Since we might check whether an object is an instance of dask.array.Array
# without requiring dask installed in the environment.
# This would become obsolete should dask become a requirement for ehrapy

from collections.abc import Callable

try:
import dask.array as da
Expand All @@ -11,6 +10,12 @@
DASK_AVAILABLE = False


def _raise_array_type_not_implemented(func: Callable, type_: type) -> NotImplementedError:
raise NotImplementedError(
f"{func.__name__} does not support array type {type_}. Must be of type {func.registry.keys()}." # type: ignore
)


def is_dask_array(array):
if DASK_AVAILABLE:
return isinstance(array, da.Array)
Expand Down
49 changes: 9 additions & 40 deletions ehrapy/core/meta_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import sys
from datetime import datetime

import session_info
from rich import print
from scanpy.logging import _versions_dependencies

from ehrapy import __version__

Expand All @@ -17,23 +15,7 @@ def print_versions(): # pragma: no cover
>>> import ehrapy as ep
>>> ep.print_versions()
"""
try:
session_info.show(
dependencies=True,
html=False,
excludes=[
"builtins",
"stdlib_list",
"importlib_metadata",
"jupyter_core"
# Special module present if test coverage being calculated
# https://gitlab.com/joelostblom/session_info/-/issues/10
"$coverage",
],
)
except AttributeError:
print("[bold yellow]Unable to fetch versions for one or more dependencies.")
pass
print_header()


def print_version_and_date(*, file=None): # pragma: no cover
Expand All @@ -47,26 +29,13 @@ def print_version_and_date(*, file=None): # pragma: no cover


def print_header(*, file=None): # pragma: no cover
"""Versions that might influence the numerical results.
"""Versions that might influence the numerical results."""
from session_info2 import session_info

Matplotlib and Seaborn are excluded from this.
"""
_DEPENDENCIES_NUMERICS = [
"scanpy",
"anndata",
"umap",
"numpy",
"scipy",
"pandas",
("sklearn", "scikit-learn"),
"statsmodels",
("igraph", "python-igraph"),
"leidenalg",
"pynndescent",
]
sinfo = session_info(os=True, cpu=True, gpu=True, dependencies=True)

modules = ["ehrapy"] + _DEPENDENCIES_NUMERICS
print(
" ".join(f"{mod}=={ver}" for mod, ver in _versions_dependencies(modules)),
file=file or sys.stdout,
)
if file is not None:
print(sinfo, file=file)
return

return sinfo
3 changes: 2 additions & 1 deletion ehrapy/plot/_scanpy_pl_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1984,7 +1984,8 @@ def paga_compare(
save=save,
title_graph=title_graph,
groups_graph=groups_graph,
pos=pos**paga_graph_params,
pos=pos,
**paga_graph_params,
)


Expand Down
130 changes: 105 additions & 25 deletions ehrapy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING

import numpy as np
import sklearn.preprocessing as sklearn_pp

from ehrapy._compat import is_dask_array
from ehrapy._compat import _raise_array_type_not_implemented

try:
import dask.array as da
import dask_ml.preprocessing as daskml_pp

DASK_AVAILABLE = True
except ImportError:
daskml_pp = None
DASK_AVAILABLE = False


from ehrapy.anndata.anndata_ext import (
assert_numeric_vars,
Expand Down Expand Up @@ -69,6 +75,23 @@ def _scale_func_group(
return None


@singledispatch
def _scale_norm_function(arr):
_raise_array_type_not_implemented(_scale_norm_function, type(arr))


@_scale_norm_function.register
def _(arr: np.ndarray, **kwargs):
return sklearn_pp.StandardScaler(**kwargs).fit_transform


if DASK_AVAILABLE:

@_scale_norm_function.register
def _(arr: da.Array, **kwargs):
return daskml_pp.StandardScaler(**kwargs).fit_transform


def scale_norm(
adata: AnnData,
vars: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -98,10 +121,7 @@ def scale_norm(
>>> adata_norm = ep.pp.scale_norm(adata, copy=True)
"""

if is_dask_array(adata.X):
scale_func = daskml_pp.StandardScaler(**kwargs).fit_transform
else:
scale_func = sklearn_pp.StandardScaler(**kwargs).fit_transform
scale_func = _scale_norm_function(adata.X, **kwargs)

return _scale_func_group(
adata=adata,
Expand All @@ -113,6 +133,23 @@ def scale_norm(
)


@singledispatch
def _minmax_norm_function(arr):
_raise_array_type_not_implemented(_minmax_norm_function, type(arr))


@_minmax_norm_function.register
def _(arr: np.ndarray, **kwargs):
return sklearn_pp.MinMaxScaler(**kwargs).fit_transform


if DASK_AVAILABLE:

@_minmax_norm_function.register
def _(arr: da.Array, **kwargs):
return daskml_pp.MinMaxScaler(**kwargs).fit_transform


def minmax_norm(
adata: AnnData,
vars: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -143,10 +180,7 @@ def minmax_norm(
>>> adata_norm = ep.pp.minmax_norm(adata, copy=True)
"""

if is_dask_array(adata.X):
scale_func = daskml_pp.MinMaxScaler(**kwargs).fit_transform
else:
scale_func = sklearn_pp.MinMaxScaler(**kwargs).fit_transform
scale_func = _minmax_norm_function(adata.X, **kwargs)

return _scale_func_group(
adata=adata,
Expand All @@ -158,6 +192,16 @@ def minmax_norm(
)


@singledispatch
def _maxabs_norm_function(arr):
_raise_array_type_not_implemented(_scale_norm_function, type(arr))


@_maxabs_norm_function.register
def _(arr: np.ndarray):
return sklearn_pp.MaxAbsScaler().fit_transform


def maxabs_norm(
adata: AnnData,
vars: str | Sequence[str] | None = None,
Expand All @@ -184,10 +228,8 @@ def maxabs_norm(
>>> adata = ep.dt.mimic_2(encoded=True)
>>> adata_norm = ep.pp.maxabs_norm(adata, copy=True)
"""
if is_dask_array(adata.X):
raise NotImplementedError("MaxAbsScaler is not implemented in dask_ml.")
else:
scale_func = sklearn_pp.MaxAbsScaler().fit_transform

scale_func = _maxabs_norm_function(adata.X)

return _scale_func_group(
adata=adata,
Expand All @@ -199,6 +241,23 @@ def maxabs_norm(
)


@singledispatch
def _robust_scale_norm_function(arr, **kwargs):
_raise_array_type_not_implemented(_robust_scale_norm_function, type(arr))


@_robust_scale_norm_function.register
def _(arr: np.ndarray, **kwargs):
return sklearn_pp.RobustScaler(**kwargs).fit_transform


if DASK_AVAILABLE:

@_robust_scale_norm_function.register
def _(arr: da.Array, **kwargs):
return daskml_pp.RobustScaler(**kwargs).fit_transform


def robust_scale_norm(
adata: AnnData,
vars: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -229,10 +288,8 @@ def robust_scale_norm(
>>> adata = ep.dt.mimic_2(encoded=True)
>>> adata_norm = ep.pp.robust_scale_norm(adata, copy=True)
"""
if is_dask_array(adata.X):
scale_func = daskml_pp.RobustScaler(**kwargs).fit_transform
else:
scale_func = sklearn_pp.RobustScaler(**kwargs).fit_transform

scale_func = _robust_scale_norm_function(adata.X, **kwargs)

return _scale_func_group(
adata=adata,
Expand All @@ -244,6 +301,23 @@ def robust_scale_norm(
)


@singledispatch
def _quantile_norm_function(arr):
_raise_array_type_not_implemented(_quantile_norm_function, type(arr))


@_quantile_norm_function.register
def _(arr: np.ndarray, **kwargs):
return sklearn_pp.QuantileTransformer(**kwargs).fit_transform


if DASK_AVAILABLE:

@_quantile_norm_function.register
def _(arr: da.Array, **kwargs):
return daskml_pp.QuantileTransformer(**kwargs).fit_transform


def quantile_norm(
adata: AnnData,
vars: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -273,10 +347,8 @@ def quantile_norm(
>>> adata = ep.dt.mimic_2(encoded=True)
>>> adata_norm = ep.pp.quantile_norm(adata, copy=True)
"""
if is_dask_array(adata.X):
scale_func = daskml_pp.QuantileTransformer(**kwargs).fit_transform
else:
scale_func = sklearn_pp.QuantileTransformer(**kwargs).fit_transform

scale_func = _quantile_norm_function(adata.X, **kwargs)

return _scale_func_group(
adata=adata,
Expand All @@ -288,6 +360,16 @@ def quantile_norm(
)


@singledispatch
def _power_norm_function(arr, **kwargs):
_raise_array_type_not_implemented(_power_norm_function, type(arr))


@_power_norm_function.register
def _(arr: np.ndarray, **kwargs):
return sklearn_pp.PowerTransformer(**kwargs).fit_transform


def power_norm(
adata: AnnData,
vars: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -317,10 +399,8 @@ def power_norm(
>>> adata = ep.dt.mimic_2(encoded=True)
>>> adata_norm = ep.pp.power_norm(adata, copy=True)
"""
if is_dask_array(adata.X):
raise NotImplementedError("dask-ml has no PowerTransformer, this is only available in scikit-learn")
else:
scale_func = sklearn_pp.PowerTransformer(**kwargs).fit_transform

scale_func = _power_norm_function(adata.X, **kwargs)

return _scale_func_group(
adata=adata,
Expand Down
Loading

0 comments on commit de4572d

Please sign in to comment.