Skip to content

Commit

Permalink
Update docstring of lr_lambda_fun and fix epoch enumeration.
Browse files Browse the repository at this point in the history
  • Loading branch information
Karan Desai committed Feb 11, 2019
1 parent 921d73e commit 711689c
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


# ================================================================================================
# INPUT ARGUMENTS AND CONFIG
# ================================================================================================
Expand All @@ -96,7 +97,7 @@


# ================================================================================================
# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER
# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER
# ================================================================================================

train_dataset = VisDialDataset(
Expand Down Expand Up @@ -127,7 +128,7 @@
if -1 not in args.gpu_ids:
model = nn.DataParallel(model, args.gpu_ids)

# loss
# Loss function.
criterion = nn.CrossEntropyLoss()

if config["solver"]["training_splits"] == "trainval":
Expand All @@ -136,28 +137,24 @@
iterations = len(train_dataset) // config["solver"]["batch_size"] + 1


def lr_lambda_fun(itn):
def lr_lambda_fun(current_iteration: int) -> float:
"""Returns a learning rate multiplier.
Till `warmup_epochs`, learning rate linearly increases to `initial_lr`,
and then gets multiplied by `lr_gamma` every time a milestone is crossed.
Args:
itn: training iteration
Returns:
learning rate multiplier
"""
cur_epoch = float(itn) / iterations
if cur_epoch <= config["solver"]["warmup_epochs"]:
alpha = cur_epoch / float(config["solver"]["warmup_epochs"])
current_epoch = float(current_iteration) / iterations
if current_epoch <= config["solver"]["warmup_epochs"]:
alpha = current_epoch / float(config["solver"]["warmup_epochs"])
return config["solver"]["warmup_factor"] * (1. - alpha) + alpha
else:
idx = bisect(config["solver"]["lr_milestones"], cur_epoch)
idx = bisect(config["solver"]["lr_milestones"], current_epoch)
return pow(config["solver"]["lr_gamma"], idx)

optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"])
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun)


# ================================================================================================
# SETUP BEFORE TRAINING LOOP
# ================================================================================================
Expand Down Expand Up @@ -186,10 +183,10 @@ def lr_lambda_fun(itn):
# TRAINING LOOP
# ================================================================================================

# Forever increasing counter keeping track of iterations completed.
# Forever increasing counter keeping track of iterations completed (for tensorboard logging).
global_iteration_step = start_epoch * iterations

for epoch in range(start_epoch, config["solver"]["num_epochs"] + 1):
for epoch in range(start_epoch, config["solver"]["num_epochs"]):

# --------------------------------------------------------------------------------------------
# ON EPOCH START (combine dataloaders if training on train + val)
Expand All @@ -213,9 +210,8 @@ def lr_lambda_fun(itn):
summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step)
summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step)

global_iteration_step += 1
scheduler.step(global_iteration_step)

global_iteration_step += 1
torch.cuda.empty_cache()

# --------------------------------------------------------------------------------------------
Expand All @@ -226,7 +222,7 @@ def lr_lambda_fun(itn):
# Validate and report automatic metrics.
if args.validate:

# switch dropout, batchnorm etc to the correct mode
# Switch dropout, batchnorm etc to the correct mode.
model.eval()

print(f"\nValidation after epoch {epoch}:")
Expand Down

0 comments on commit 711689c

Please sign in to comment.