@@ -287,7 +287,7 @@ def train_model(
287
287
288
288
return_dict = {
289
289
"model" : model ,
290
- "best_model" : best_model if (do_evaluation and best_model ) is not None else model ,
290
+ "best_model" : best_model if (do_evaluation and best_model is not None ) else model ,
291
291
"train_logger" : train_logger ,
292
292
"val_logger" : val_logger ,
293
293
"optimizer" : optimizer ,
@@ -840,8 +840,8 @@ def load_model(
840
840
# Verify consistency of last epoch trained
841
841
if epoch_trained != checkpoint ["epoch" ]:
842
842
logger .warning (
843
- f"Mismatch between epoch specified in checkpoint path (' { epoch_trained } '), "
844
- f"epoch specified at saving time (' { checkpoint ['epoch' ]} ' )."
843
+ f"Mismatch between epoch specified in checkpoint path ({ epoch_trained } ) "
844
+ f"and epoch specified at saving time ({ checkpoint ['epoch' ]} )."
845
845
)
846
846
847
847
# Load train / val loggers if provided
@@ -992,7 +992,7 @@ class EarlyStopping:
992
992
993
993
def __init__ (
994
994
self ,
995
- criterion : Optional [str ] = "f1 " ,
995
+ criterion : Optional [str ] = "accuracy " ,
996
996
mode : Optional [str ] = None ,
997
997
min_delta : Optional [float ] = None ,
998
998
patience : Optional [int ] = None ,
0 commit comments