diff --git a/ehrapy/preprocessing/_quality_control.py b/ehrapy/preprocessing/_quality_control.py index 103cb4cc..582f6a3c 100644 --- a/ehrapy/preprocessing/_quality_control.py +++ b/ehrapy/preprocessing/_quality_control.py @@ -59,17 +59,12 @@ def qc_metrics( >>> obs_qc["missing_values_pct"].plot(kind="hist", bins=20) """ - # obs_metrics = _obs_qc_metrics(adata, layer, qc_vars) - # var_metrics = _var_qc_metrics(adata, layer) - obs_metrics = pd.DataFrame(index=adata.obs_names) - var_metrics = pd.DataFrame(index=adata.var_names) - mtx = adata.X if layer is None else adata.layers[layer] - _compute_var_metrics(mtx, var_metrics, adata) - _compute_obs_metrics(mtx, obs_metrics, var_metrics, adata, qc_vars, log1p=True) + var_metrics = _compute_var_metrics(mtx, adata) + obs_metrics = _compute_obs_metrics(mtx, adata, qc_vars=qc_vars, log1p=True) - adata.obs[obs_metrics.columns] = obs_metrics adata.var[var_metrics.columns] = var_metrics + adata.obs[obs_metrics.columns] = obs_metrics return obs_metrics, var_metrics @@ -98,9 +93,8 @@ def _missing_values( @singledispatch def _compute_obs_metrics( arr, - obs_metrics: pd.DataFrame, - var_metrics: pd.DataFrame, adata: AnnData, + *, qc_vars: Collection[str], log1p: bool, ): @@ -120,19 +114,15 @@ def _compute_obs_metrics( A Pandas DataFrame with the calculated metrics. """ _raise_array_type_not_implemented(_compute_obs_metrics, type(arr)) + # TODO: add tests for this function @_compute_obs_metrics.register(np.ndarray) -def _( - arr: np.array, - obs_metrics: pd.DataFrame, - var_metrics: pd.DataFrame, - adata: AnnData, - qc_vars: Collection[str], - log1p: bool, -): - # has no return, modifies obs_metrics and var_metrics in place +def _(arr: np.array, adata: AnnData, *, qc_vars: Collection[str] = (), log1p: bool = True): + obs_metrics = pd.DataFrame(index=adata.obs_names) + var_metrics = pd.DataFrame(index=adata.var_names) mtx = copy.deepcopy(arr.astype(object)) + if "encoding_mode" in adata.var: for original_values_categorical in _get_encoded_features(adata): index = np.where(var_metrics.index.str.contains(original_values_categorical))[0] @@ -160,25 +150,12 @@ def _( obs_metrics[f"total_features_{qc_var}"] / obs_metrics["total_features"] * 100 ) - -@_compute_obs_metrics.register(da.Array) -def _( - arr: da.Array, - obs_metrics: pd.DataFrame, - var_metrics: pd.DataFrame, - adata: AnnData, - qc_vars: Collection[str], - log1p: bool, -): - return _compute_obs_metrics( - arr.compute(), obs_metrics, var_metrics, adata, qc_vars, log1p - ) # TODO: is it okay to compute here? + return obs_metrics @singledispatch def _compute_var_metrics( arr, - var_metrics: pd.DataFrame, adata: AnnData, ): """Compute variable metrics for quality control. @@ -189,16 +166,17 @@ def _compute_var_metrics( adata: Annotated data matrix. """ _raise_array_type_not_implemented(_compute_var_metrics, type(arr)) + # TODO: add tests for this function @_compute_var_metrics.register(np.ndarray) def _( arr: np.array, - var_metrics: pd.DataFrame, adata: AnnData, ): categorical_indices = np.ndarray([0], dtype=int) mtx = copy.deepcopy(arr.astype(object)) + var_metrics = pd.DataFrame(index=adata.var_names) if "encoding_mode" in adata.var.keys(): for original_values_categorical in _get_encoded_features(adata): @@ -215,8 +193,10 @@ def _( mtx[:, index].shape[1], ) categorical_indices = np.concatenate([categorical_indices, index]) + non_categorical_indices = np.ones(mtx.shape[1], dtype=bool) non_categorical_indices[categorical_indices] = False + var_metrics["missing_values_abs"] = np.apply_along_axis(_missing_values, 0, mtx, mode="abs") var_metrics["missing_values_pct"] = np.apply_along_axis(_missing_values, 0, mtx, mode="pct", df_type="var") @@ -259,14 +239,7 @@ def _( # We assume that the data just hasn't been encoded yet pass - -@_compute_var_metrics.register(da.Array) -def _( - arr: da.Array, - var_metrics: pd.DataFrame, - adata: AnnData, -): - return _compute_var_metrics(arr.compute(), var_metrics, adata) # TODO: is it okay to compute here? + return var_metrics def qc_lab_measurements( diff --git a/tests/preprocessing/test_imputation.py b/tests/preprocessing/test_imputation.py index 0729482f..d9382084 100644 --- a/tests/preprocessing/test_imputation.py +++ b/tests/preprocessing/test_imputation.py @@ -49,7 +49,6 @@ def _base_check_imputation( AssertionError: If any of the checks fail. """ # if .x of the AnnData is a dask array, convert it to a numpy array - # TODO: look into a better way to handle this if isinstance(adata_before_imputation.X, da.Array): adata_before_imputation.X = adata_before_imputation.X.compute() if isinstance(adata_after_imputation.X, da.Array): @@ -319,7 +318,7 @@ def test_explicit_impute_types(impute_num_adata, array_type, expected_error): explicit_impute(impute_num_adata, replacement=1011, copy=True) -@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("array_type", [np.array]) # TODO: discuss, should we add a new fixture with supported types? def test_explicit_impute_all(array_type, impute_num_adata): impute_num_adata.X = array_type(impute_num_adata.X) warnings.filterwarnings("ignore", category=FutureWarning) diff --git a/tests/preprocessing/test_quality_control.py b/tests/preprocessing/test_quality_control.py index dee27b3c..6a0c2fbd 100644 --- a/tests/preprocessing/test_quality_control.py +++ b/tests/preprocessing/test_quality_control.py @@ -1,14 +1,16 @@ from pathlib import Path +import dask.array as da import numpy as np import pandas as pd import pytest from anndata import AnnData +from scipy import sparse import ehrapy as ep from ehrapy.io._read import read_csv from ehrapy.preprocessing._encoding import encode -from ehrapy.preprocessing._quality_control import _obs_qc_metrics, _var_qc_metrics, mcar_test +from ehrapy.preprocessing._quality_control import _compute_obs_metrics, _compute_var_metrics, mcar_test from tests.conftest import TEST_DATA_PATH CURRENT_DIR = Path(__file__).parent @@ -64,14 +66,16 @@ def lab_measurements_layer_adata(obs_data, var_data): def test_obs_qc_metrics(missing_values_adata): - obs_metrics = _obs_qc_metrics(missing_values_adata) + mtx = missing_values_adata.X + obs_metrics = _compute_obs_metrics(mtx, missing_values_adata) assert np.array_equal(obs_metrics["missing_values_abs"].values, np.array([1, 2])) assert np.allclose(obs_metrics["missing_values_pct"].values, np.array([33.3333, 66.6667])) def test_var_qc_metrics(missing_values_adata): - var_metrics = _var_qc_metrics(missing_values_adata) + mtx = missing_values_adata.X + var_metrics = _compute_var_metrics(mtx, missing_values_adata) assert np.array_equal(var_metrics["missing_values_abs"].values, np.array([1, 2, 0])) assert np.allclose(var_metrics["missing_values_pct"].values, np.array([50.0, 100.0, 0.0])) @@ -82,19 +86,57 @@ def test_var_qc_metrics(missing_values_adata): assert (~var_metrics["iqr_outliers"]).all() +@pytest.mark.parametrize( + "array_type, expected_error", + [ + (np.array, None), + # (da.array, NotImplementedError), + # (sparse.csr_matrix, NotImplementedError), + # TODO: currently disabled, due to sparse matrix not supporting datat type conversion + ], +) +def test_obs_array_types(array_type, expected_error): + adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv") + adata.X = array_type(adata.X) + mtx = adata.X + if expected_error: + with pytest.raises(expected_error): + _compute_obs_metrics(mtx, adata) + + def test_obs_nan_qc_metrics(): adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv") adata.X[0][4] = np.nan adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]}) - obs_metrics = _obs_qc_metrics(adata2) + mtx = adata2.X + obs_metrics = _compute_obs_metrics(mtx, adata2) assert obs_metrics.iloc[0].iloc[0] == 1 +@pytest.mark.parametrize( + "array_type, expected_error", + [ + (np.array, None), + # (da.array, NotImplementedError), + # (sparse.csr_matrix, NotImplementedError), + # TODO: currently disabled, due to sparse matrix not supporting datat type conversion + ], +) +def test_var_array_types(array_type, expected_error): + adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv") + adata.X = array_type(adata.X) + mtx = adata.X + if expected_error: + with pytest.raises(expected_error): + _compute_var_metrics(mtx, adata) + + def test_var_nan_qc_metrics(): adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv") adata.X[0][4] = np.nan adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]}) - var_metrics = _var_qc_metrics(adata2) + mtx = adata2.X + var_metrics = _compute_var_metrics(mtx, adata2) assert var_metrics.iloc[0].iloc[0] == 1 assert var_metrics.iloc[1].iloc[0] == 1 assert var_metrics.iloc[2].iloc[0] == 1