From 711689c53f64574632438b826d1763590a71de85 Mon Sep 17 00:00:00 2001 From: Karan Desai Date: Mon, 11 Feb 2019 16:26:25 -0500 Subject: [PATCH] Update docstring of lr_lambda_fun and fix epoch enumeration. --- train.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/train.py b/train.py index 39cd76a..b4d5fd9 100644 --- a/train.py +++ b/train.py @@ -77,6 +77,7 @@ torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True + # ================================================================================================ # INPUT ARGUMENTS AND CONFIG # ================================================================================================ @@ -96,7 +97,7 @@ # ================================================================================================ -# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER +# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER # ================================================================================================ train_dataset = VisDialDataset( @@ -127,7 +128,7 @@ if -1 not in args.gpu_ids: model = nn.DataParallel(model, args.gpu_ids) -# loss +# Loss function. criterion = nn.CrossEntropyLoss() if config["solver"]["training_splits"] == "trainval": @@ -136,28 +137,24 @@ iterations = len(train_dataset) // config["solver"]["batch_size"] + 1 -def lr_lambda_fun(itn): +def lr_lambda_fun(current_iteration: int) -> float: """Returns a learning rate multiplier. Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, and then gets multiplied by `lr_gamma` every time a milestone is crossed. - - Args: - itn: training iteration - Returns: - learning rate multiplier """ - cur_epoch = float(itn) / iterations - if cur_epoch <= config["solver"]["warmup_epochs"]: - alpha = cur_epoch / float(config["solver"]["warmup_epochs"]) + current_epoch = float(current_iteration) / iterations + if current_epoch <= config["solver"]["warmup_epochs"]: + alpha = current_epoch / float(config["solver"]["warmup_epochs"]) return config["solver"]["warmup_factor"] * (1. - alpha) + alpha else: - idx = bisect(config["solver"]["lr_milestones"], cur_epoch) + idx = bisect(config["solver"]["lr_milestones"], current_epoch) return pow(config["solver"]["lr_gamma"], idx) optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) + # ================================================================================================ # SETUP BEFORE TRAINING LOOP # ================================================================================================ @@ -186,10 +183,10 @@ def lr_lambda_fun(itn): # TRAINING LOOP # ================================================================================================ -# Forever increasing counter keeping track of iterations completed. +# Forever increasing counter keeping track of iterations completed (for tensorboard logging). global_iteration_step = start_epoch * iterations -for epoch in range(start_epoch, config["solver"]["num_epochs"] + 1): +for epoch in range(start_epoch, config["solver"]["num_epochs"]): # -------------------------------------------------------------------------------------------- # ON EPOCH START (combine dataloaders if training on train + val) @@ -213,9 +210,8 @@ def lr_lambda_fun(itn): summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step) summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step) - global_iteration_step += 1 scheduler.step(global_iteration_step) - + global_iteration_step += 1 torch.cuda.empty_cache() # -------------------------------------------------------------------------------------------- @@ -226,7 +222,7 @@ def lr_lambda_fun(itn): # Validate and report automatic metrics. if args.validate: - # switch dropout, batchnorm etc to the correct mode + # Switch dropout, batchnorm etc to the correct mode. model.eval() print(f"\nValidation after epoch {epoch}:")