Skip to content

Commit

Permalink
Refactor quality control metrics functions to streamline computation …
Browse files Browse the repository at this point in the history
…and improve readability
  • Loading branch information
aGuyLearning committed Feb 5, 2025
1 parent 16a5817 commit a603dbb
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 49 deletions.
57 changes: 15 additions & 42 deletions ehrapy/preprocessing/_quality_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,12 @@ def qc_metrics(
>>> obs_qc["missing_values_pct"].plot(kind="hist", bins=20)
"""

# obs_metrics = _obs_qc_metrics(adata, layer, qc_vars)
# var_metrics = _var_qc_metrics(adata, layer)
obs_metrics = pd.DataFrame(index=adata.obs_names)
var_metrics = pd.DataFrame(index=adata.var_names)

mtx = adata.X if layer is None else adata.layers[layer]
_compute_var_metrics(mtx, var_metrics, adata)
_compute_obs_metrics(mtx, obs_metrics, var_metrics, adata, qc_vars, log1p=True)
var_metrics = _compute_var_metrics(mtx, adata)
obs_metrics = _compute_obs_metrics(mtx, adata, qc_vars=qc_vars, log1p=True)

adata.obs[obs_metrics.columns] = obs_metrics
adata.var[var_metrics.columns] = var_metrics
adata.obs[obs_metrics.columns] = obs_metrics

return obs_metrics, var_metrics

Expand Down Expand Up @@ -98,9 +93,8 @@ def _missing_values(
@singledispatch
def _compute_obs_metrics(
arr,
obs_metrics: pd.DataFrame,
var_metrics: pd.DataFrame,
adata: AnnData,
*,
qc_vars: Collection[str],
log1p: bool,
):
Expand All @@ -120,19 +114,15 @@ def _compute_obs_metrics(
A Pandas DataFrame with the calculated metrics.
"""
_raise_array_type_not_implemented(_compute_obs_metrics, type(arr))
# TODO: add tests for this function


@_compute_obs_metrics.register(np.ndarray)
def _(
arr: np.array,
obs_metrics: pd.DataFrame,
var_metrics: pd.DataFrame,
adata: AnnData,
qc_vars: Collection[str],
log1p: bool,
):
# has no return, modifies obs_metrics and var_metrics in place
def _(arr: np.array, adata: AnnData, *, qc_vars: Collection[str] = (), log1p: bool = True):
obs_metrics = pd.DataFrame(index=adata.obs_names)
var_metrics = pd.DataFrame(index=adata.var_names)
mtx = copy.deepcopy(arr.astype(object))

if "encoding_mode" in adata.var:
for original_values_categorical in _get_encoded_features(adata):
index = np.where(var_metrics.index.str.contains(original_values_categorical))[0]
Expand Down Expand Up @@ -160,25 +150,12 @@ def _(
obs_metrics[f"total_features_{qc_var}"] / obs_metrics["total_features"] * 100
)


@_compute_obs_metrics.register(da.Array)
def _(
arr: da.Array,
obs_metrics: pd.DataFrame,
var_metrics: pd.DataFrame,
adata: AnnData,
qc_vars: Collection[str],
log1p: bool,
):
return _compute_obs_metrics(
arr.compute(), obs_metrics, var_metrics, adata, qc_vars, log1p
) # TODO: is it okay to compute here?
return obs_metrics


@singledispatch
def _compute_var_metrics(
arr,
var_metrics: pd.DataFrame,
adata: AnnData,
):
"""Compute variable metrics for quality control.
Expand All @@ -189,16 +166,17 @@ def _compute_var_metrics(
adata: Annotated data matrix.
"""
_raise_array_type_not_implemented(_compute_var_metrics, type(arr))
# TODO: add tests for this function


@_compute_var_metrics.register(np.ndarray)
def _(
arr: np.array,
var_metrics: pd.DataFrame,
adata: AnnData,
):
categorical_indices = np.ndarray([0], dtype=int)
mtx = copy.deepcopy(arr.astype(object))
var_metrics = pd.DataFrame(index=adata.var_names)

if "encoding_mode" in adata.var.keys():
for original_values_categorical in _get_encoded_features(adata):
Expand All @@ -215,8 +193,10 @@ def _(
mtx[:, index].shape[1],
)
categorical_indices = np.concatenate([categorical_indices, index])

non_categorical_indices = np.ones(mtx.shape[1], dtype=bool)
non_categorical_indices[categorical_indices] = False

var_metrics["missing_values_abs"] = np.apply_along_axis(_missing_values, 0, mtx, mode="abs")
var_metrics["missing_values_pct"] = np.apply_along_axis(_missing_values, 0, mtx, mode="pct", df_type="var")

Expand Down Expand Up @@ -259,14 +239,7 @@ def _(
# We assume that the data just hasn't been encoded yet
pass


@_compute_var_metrics.register(da.Array)
def _(
arr: da.Array,
var_metrics: pd.DataFrame,
adata: AnnData,
):
return _compute_var_metrics(arr.compute(), var_metrics, adata) # TODO: is it okay to compute here?
return var_metrics


def qc_lab_measurements(
Expand Down
3 changes: 1 addition & 2 deletions tests/preprocessing/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def _base_check_imputation(
AssertionError: If any of the checks fail.
"""
# if .x of the AnnData is a dask array, convert it to a numpy array
# TODO: look into a better way to handle this
if isinstance(adata_before_imputation.X, da.Array):
adata_before_imputation.X = adata_before_imputation.X.compute()
if isinstance(adata_after_imputation.X, da.Array):
Expand Down Expand Up @@ -319,7 +318,7 @@ def test_explicit_impute_types(impute_num_adata, array_type, expected_error):
explicit_impute(impute_num_adata, replacement=1011, copy=True)


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("array_type", [np.array]) # TODO: discuss, should we add a new fixture with supported types?
def test_explicit_impute_all(array_type, impute_num_adata):
impute_num_adata.X = array_type(impute_num_adata.X)
warnings.filterwarnings("ignore", category=FutureWarning)
Expand Down
52 changes: 47 additions & 5 deletions tests/preprocessing/test_quality_control.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from pathlib import Path

import dask.array as da
import numpy as np
import pandas as pd
import pytest
from anndata import AnnData
from scipy import sparse

import ehrapy as ep
from ehrapy.io._read import read_csv
from ehrapy.preprocessing._encoding import encode
from ehrapy.preprocessing._quality_control import _obs_qc_metrics, _var_qc_metrics, mcar_test
from ehrapy.preprocessing._quality_control import _compute_obs_metrics, _compute_var_metrics, mcar_test
from tests.conftest import TEST_DATA_PATH

CURRENT_DIR = Path(__file__).parent
Expand Down Expand Up @@ -64,14 +66,16 @@ def lab_measurements_layer_adata(obs_data, var_data):


def test_obs_qc_metrics(missing_values_adata):
obs_metrics = _obs_qc_metrics(missing_values_adata)
mtx = missing_values_adata.X
obs_metrics = _compute_obs_metrics(mtx, missing_values_adata)

assert np.array_equal(obs_metrics["missing_values_abs"].values, np.array([1, 2]))
assert np.allclose(obs_metrics["missing_values_pct"].values, np.array([33.3333, 66.6667]))


def test_var_qc_metrics(missing_values_adata):
var_metrics = _var_qc_metrics(missing_values_adata)
mtx = missing_values_adata.X
var_metrics = _compute_var_metrics(mtx, missing_values_adata)

assert np.array_equal(var_metrics["missing_values_abs"].values, np.array([1, 2, 0]))
assert np.allclose(var_metrics["missing_values_pct"].values, np.array([50.0, 100.0, 0.0]))
Expand All @@ -82,19 +86,57 @@ def test_var_qc_metrics(missing_values_adata):
assert (~var_metrics["iqr_outliers"]).all()


@pytest.mark.parametrize(
"array_type, expected_error",
[
(np.array, None),
# (da.array, NotImplementedError),
# (sparse.csr_matrix, NotImplementedError),
# TODO: currently disabled, due to sparse matrix not supporting datat type conversion
],
)
def test_obs_array_types(array_type, expected_error):
adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv")
adata.X = array_type(adata.X)
mtx = adata.X
if expected_error:
with pytest.raises(expected_error):
_compute_obs_metrics(mtx, adata)


def test_obs_nan_qc_metrics():
adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv")
adata.X[0][4] = np.nan
adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]})
obs_metrics = _obs_qc_metrics(adata2)
mtx = adata2.X
obs_metrics = _compute_obs_metrics(mtx, adata2)
assert obs_metrics.iloc[0].iloc[0] == 1


@pytest.mark.parametrize(
"array_type, expected_error",
[
(np.array, None),
# (da.array, NotImplementedError),
# (sparse.csr_matrix, NotImplementedError),
# TODO: currently disabled, due to sparse matrix not supporting datat type conversion
],
)
def test_var_array_types(array_type, expected_error):
adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv")
adata.X = array_type(adata.X)
mtx = adata.X
if expected_error:
with pytest.raises(expected_error):
_compute_var_metrics(mtx, adata)


def test_var_nan_qc_metrics():
adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv")
adata.X[0][4] = np.nan
adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]})
var_metrics = _var_qc_metrics(adata2)
mtx = adata2.X
var_metrics = _compute_var_metrics(mtx, adata2)
assert var_metrics.iloc[0].iloc[0] == 1
assert var_metrics.iloc[1].iloc[0] == 1
assert var_metrics.iloc[2].iloc[0] == 1
Expand Down

0 comments on commit a603dbb

Please sign in to comment.