Skip to content

Commit

Permalink
chore: refacto history
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Oct 12, 2020
1 parent 0ae114f commit b01339a
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 24 deletions.
24 changes: 20 additions & 4 deletions census_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"clf.fit(\n",
Expand Down Expand Up @@ -257,10 +259,24 @@
"preds = clf.predict_proba(X_test)\n",
"test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
"\n",
"\n",
"preds_valid = clf.predict_proba(X_valid)\n",
"valid_auc = roc_auc_score(y_score=preds_valid[:,1], y_true=y_valid)\n",
"\n",
"print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check that best weights are used\n",
"assert np.isclose(valid_auc, np.max(clf.history['valid_auc']), atol=1e-6)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -420,9 +436,9 @@
],
"metadata": {
"kernelspec": {
"display_name": ".shap",
"display_name": "Python 3",
"language": "python",
"name": ".shap"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -434,7 +450,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.7.6"
},
"toc": {
"base_numbering": 1,
Expand Down
15 changes: 15 additions & 0 deletions customizing_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@
"# 2 - Use your own loss function\n",
"\n",
"The default loss for classification is torch.nn.functional.cross_entropy\n",
"\n",
"The default loss for regression is torch.nn.functional.mse_loss\n",
"\n",
"Any derivable loss function of the type lambda y_pred, y_true : loss(y_pred, y_true) should work if it uses torch computation (to allow gradients computations).\n",
Expand Down Expand Up @@ -410,10 +411,24 @@
"preds = clf.predict_proba(X_test)\n",
"test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
"\n",
"preds_valid = clf.predict_proba(X_valid)\n",
"valid_auc = roc_auc_score(y_score=preds_valid[:,1], y_true=y_valid)\n",
"\n",
"\n",
"print(f\"FINAL VALID SCORE FOR {dataset_name} : {clf.history['val_auc'][-1]}\")\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check that last epoch's weight are used\n",
"assert np.isclose(valid_auc, clf.history['val_auc'][-1], atol=1e-6)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
10 changes: 6 additions & 4 deletions pytorch_tabnet/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def fit(
self,
X_train,
y_train,
eval_set=[],
eval_set=None,
eval_name=None,
eval_metric=None,
loss_fn=None,
Expand Down Expand Up @@ -133,6 +133,7 @@ def fit(
self.drop_last = drop_last
self.input_dim = X_train.shape[1]
self._stop_training = False
self.eval_set = eval_set if eval_set else []

if loss_fn is None:
self.loss_fn = self._default_loss
Expand Down Expand Up @@ -171,7 +172,8 @@ def fit(
self._predict_epoch(eval_name, valid_dataloader)

# Call method on_epoch_end for all callbacks
self._callback_container.on_epoch_end(epoch_idx, self.history.batch_metrics)
self._callback_container.on_epoch_end(epoch_idx,
logs=self.history.epoch_metrics)

if self._stop_training:
break
Expand Down Expand Up @@ -337,7 +339,7 @@ def _train_epoch(self, train_loader):
self._callback_container.on_batch_end(batch_idx, batch_logs)

epoch_logs = {"lr": self._optimizer.param_groups[-1]["lr"]}
self.history.batch_metrics.update(epoch_logs)
self.history.epoch_metrics.update(epoch_logs)

return

Expand Down Expand Up @@ -409,7 +411,7 @@ def _predict_epoch(self, name, loader):

metrics_logs = self._metric_container_dict[name](y_true, scores)
self.network.train()
self.history.batch_metrics.update(metrics_logs)
self.history.epoch_metrics.update(metrics_logs)
return

def _predict_batch(self, X):
Expand Down
32 changes: 16 additions & 16 deletions pytorch_tabnet/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,46 +192,46 @@ def __post_init__(self):
self.total_time = 0.0

def on_train_begin(self, logs=None):
self.epoch_metrics = {"loss": []}
self.epoch_metrics.update({"lr": []})
self.epoch_metrics.update({name: [] for name in self.trainer._metrics_names})
self.history = {"loss": []}
self.history.update({"lr": []})
self.history.update({name: [] for name in self.trainer._metrics_names})
self.start_time = logs["start_time"]
self.epoch_loss = 0.

def on_epoch_begin(self, epoch, logs=None):
self.batch_metrics = {"loss": 0.0}
self.epoch_metrics = {"loss": 0.0}
self.samples_seen = 0.0

def on_epoch_end(self, epoch, logs=None):
for k in self.batch_metrics:
self.epoch_metrics[k].append(self.batch_metrics[k])
self.epoch_metrics["loss"] = self.epoch_loss
for metric_name, metric_value in self.epoch_metrics.items():
self.history[metric_name].append(metric_value)
if self.verbose == 0:
return
if epoch % self.verbose != 0:
return
msg = f"epoch: {epoch:<4}"
for metric_name, metric_value in self.batch_metrics.items():
msg = f"epoch {epoch:<3}"
for metric_name, metric_value in self.epoch_metrics.items():
if metric_name != "lr":
msg += f"| {metric_name:<5}: {np.round(metric_value, 5):<8}"
msg += f"| {metric_name:<3}: {np.round(metric_value, 5):<8}"
self.total_time = int(time.time() - self.start_time)
msg += f"| {str(datetime.timedelta(seconds=self.total_time)) + 's':<6}"
print(msg)

def on_batch_end(self, batch, logs=None):
batch_size = logs["batch_size"]
for k in self.batch_metrics:
self.batch_metrics[k] = (
self.samples_seen * self.batch_metrics[k] + logs[k] * batch_size
) / (self.samples_seen + batch_size)
self.epoch_loss = (self.samples_seen * self.epoch_loss + batch_size * logs["loss"]
) / (self.samples_seen + batch_size)
self.samples_seen += batch_size

def __getitem__(self, name):
return self.epoch_metrics[name]
return self.history[name]

def __repr__(self):
return str(self.epoch_metrics)
return str(self.history)

def __str__(self):
return str(self.epoch_metrics)
return str(self.history)


@dataclass
Expand Down

0 comments on commit b01339a

Please sign in to comment.