diff --git a/ehrapy/_compat.py b/ehrapy/_compat.py index 5497eead..37c38c2d 100644 --- a/ehrapy/_compat.py +++ b/ehrapy/_compat.py @@ -1,7 +1,6 @@ # Since we might check whether an object is an instance of dask.array.Array # without requiring dask installed in the environment. -# This would become obsolete should dask become a requirement for ehrapy - +from collections.abc import Callable try: import dask.array as da @@ -11,6 +10,12 @@ DASK_AVAILABLE = False +def _raise_array_type_not_implemented(func: Callable, type_: type) -> NotImplementedError: + return NotImplementedError( + f"{func.__name__} does not support array type {type_}. Must be of type {func.registry.keys()}." # type: ignore + ) + + def is_dask_array(array): if DASK_AVAILABLE: return isinstance(array, da.Array) diff --git a/ehrapy/preprocessing/_normalization.py b/ehrapy/preprocessing/_normalization.py index af39879d..67c41d02 100644 --- a/ehrapy/preprocessing/_normalization.py +++ b/ehrapy/preprocessing/_normalization.py @@ -6,7 +6,7 @@ import numpy as np import sklearn.preprocessing as sklearn_pp -from ehrapy._compat import is_dask_array +from ehrapy._compat import _raise_array_type_not_implemented try: import dask.array as da @@ -77,7 +77,7 @@ def _scale_func_group( @singledispatch def _scale_norm_function(arr): - raise NotImplementedError(f"scale_norm does not support data to be of type {type(arr)}") + _raise_array_type_not_implemented(_scale_norm_function, type(arr)) @_scale_norm_function.register @@ -135,7 +135,7 @@ def scale_norm( @singledispatch def _minmax_norm_function(arr): - raise NotImplementedError(f"minmax_norm does not support data to be of type {type(arr)}") + _raise_array_type_not_implemented(_minmax_norm_function, type(arr)) @_minmax_norm_function.register @@ -194,7 +194,7 @@ def minmax_norm( @singledispatch def _maxabs_norm_function(arr): - raise NotImplementedError(f"maxabs_norm does not support data to be of type {type(arr)}") + _raise_array_type_not_implemented(_scale_norm_function, type(arr)) @_maxabs_norm_function.register @@ -243,7 +243,7 @@ def maxabs_norm( @singledispatch def _robust_scale_norm_function(arr, **kwargs): - raise NotImplementedError(f"robust_scale_norm does not support data to be of type {type(arr)}") + _raise_array_type_not_implemented(_robust_scale_norm_function, type(arr)) @_robust_scale_norm_function.register @@ -303,7 +303,7 @@ def robust_scale_norm( @singledispatch def _quantile_norm_function(arr): - raise NotImplementedError(f"robust_scale_norm does not support data to be of type {type(arr)}") + _raise_array_type_not_implemented(_quantile_norm_function, type(arr)) @_quantile_norm_function.register @@ -362,7 +362,7 @@ def quantile_norm( @singledispatch def _power_norm_function(arr, **kwargs): - raise NotImplementedError(f"power_norm does not support data to be of type {type(arr)}") + _raise_array_type_not_implemented(_power_norm_function, type(arr)) @_power_norm_function.register diff --git a/pyproject.toml b/pyproject.toml index 654d9092..a54b680c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ medcat = [ "medcat", ] dask = [ - "dask", + "anndata[dask]", "dask-ml", ] dev = [ diff --git a/tests/preprocessing/test_normalization.py b/tests/preprocessing/test_normalization.py index 18eeef49..1507c5c5 100644 --- a/tests/preprocessing/test_normalization.py +++ b/tests/preprocessing/test_normalization.py @@ -89,11 +89,6 @@ def test_vars_checks(adata_to_norm): ep.pp.scale_norm(adata_to_norm, vars=["String1"]) -# TODO: list the supported array types centrally? -norm_scale_supported_types = [np.asarray, da.asarray] -norm_scale_unsupported_types = [sparse.csc_matrix] - - # TODO: check this for each function, with just default settings? @pytest.mark.parametrize( "array_type,expected_error",