Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add input checks for imputers #625

Merged
merged 1 commit into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 131 additions & 102 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,30 +239,34 @@ def knn_impute(
print(
"[bold yellow]scikit-learn-intelex is not available. Install via [blue]pip install scikit-learn-intelex [yellow] for faster imputations."
)

with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running KNN imputation", total=1)
# numerical only data needs no encoding since KNN Imputation can be applied directly
if np.issubdtype(adata.X.dtype, np.number):
_knn_impute(adata, var_names, n_neighbours)
else:
# ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using KNN imputation
_knn_impute(adata, var_names, n_neighbours)
# imputing on encoded columns might result in float numbers; those can not be decoded
# cast them to int to ensure they can be decoded
adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int)
# knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding
adata.X = adata.X.astype("object")
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])
try:
with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running KNN imputation", total=1)
# numerical only data needs no encoding since KNN Imputation can be applied directly
if np.issubdtype(adata.X.dtype, np.number):
_knn_impute(adata, var_names, n_neighbours)
else:
# ordinal encoding is used since non-numerical data can not be imputed using KNN Imputation
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using KNN imputation
_knn_impute(adata, var_names, n_neighbours)
# imputing on encoded columns might result in float numbers; those can not be decoded
# cast them to int to ensure they can be decoded
adata.X[::, column_indices] = np.rint(adata.X[::, column_indices]).astype(int)
# knn imputer transforms X dtype to numerical (encoded), but object is needed for decoding
adata.X = adata.X.astype("object")
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])
except ValueError as e:
if "Data matrix has wrong shape" in str(e):
print("[bold red]Check that your matrix does not contain any NaN values!")
raise

if _check_module_importable("sklearnex"): # pragma: no cover
unpatch_sklearn()
Expand Down Expand Up @@ -356,62 +360,69 @@ def miss_forest_impute(
from sklearn.ensemble import ExtraTreesRegressor, RandomForestClassifier
from sklearn.impute import IterativeImputer

with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running MissForest imputation", total=1)

if settings.n_jobs == 1: # pragma: no cover
print("[bold yellow]The number of jobs is only 1. To decrease the runtime set [blue]ep.settings.n_jobs=-1.")

imp_num = IterativeImputer(
estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy=num_initial_strategy,
max_iter=max_iter,
random_state=random_state,
)
# initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data
imp_cat = IterativeImputer(
estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy="most_frequent",
max_iter=max_iter,
random_state=random_state,
)

if isinstance(var_names, list):
var_indices = _get_column_indices(adata, var_names) # type: ignore
adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices])
elif isinstance(var_names, dict) or var_names is None:
if var_names:
try:
non_num_vars = var_names["non_numerical"]
num_vars = var_names["numerical"]
except KeyError: # pragma: no cover
raise ValueError(
"One or both of your keys provided for var_names are unknown. Only "
"numerical and non_numerical are available!"
) from None
non_num_indices = _get_column_indices(adata, non_num_vars)
num_indices = _get_column_indices(adata, num_vars)

# infer non numerical and numerical indices automatically
else:
non_num_indices_set = _get_non_numerical_column_indices(adata.X)
num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set]
non_num_indices = list(non_num_indices_set)
try:
with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running MissForest imputation", total=1)

if settings.n_jobs == 1: # pragma: no cover
print(
"[bold yellow]The number of jobs is only 1. To decrease the runtime set [blue]ep.settings.n_jobs=-1."
)

imp_num = IterativeImputer(
estimator=ExtraTreesRegressor(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy=num_initial_strategy,
max_iter=max_iter,
random_state=random_state,
)
# initial strategy here will not be parametrized since only most_frequent will be applied to non numerical data
imp_cat = IterativeImputer(
estimator=RandomForestClassifier(n_estimators=n_estimators, n_jobs=settings.n_jobs),
initial_strategy="most_frequent",
max_iter=max_iter,
random_state=random_state,
)

# encode all non numerical columns
if non_num_indices:
enc = OrdinalEncoder()
adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices])
# this step is the most expensive one and might extremely slow down the impute process
if num_indices:
adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices])
if non_num_indices:
adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices])
adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices])
if isinstance(var_names, list):
var_indices = _get_column_indices(adata, var_names) # type: ignore
adata.X[::, var_indices] = imp_num.fit_transform(adata.X[::, var_indices])
elif isinstance(var_names, dict) or var_names is None:
if var_names:
try:
non_num_vars = var_names["non_numerical"]
num_vars = var_names["numerical"]
except KeyError: # pragma: no cover
raise ValueError(
"One or both of your keys provided for var_names are unknown. Only "
"numerical and non_numerical are available!"
) from None
non_num_indices = _get_column_indices(adata, non_num_vars)
num_indices = _get_column_indices(adata, num_vars)

