From d87140623f2118e494874549752987e89be235f3 Mon Sep 17 00:00:00 2001 From: Optimox Date: Fri, 9 Oct 2020 19:11:31 +0200 Subject: [PATCH] feat: add check nan and inf --- pytorch_tabnet/abstract_model.py | 3 +++ pytorch_tabnet/utils.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 51ca2915..45be7ef6 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -11,6 +11,7 @@ create_explain_matrix, validate_eval_set, create_dataloaders, + check_nans, ) from pytorch_tabnet.callbacks import ( CallbackContainer, @@ -140,6 +141,8 @@ def fit( else: self.loss_fn = loss_fn + check_nans(X_train) + check_nans(y_train) self.update_fit_params( X_train, y_train, eval_set, weights, ) diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index efe740d7..5c1080e1 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -237,6 +237,8 @@ def validate_eval_set(eval_set, eval_name, X_train, y_train): len(elem) == 2 for elem in eval_set ), "Each tuple of eval_set need to have two elements" for name, (X, y) in zip(eval_name, eval_set): + check_nans(X) + check_nans(y) msg = ( f"Number of columns is different between X_{name} " + f"({X.shape[1]}) and X_train ({X_train.shape[1]})" @@ -255,3 +257,10 @@ def validate_eval_set(eval_set, eval_name, X_train, y_train): assert X.shape[0] == y.shape[0], msg return eval_name, eval_set + + +def check_nans(array): + if np.isnan(array).any(): + raise ValueError("NaN were found, TabNet does not allow nans.") + if np.isinf(array).any(): + raise ValueError("Infinite values were found, TabNet does not allow inf.")