Skip to content

Commit

Permalink
initial suggestions of array type checks on example of scale_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Dec 1, 2024
1 parent 03cd180 commit 877034d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
22 changes: 18 additions & 4 deletions ehrapy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from functools import singledispatch
from typing import TYPE_CHECKING

import dask.array as da
import numpy as np
import sklearn.preprocessing as sklearn_pp

Expand Down Expand Up @@ -69,6 +71,21 @@ def _scale_func_group(
return None


@singledispatch
def _scale_norm_function(arr):
raise NotImplementedError(f"scale_norm does not support data to be of type {type(arr)}")


@_scale_norm_function.register
def _(arr: np.ndarray, **kwargs):
return sklearn_pp.StandardScaler(**kwargs).fit_transform


@_scale_norm_function.register
def _(arr: da.Array, **kwargs):
return sklearn_pp.StandardScaler(**kwargs).fit_transform


def scale_norm(
adata: AnnData,
vars: str | Sequence[str] | None = None,
Expand Down Expand Up @@ -98,10 +115,7 @@ def scale_norm(
>>> adata_norm = ep.pp.scale_norm(adata, copy=True)
"""

if is_dask_array(adata.X):
scale_func = daskml_pp.StandardScaler(**kwargs).fit_transform
else:
scale_func = sklearn_pp.StandardScaler(**kwargs).fit_transform
scale_func = _scale_norm_function(adata.X, **kwargs)

return _scale_func_group(
adata=adata,
Expand Down
19 changes: 18 additions & 1 deletion tests/preprocessing/test_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from pathlib import Path

import dask.array as da
import numpy as np
import pandas as pd
import pytest
Expand All @@ -13,6 +14,7 @@
from tests.conftest import ARRAY_TYPES, TEST_DATA_PATH

CURRENT_DIR = Path(__file__).parent
from scipy import sparse


@pytest.fixture
Expand Down Expand Up @@ -74,7 +76,14 @@ def test_vars_checks(adata_to_norm):
ep.pp.scale_norm(adata_to_norm, vars=["String1"])


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
# TODO: where to list the supported types?
norm_scale_supported_types = [np.asarray, da.asarray]
norm_scale_unsupported_types = [sparse.csc_matrix]


# TODO: find consens for "minimal" test of ehrapy functions when make this casting test. vanilla settings, all defaults?
# even test for value matchings?
@pytest.mark.parametrize("array_type", norm_scale_supported_types)
def test_norm_scale(array_type, adata_to_norm):
"""Test for the scaling normalization method."""
warnings.filterwarnings("ignore")
Expand All @@ -94,6 +103,14 @@ def test_norm_scale(array_type, adata_to_norm):
assert np.allclose(adata_norm.X[:, 5], adata_to_norm_casted.X[:, 5], equal_nan=True)


@pytest.mark.parametrize("array_type", norm_scale_unsupported_types)
def test_norm_scale_notimplemented(array_type, adata_to_norm):
adata_to_norm_casted = adata_to_norm.copy()
adata_to_norm_casted.X = array_type(adata_to_norm_casted.X)
with pytest.raises(NotImplementedError):
ep.pp.scale_norm(adata_to_norm_casted)


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
def test_norm_scale_kwargs(array_type, adata_to_norm):
adata_to_norm_casted = adata_to_norm.copy()
Expand Down

0 comments on commit 877034d

Please sign in to comment.