Skip to content

Commit

Permalink
feat/493-optional-importance
Browse files Browse the repository at this point in the history
  • Loading branch information
César Leblanc authored and Optimox committed Jul 6, 2023
1 parent d44f3b5 commit 9ba8991
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 22 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,7 @@ loaded_clf.load_model(saved_filepath)
- `warm_start` : bool (default=False)
In order to match scikit-learn API, this is set to False.
It allows to fit twice the same model and start from a warm start.

- `compute_importance` : bool (default=True)

Whether to compute feature importance
64 changes: 44 additions & 20 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -272,26 +272,50 @@
"scrolled": true
},
"outputs": [],
"source": [
"# This illustrates the warm_start=False behaviour\n",
"save_history = []\n",
"for _ in range(2):\n",
" clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False,\n",
" augmentations=aug, #aug, None\n",
" )\n",
" save_history.append(clf.history[\"valid_auc\"])\n",
"\n",
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))"
]
"source": [
"# This illustrates the warm_start=False behaviour\n",
"save_history = []\n",
"\n",
"# Fitting the model without starting from a warm start nor computing the feature importance\n",
"for _ in range(2):\n",
" clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False,\n",
" augmentations=aug, #aug, None\n",
" compute_importance=False\n",
" )\n",
" save_history.append(clf.history[\"valid_auc\"])\n",
"\n",
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))\n",
"\n",
"save_history = [] # Resetting the list to show that it also works when computing feature importance\n",
"\n",
"# Fitting the model without starting from a warm start but with the computing of the feature importance activated\n",
"for _ in range(2):\n",
" clf.fit(\n",
" X_train=X_train, y_train=y_train,\n",
" eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
" eval_name=['train', 'valid'],\n",
" eval_metric=['auc'],\n",
" max_epochs=max_epochs , patience=20,\n",
" batch_size=1024, virtual_batch_size=128,\n",
" num_workers=0,\n",
" weights=1,\n",
" drop_last=False,\n",
" augmentations=aug, #aug, None\n",
" compute_importance=True # True by default so not needed\n",
" )\n",
" save_history.append(clf.history[\"valid_auc\"])\n",
"\n",
"assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))"
]
},
{
"cell_type": "code",
Expand Down
9 changes: 7 additions & 2 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def fit(
from_unsupervised=None,
warm_start=False,
augmentations=None,
compute_importance=True
):
"""Train a neural network stored in self.network
Using train_dataloader for training data and
Expand Down Expand Up @@ -183,6 +184,8 @@ def fit(
Use a previously self supervised model as starting weights
warm_start: bool
If True, current model parameters are used to start training
compute_importance : bool
Whether to compute feature importance
"""
# update model name

Expand All @@ -196,6 +199,7 @@ def fit(
self._stop_training = False
self.pin_memory = pin_memory and (self.device.type != "cpu")
self.augmentations = augmentations
self.compute_importance = compute_importance

if self.augmentations is not None:
# This ensure reproducibility
Expand Down Expand Up @@ -267,8 +271,9 @@ def fit(
self._callback_container.on_train_end()
self.network.eval()

# compute feature importance once the best model is defined
self.feature_importances_ = self._compute_feature_importances(X_train)
if self.compute_importance:
# compute feature importance once the best model is defined
self.feature_importances_ = self._compute_feature_importances(X_train)

def predict(self, X):
"""
Expand Down

0 comments on commit 9ba8991

Please sign in to comment.