From 4d534d9c5ae9a6fa3a86a44fbdae97a9fd22c9bd Mon Sep 17 00:00:00 2001 From: Abhishek Das Date: Sat, 9 Feb 2019 21:28:53 -0500 Subject: [PATCH 1/7] Larger batch size, gpu memory clearing, shuffling batches --- ...rcnn_x101_bs32.yml => lf_disc_faster_rcnn_x101.yml} | 4 +--- train.py | 10 +++++++--- 2 files changed, 8 insertions(+), 6 deletions(-) rename configs/{lf_disc_faster_rcnn_x101_bs32.yml => lf_disc_faster_rcnn_x101.yml} (91%) diff --git a/configs/lf_disc_faster_rcnn_x101_bs32.yml b/configs/lf_disc_faster_rcnn_x101.yml similarity index 91% rename from configs/lf_disc_faster_rcnn_x101_bs32.yml rename to configs/lf_disc_faster_rcnn_x101.yml index b18296f..3f1211a 100644 --- a/configs/lf_disc_faster_rcnn_x101_bs32.yml +++ b/configs/lf_disc_faster_rcnn_x101.yml @@ -25,10 +25,8 @@ model: # Optimization related arguments solver: - batch_size: 32 + batch_size: 128 # 32 x num_gpus is a good rule of thumb num_epochs: 20 initial_lr: 0.001 - lr_gamma: 0.9997592083 - minimum_lr: 0.00005 training_splits: "train" # "trainval" diff --git a/train.py b/train.py index 40ea94e..1b75696 100644 --- a/train.py +++ b/train.py @@ -102,7 +102,7 @@ config["dataset"], args.train_json, overfit=args.overfit, in_memory=args.in_memory ) train_dataloader = DataLoader( - train_dataset, batch_size=config["solver"]["batch_size"], num_workers=args.cpu_workers + train_dataset, batch_size=config["solver"]["batch_size"], num_workers=args.cpu_workers, shuffle=True ) val_dataset = VisDialDataset( @@ -189,9 +189,11 @@ summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step) summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step) - if optimizer.param_groups[0]["lr"] > config["solver"]["minimum_lr"]: - scheduler.step() + global_iteration_step += 1 + scheduler.step(global_iteration_step) + + torch.cuda.empty_cache() # -------------------------------------------------------------------------------------------- # ON EPOCH END (checkpointing and validation) @@ -217,3 +219,5 @@ for metric_name, metric_value in all_metrics.items(): print(f"{metric_name}: {metric_value}") summary_writer.add_scalars("metrics", all_metrics, global_iteration_step) + + torch.cuda.empty_cache() From 7506824961157d7097fb46372fa3e7940f955b9d Mon Sep 17 00:00:00 2001 From: Abhishek Das Date: Sat, 9 Feb 2019 22:03:56 -0500 Subject: [PATCH 2/7] Better LR schedule -- warmup + multi-step --- configs/lf_disc_faster_rcnn_x101.yml | 8 +++++- evaluate.py | 2 +- train.py | 39 ++++++++++++++++++++++------ 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/configs/lf_disc_faster_rcnn_x101.yml b/configs/lf_disc_faster_rcnn_x101.yml index 3f1211a..2d81688 100644 --- a/configs/lf_disc_faster_rcnn_x101.yml +++ b/configs/lf_disc_faster_rcnn_x101.yml @@ -29,4 +29,10 @@ solver: num_epochs: 20 initial_lr: 0.001 training_splits: "train" # "trainval" - + lr_gamma: 0.1 + lr_milestones: # epochs when lr —> lr * lr_gamma + - 3 + - 5 + - 7 + warmup_factor: 0.2 + warmup_epochs: 1 \ No newline at end of file diff --git a/evaluate.py b/evaluate.py index 058b087..3139e28 100644 --- a/evaluate.py +++ b/evaluate.py @@ -18,7 +18,7 @@ parser = argparse.ArgumentParser("Evaluate and/or generate EvalAI submission file.") parser.add_argument( - "--config-yml", default="configs/lf_disc_faster_rcnn_x101_bs32.yml", + "--config-yml", default="configs/lf_disc_faster_rcnn_x101.yml", help="Path to a config file listing reader, model and optimization parameters." ) parser.add_argument( diff --git a/train.py b/train.py index 1b75696..219f2b8 100644 --- a/train.py +++ b/train.py @@ -8,6 +8,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm import yaml +from bisect import bisect from visdialch.data.dataset import VisDialDataset from visdialch.encoders import Encoder @@ -19,7 +20,7 @@ parser = argparse.ArgumentParser() parser.add_argument( - "--config-yml", default="configs/lf_disc_faster_rcnn_x101_bs32.yml", + "--config-yml", default="configs/lf_disc_faster_rcnn_x101.yml", help="Path to a config file listing reader, model and solver parameters." ) parser.add_argument( @@ -126,10 +127,36 @@ if -1 not in args.gpu_ids: model = nn.DataParallel(model, args.gpu_ids) +# loss criterion = nn.CrossEntropyLoss() -optimizer = optim.Adam(model.parameters(), lr=config["solver"]["initial_lr"]) -scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=config["solver"]["lr_gamma"]) +if config["solver"]["training_splits"] == "trainval": + iterations = (len(train_dataset) + len(val_dataset)) // config["solver"]["batch_size"] + 1 +else: + iterations = len(train_dataset) // config["solver"]["batch_size"] + 1 + + +def lr_lambda_fun(itn): + """Returns a learning rate multiplier. + + Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, + and then gets multiplied by `lr_gamma` every time a milestone is crossed. + + Args: + itn: training iteration + Returns: + learning rate multiplier + """ + cur_epoch = float(itn) / iterations + if cur_epoch <= config["solver"]["warmup_epochs"]: + alpha = cur_epoch / float(config["solver"]["warmup_epochs"]) + return config["solver"]["warmup_factor"] * (1. - alpha) + alpha + else: + idx = bisect(config["solver"]["lr_milestones"], cur_epoch) + return pow(config["solver"]["lr_gamma"], idx) + +optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) +scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) # ================================================================================================ # SETUP BEFORE TRAINING LOOP @@ -160,12 +187,8 @@ # ================================================================================================ # Forever increasing counter keeping track of iterations completed. -if config["solver"]["training_splits"] == "trainval": - iterations = (len(train_dataset) + len(val_dataset)) // config["solver"]["batch_size"] + 1 -else: - iterations = len(train_dataset) // config["solver"]["batch_size"] + 1 - global_iteration_step = start_epoch * iterations + for epoch in range(start_epoch, config["solver"]["num_epochs"] + 1): # -------------------------------------------------------------------------------------------- From bc46ab5eae866a7ac2000ab014c1699eaad5cfc3 Mon Sep 17 00:00:00 2001 From: Abhishek Das Date: Sat, 9 Feb 2019 22:06:16 -0500 Subject: [PATCH 3/7] Use multi-layer LSTM with dropout for options --- visdialch/decoders/disc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/visdialch/decoders/disc.py b/visdialch/decoders/disc.py index 7462ded..e293651 100644 --- a/visdialch/decoders/disc.py +++ b/visdialch/decoders/disc.py @@ -14,7 +14,9 @@ def __init__(self, config, vocabulary): padding_idx=vocabulary.PAD_INDEX) self.option_rnn = nn.LSTM(config["word_embedding_size"], config["lstm_hidden_size"], - batch_first=True) + config["lstm_num_layers"], + batch_first=True, + dropout=config["dropout"]) # Options are variable length padded sequences, use DynamicRNN. self.option_rnn = DynamicRNN(self.option_rnn) From 2850328eab3c13b5b8f2bf02ff171c7af8aa6e29 Mon Sep 17 00:00:00 2001 From: Abhishek Das Date: Sat, 9 Feb 2019 22:19:26 -0500 Subject: [PATCH 4/7] Switch to eval mode when computing metrics on val --- train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/train.py b/train.py index 219f2b8..39cd76a 100644 --- a/train.py +++ b/train.py @@ -225,6 +225,10 @@ def lr_lambda_fun(itn): # Validate and report automatic metrics. if args.validate: + + # switch dropout, batchnorm etc to the correct mode + model.eval() + print(f"\nValidation after epoch {epoch}:") for i, batch in enumerate(tqdm(val_dataloader)): for key in batch: @@ -243,4 +247,5 @@ def lr_lambda_fun(itn): print(f"{metric_name}: {metric_value}") summary_writer.add_scalars("metrics", all_metrics, global_iteration_step) + model.train() torch.cuda.empty_cache() From ae1b7d8cca37264d5095350b700beaca4f4a68c4 Mon Sep 17 00:00:00 2001 From: Abhishek Das Date: Sun, 10 Feb 2019 21:00:46 -0500 Subject: [PATCH 5/7] Sets default learning rate as 1e-2 for batch size = 128 --- configs/lf_disc_faster_rcnn_x101.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/lf_disc_faster_rcnn_x101.yml b/configs/lf_disc_faster_rcnn_x101.yml index 2d81688..cb61504 100644 --- a/configs/lf_disc_faster_rcnn_x101.yml +++ b/configs/lf_disc_faster_rcnn_x101.yml @@ -27,12 +27,12 @@ model: solver: batch_size: 128 # 32 x num_gpus is a good rule of thumb num_epochs: 20 - initial_lr: 0.001 + initial_lr: 0.01 training_splits: "train" # "trainval" lr_gamma: 0.1 lr_milestones: # epochs when lr —> lr * lr_gamma - - 3 - - 5 + - 4 - 7 + - 10 warmup_factor: 0.2 - warmup_epochs: 1 \ No newline at end of file + warmup_epochs: 1 From 921d73e924361db9ec0319c459b913e69bb9aaff Mon Sep 17 00:00:00 2001 From: Abhishek Das Date: Sun, 10 Feb 2019 21:04:41 -0500 Subject: [PATCH 6/7] Implements point-wise mult + fc for attention than dot product --- visdialch/encoders/lf.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/visdialch/encoders/lf.py b/visdialch/encoders/lf.py index b440e8e..19fd879 100644 --- a/visdialch/encoders/lf.py +++ b/visdialch/encoders/lf.py @@ -39,6 +39,9 @@ def __init__(self, config, vocabulary): config["img_feature_size"], config["lstm_hidden_size"] ) + # fc layer for image * question to attention weights + self.attention_proj = nn.Linear(config["lstm_hidden_size"], 1) + # fusion layer (attended_image_features + question + history) fusion_size = config["img_feature_size"] + config["lstm_hidden_size"] * 2 self.fusion = nn.Linear(fusion_size, config["lstm_hidden_size"]) @@ -78,10 +81,14 @@ def forward(self, batch): batch_size * num_rounds, -1, self.config["lstm_hidden_size"] ) - # attend the features using question + # computing attention weights # shape: (batch_size * num_rounds, num_proposals) - image_attention_weights = projected_image_features.bmm( - ques_embed.unsqueeze(-1)).squeeze() + projected_ques_features = ques_embed.unsqueeze(1).repeat( + 1, img.shape[1], 1) + projected_ques_image = projected_ques_features * projected_image_features + projected_ques_image = self.dropout(projected_ques_image) + image_attention_weights = self.attention_proj( + projected_ques_image).squeeze() image_attention_weights = F.softmax(image_attention_weights, dim=-1) # shape: (batch_size * num_rounds, num_proposals, img_features_size) @@ -105,7 +112,7 @@ def forward(self, batch): hist_embed = self.word_embed(hist) # shape: (batch_size * num_rounds, lstm_hidden_size) - _ , (hist_embed, _) = self.hist_rnn(hist_embed, batch["hist_len"]) + _, (hist_embed, _) = self.hist_rnn(hist_embed, batch["hist_len"]) fused_vector = torch.cat((img, ques_embed, hist_embed), 1) fused_vector = self.dropout(fused_vector) From 711689c53f64574632438b826d1763590a71de85 Mon Sep 17 00:00:00 2001 From: Karan Desai Date: Mon, 11 Feb 2019 16:26:25 -0500 Subject: [PATCH 7/7] Update docstring of lr_lambda_fun and fix epoch enumeration. --- train.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/train.py b/train.py index 39cd76a..b4d5fd9 100644 --- a/train.py +++ b/train.py @@ -77,6 +77,7 @@ torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True + # ================================================================================================ # INPUT ARGUMENTS AND CONFIG # ================================================================================================ @@ -96,7 +97,7 @@ # ================================================================================================ -# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER +# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER # ================================================================================================ train_dataset = VisDialDataset( @@ -127,7 +128,7 @@ if -1 not in args.gpu_ids: model = nn.DataParallel(model, args.gpu_ids) -# loss +# Loss function. criterion = nn.CrossEntropyLoss() if config["solver"]["training_splits"] == "trainval": @@ -136,28 +137,24 @@ iterations = len(train_dataset) // config["solver"]["batch_size"] + 1 -def lr_lambda_fun(itn): +def lr_lambda_fun(current_iteration: int) -> float: """Returns a learning rate multiplier. Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, and then gets multiplied by `lr_gamma` every time a milestone is crossed. - - Args: - itn: training iteration - Returns: - learning rate multiplier """ - cur_epoch = float(itn) / iterations - if cur_epoch <= config["solver"]["warmup_epochs"]: - alpha = cur_epoch / float(config["solver"]["warmup_epochs"]) + current_epoch = float(current_iteration) / iterations + if current_epoch <= config["solver"]["warmup_epochs"]: + alpha = current_epoch / float(config["solver"]["warmup_epochs"]) return config["solver"]["warmup_factor"] * (1. - alpha) + alpha else: - idx = bisect(config["solver"]["lr_milestones"], cur_epoch) + idx = bisect(config["solver"]["lr_milestones"], current_epoch) return pow(config["solver"]["lr_gamma"], idx) optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) + # ================================================================================================ # SETUP BEFORE TRAINING LOOP # ================================================================================================ @@ -186,10 +183,10 @@ def lr_lambda_fun(itn): # TRAINING LOOP # ================================================================================================ -# Forever increasing counter keeping track of iterations completed. +# Forever increasing counter keeping track of iterations completed (for tensorboard logging). global_iteration_step = start_epoch * iterations -for epoch in range(start_epoch, config["solver"]["num_epochs"] + 1): +for epoch in range(start_epoch, config["solver"]["num_epochs"]): # -------------------------------------------------------------------------------------------- # ON EPOCH START (combine dataloaders if training on train + val) @@ -213,9 +210,8 @@ def lr_lambda_fun(itn): summary_writer.add_scalar("train/loss", batch_loss, global_iteration_step) summary_writer.add_scalar("train/lr", optimizer.param_groups[0]["lr"], global_iteration_step) - global_iteration_step += 1 scheduler.step(global_iteration_step) - + global_iteration_step += 1 torch.cuda.empty_cache() # -------------------------------------------------------------------------------------------- @@ -226,7 +222,7 @@ def lr_lambda_fun(itn): # Validate and report automatic metrics. if args.validate: - # switch dropout, batchnorm etc to the correct mode + # Switch dropout, batchnorm etc to the correct mode. model.eval() print(f"\nValidation after epoch {epoch}:")