Skip to content

Commit

Permalink
Fixes for naszilla#33 and for issues with NODE and the case of missin…
Browse files Browse the repository at this point in the history
…g classes on data splits.
  • Loading branch information
jonathan-valverde-l committed Sep 30, 2022
1 parent c75d4a6 commit 1e15ff8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 7 additions & 2 deletions TabSurvey/models/node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import shutil

import numpy as np
import torch
Expand All @@ -17,6 +18,7 @@
Code adapted from: https://github.com/Qwicen/node
"""

n_last_checkpoints = 5 # Number of models that are averaged

class NODE(BaseModelTorch):
def __init__(self, params, args):
Expand Down Expand Up @@ -98,7 +100,7 @@ def fit(self, X, y, X_val=None, y_val=None):
Optimizer=QHAdam,
optimizer_params=dict(lr=1e-3, nus=(0.7, 1.0), betas=(0.95, 0.998)),
verbose=True,
n_last_checkpoints=5,
n_last_checkpoints=n_last_checkpoints,
)

best_loss = float("inf")
Expand All @@ -120,7 +122,8 @@ def fit(self, X, y, X_val=None, y_val=None):
metrics = self.trainer.train_on_batch(*batch, device=self.device)
loss_history.append(metrics["loss"].item())

if self.trainer.step % self.args.logging_period == 0:
# Periodically every args.logging_period until last step
if (self.args.epochs - self.trainer.step) % self.args.logging_period == 0:
self.trainer.save_checkpoint()
self.trainer.average_checkpoints(out_tag="avg")
self.trainer.load_checkpoint(tag="avg")
Expand All @@ -141,6 +144,7 @@ def fit(self, X, y, X_val=None, y_val=None):
data.y_valid,
device=self.device,
batch_size=self.args.batch_size,
num_classes=self.args.num_classes,
)
print("Val LogLoss: %0.5f" % loss)
elif self.args.objective == "binary":
Expand Down Expand Up @@ -172,6 +176,7 @@ def fit(self, X, y, X_val=None, y_val=None):
break

self.trainer.load_checkpoint(tag="best")
shutil.rmtree(self.trainer.experiment_path)
return loss_history, val_loss_history

def predict_helper(self, X):
Expand Down
4 changes: 2 additions & 2 deletions TabSurvey/models/node_lib/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,13 @@ def evaluate_binarylogloss(self, X_test, y_test, device, batch_size=512):
logloss = log_loss(y_test, logits)
return logloss

def evaluate_logloss(self, X_test, y_test, device, batch_size=512):
def evaluate_logloss(self, X_test, y_test, device, batch_size=512, num_classes=None):
X_test = torch.as_tensor(X_test, device=device)
y_test = check_numpy(y_test)
self.model.train(False)
with torch.no_grad():
logits = F.softmax(process_in_chunks(self.model, X_test, batch_size=batch_size), dim=1)
logits = check_numpy(logits)
y_test = torch.tensor(y_test)
logloss = log_loss(check_numpy(to_one_hot(y_test)), logits)
logloss = log_loss(check_numpy(to_one_hot(y_test, depth=num_classes)), logits)
return logloss

0 comments on commit 1e15ff8

Please sign in to comment.