From 1e15ff86303a9055f28214a695d9c3566cc06b7f Mon Sep 17 00:00:00 2001 From: Jonathan Date: Fri, 30 Sep 2022 22:59:21 +0000 Subject: [PATCH] Fixes for #33 and for issues with NODE and the case of missing classes on data splits. --- TabSurvey/models/node.py | 9 +++++++-- TabSurvey/models/node_lib/trainer.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/TabSurvey/models/node.py b/TabSurvey/models/node.py index 2e48b5e..8695282 100644 --- a/TabSurvey/models/node.py +++ b/TabSurvey/models/node.py @@ -1,4 +1,5 @@ import time +import shutil import numpy as np import torch @@ -17,6 +18,7 @@ Code adapted from: https://github.com/Qwicen/node """ +n_last_checkpoints = 5 # Number of models that are averaged class NODE(BaseModelTorch): def __init__(self, params, args): @@ -98,7 +100,7 @@ def fit(self, X, y, X_val=None, y_val=None): Optimizer=QHAdam, optimizer_params=dict(lr=1e-3, nus=(0.7, 1.0), betas=(0.95, 0.998)), verbose=True, - n_last_checkpoints=5, + n_last_checkpoints=n_last_checkpoints, ) best_loss = float("inf") @@ -120,7 +122,8 @@ def fit(self, X, y, X_val=None, y_val=None): metrics = self.trainer.train_on_batch(*batch, device=self.device) loss_history.append(metrics["loss"].item()) - if self.trainer.step % self.args.logging_period == 0: + # Periodically every args.logging_period until last step + if (self.args.epochs - self.trainer.step) % self.args.logging_period == 0: self.trainer.save_checkpoint() self.trainer.average_checkpoints(out_tag="avg") self.trainer.load_checkpoint(tag="avg") @@ -141,6 +144,7 @@ def fit(self, X, y, X_val=None, y_val=None): data.y_valid, device=self.device, batch_size=self.args.batch_size, + num_classes=self.args.num_classes, ) print("Val LogLoss: %0.5f" % loss) elif self.args.objective == "binary": @@ -172,6 +176,7 @@ def fit(self, X, y, X_val=None, y_val=None): break self.trainer.load_checkpoint(tag="best") + shutil.rmtree(self.trainer.experiment_path) return loss_history, val_loss_history def predict_helper(self, X): diff --git a/TabSurvey/models/node_lib/trainer.py b/TabSurvey/models/node_lib/trainer.py index 4c05673..800c59d 100644 --- a/TabSurvey/models/node_lib/trainer.py +++ b/TabSurvey/models/node_lib/trainer.py @@ -171,7 +171,7 @@ def evaluate_binarylogloss(self, X_test, y_test, device, batch_size=512): logloss = log_loss(y_test, logits) return logloss - def evaluate_logloss(self, X_test, y_test, device, batch_size=512): + def evaluate_logloss(self, X_test, y_test, device, batch_size=512, num_classes=None): X_test = torch.as_tensor(X_test, device=device) y_test = check_numpy(y_test) self.model.train(False) @@ -179,5 +179,5 @@ def evaluate_logloss(self, X_test, y_test, device, batch_size=512): logits = F.softmax(process_in_chunks(self.model, X_test, batch_size=batch_size), dim=1) logits = check_numpy(logits) y_test = torch.tensor(y_test) - logloss = log_loss(check_numpy(to_one_hot(y_test)), logits) + logloss = log_loss(check_numpy(to_one_hot(y_test, depth=num_classes)), logits) return logloss