diff --git a/trainer.py b/trainer.py index 1eef06c..99ba853 100644 --- a/trainer.py +++ b/trainer.py @@ -55,6 +55,8 @@ parser.add_argument('--save-every', dest='save_every', help='Saves checkpoints at every specified number of epochs', type=int, default=10) +parser.add_argument('--lr-milestones', default=[100, 150], nargs='+', + help='list of epoch indices for multi step learning rate scheduler', type=int) best_prec1 = 0 @@ -118,7 +120,7 @@ def main(): weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, - milestones=[100, 150], last_epoch=args.start_epoch - 1) + milestones=args.lr_milestones, last_epoch=args.start_epoch - 1) if args.arch in ['resnet1202', 'resnet110']: # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up