diff --git a/README.md b/README.md index 78400617..c57b138b 100644 --- a/README.md +++ b/README.md @@ -411,3 +411,7 @@ loaded_clf.load_model(saved_filepath) - `warm_start` : bool (default=False) In order to match scikit-learn API, this is set to False. It allows to fit twice the same model and start from a warm start. + +- `compute_importance` : bool (default=True) + + Whether to compute feature importance diff --git a/census_example.ipynb b/census_example.ipynb index 0be04b0a..8603542f 100755 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -272,26 +272,50 @@ "scrolled": true }, "outputs": [], - "source": [ - "# This illustrates the warm_start=False behaviour\n", - "save_history = []\n", - "for _ in range(2):\n", - " clf.fit(\n", - " X_train=X_train, y_train=y_train,\n", - " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", - " eval_name=['train', 'valid'],\n", - " eval_metric=['auc'],\n", - " max_epochs=max_epochs , patience=20,\n", - " batch_size=1024, virtual_batch_size=128,\n", - " num_workers=0,\n", - " weights=1,\n", - " drop_last=False,\n", - " augmentations=aug, #aug, None\n", - " )\n", - " save_history.append(clf.history[\"valid_auc\"])\n", - "\n", - "assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))" - ] + "source": [ + "# This illustrates the warm_start=False behaviour\n", + "save_history = []\n", + "\n", + "# Fitting the model without starting from a warm start nor computing the feature importance\n", + "for _ in range(2):\n", + " clf.fit(\n", + " X_train=X_train, y_train=y_train,\n", + " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", + " eval_name=['train', 'valid'],\n", + " eval_metric=['auc'],\n", + " max_epochs=max_epochs , patience=20,\n", + " batch_size=1024, virtual_batch_size=128,\n", + " num_workers=0,\n", + " weights=1,\n", + " drop_last=False,\n", + " augmentations=aug, #aug, None\n", + " compute_importance=False\n", + " )\n", + " save_history.append(clf.history[\"valid_auc\"])\n", + "\n", + "assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))\n", + "\n", + "save_history = [] # Resetting the list to show that it also works when computing feature importance\n", + "\n", + "# Fitting the model without starting from a warm start but with the computing of the feature importance activated\n", + "for _ in range(2):\n", + " clf.fit(\n", + " X_train=X_train, y_train=y_train,\n", + " eval_set=[(X_train, y_train), (X_valid, y_valid)],\n", + " eval_name=['train', 'valid'],\n", + " eval_metric=['auc'],\n", + " max_epochs=max_epochs , patience=20,\n", + " batch_size=1024, virtual_batch_size=128,\n", + " num_workers=0,\n", + " weights=1,\n", + " drop_last=False,\n", + " augmentations=aug, #aug, None\n", + " compute_importance=True # True by default so not needed\n", + " )\n", + " save_history.append(clf.history[\"valid_auc\"])\n", + "\n", + "assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))" + ] }, { "cell_type": "code", diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 4ccf58d4..039960f5 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -138,6 +138,7 @@ def fit( from_unsupervised=None, warm_start=False, augmentations=None, + compute_importance=True ): """Train a neural network stored in self.network Using train_dataloader for training data and @@ -183,6 +184,8 @@ def fit( Use a previously self supervised model as starting weights warm_start: bool If True, current model parameters are used to start training + compute_importance : bool + Whether to compute feature importance """ # update model name @@ -196,6 +199,7 @@ def fit( self._stop_training = False self.pin_memory = pin_memory and (self.device.type != "cpu") self.augmentations = augmentations + self.compute_importance = compute_importance if self.augmentations is not None: # This ensure reproducibility @@ -267,8 +271,9 @@ def fit( self._callback_container.on_train_end() self.network.eval() - # compute feature importance once the best model is defined - self.feature_importances_ = self._compute_feature_importances(X_train) + if self.compute_importance: + # compute feature importance once the best model is defined + self.feature_importances_ = self._compute_feature_importances(X_train) def predict(self, X): """