# infer non numerical and numerical indices automatically
else:
non_num_indices_set = _get_non_numerical_column_indices(adata.X)
num_indices = [idx for idx in range(adata.X.shape[1]) if idx not in non_num_indices_set]
non_num_indices = list(non_num_indices_set)

# encode all non numerical columns
if non_num_indices:
enc = OrdinalEncoder()
adata.X[::, non_num_indices] = enc.fit_transform(adata.X[::, non_num_indices])
# this step is the most expensive one and might extremely slow down the impute process
if num_indices:
adata.X[::, num_indices] = imp_num.fit_transform(adata.X[::, num_indices])
if non_num_indices:
adata.X[::, non_num_indices] = imp_cat.fit_transform(adata.X[::, non_num_indices])
adata.X[::, non_num_indices] = enc.inverse_transform(adata.X[::, non_num_indices])
except ValueError as e:
if "Data matrix has wrong shape" in str(e):
print("[bold red]Check that your matrix does not contain any NaN values!")
raise

if _check_module_importable("sklearnex"): # pragma: no cover
unpatch_sklearn()
Expand Down Expand Up @@ -1025,29 +1036,47 @@ def mice_forest_impute(
adata = adata.copy()

_warn_imputation_threshold(adata, var_names, threshold=warning_threshold)

with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running miceforest", total=1)
if np.issubdtype(adata.X.dtype, np.number):
_miceforest_impute(
adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose
)
else:
# ordinal encoding is used since non-numerical data can not be imputed using miceforest
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using miceforest
_miceforest_impute(
adata, var_names, save_all_iterations, random_state, inplace, iterations, variable_parameters, verbose
)
adata.X = adata.X.astype("object")
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])
try:
with Progress(
"[progress.description]{task.description}",
SpinnerColumn(),
refresh_per_second=1500,
) as progress:
progress.add_task("[blue]Running miceforest", total=1)
if np.issubdtype(adata.X.dtype, np.number):
_miceforest_impute(
adata,
var_names,
save_all_iterations,
random_state,
inplace,
iterations,
variable_parameters,
verbose,
)
else:
# ordinal encoding is used since non-numerical data can not be imputed using miceforest
enc = OrdinalEncoder()
column_indices = _get_column_indices(adata, adata.uns["non_numerical_columns"])
adata.X[::, column_indices] = enc.fit_transform(adata.X[::, column_indices])
# impute the data using miceforest
_miceforest_impute(
adata,
var_names,
save_all_iterations,
random_state,
inplace,
iterations,
variable_parameters,
verbose,
)
adata.X = adata.X.astype("object")
# decode ordinal encoding to obtain imputed original data
adata.X[::, column_indices] = enc.inverse_transform(adata.X[::, column_indices])
except ValueError as e:
if "Data matrix has wrong shape" in str(e):
print("[bold red]Check that your matrix does not contain any NaN values!")
raise

if var_names:
logg.debug(
Expand Down
8 changes: 0 additions & 8 deletions tests/anndata/test_anndata_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,6 @@ def _setup_anndata_to_df() -> tuple[list, list, list]:

return col1_val, col2_val, col3_val

def test_generate_anndata(self):
adata = generate_anndata((3, 3), include_nlp=False)
assert adata.X.shape == (3, 3)

adata = generate_anndata((2, 2), include_nlp=True)
assert adata.X.shape == (2, 2)
assert "nlp" in adata.obs.columns


class TestAnnDataUtil:
def setup_method(self):
Expand Down
Loading