Skip to content

Commit

Permalink
Fix for #42
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-valverde-l committed Sep 30, 2022
1 parent fc0c402 commit 5f3b196
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion TabSurvey/models/tabnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, params, args):
#self.model = TabNetClassifier(**self.params)
self.model = TabNetClassifierPatched(**self.params)
self.metric = ["logloss"]

def fit(self, X, y, X_val=None, y_val=None):
if self.args.objective == "regression":
y, y_val = y.reshape(-1, 1), y_val.reshape(-1, 1)
Expand All @@ -42,6 +42,8 @@ def fit(self, X, y, X_val=None, y_val=None):
elif self.args.objective == "classification":
self.model.num_classes = self.args.num_classes

# Drop last only if last batch has only one sample
drop_last = X.shape[0] % self.args.batch_size == 1
self.model.fit(
X,
y,
Expand All @@ -51,6 +53,7 @@ def fit(self, X, y, X_val=None, y_val=None):
max_epochs=self.args.epochs,
patience=self.args.early_stopping_rounds,
batch_size=self.args.batch_size,
drop_last=drop_last,
)
history = self.model.history
self.save_model(filename_extension="best")
Expand Down

0 comments on commit 5f3b196

Please sign in to comment.