|
33 | 33 |
|
34 | 34 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
35 | 35 | # data
|
36 |
| - parser.add_argument('--train_csv', type=str, default=None, help='csv file with training samples') |
37 |
| - parser.add_argument('--val_csv', type=str, default=None, help='csv file with validation samples') |
38 |
| - parser.add_argument('--category_csv', default=None, type=str, help='csv file with category names') |
39 |
| - parser.add_argument('--save_dir', default=None, type=str, help='save directory') |
| 36 | + parser.add_argument('--train_csv', type=str, default=None, |
| 37 | + help='csv file with training samples') |
| 38 | + parser.add_argument('--val_csv', type=str, default=None, |
| 39 | + help='csv file with validation samples') |
| 40 | + parser.add_argument('--category_csv', default=None, type=str, |
| 41 | + help='csv file with category names') |
| 42 | + parser.add_argument('--save_dir', default=None, type=str, |
| 43 | + help='save directory') |
40 | 44 | # model
|
41 | 45 | parser = nf.add_model_parser_arguments(parser)
|
42 | 46 | # optim
|
|
47 | 51 | parser.add_argument('--scheduler_gamma', type=float, default=0.1,
|
48 | 52 | help='multiplicative factor of learning rate decay')
|
49 | 53 |
|
50 |
| - parser.add_argument('--train_steps', type=int, default=- |
51 |
| - 1, help='number of iterations on training data for one epoch') |
| 54 | + parser.add_argument('--train_steps', type=int, default=-1, |
| 55 | + help='number of iterations on training data for one epoch') |
52 | 56 | parser.add_argument('--train_epochs', type=int, default=1,
|
53 | 57 | help='number of epochs on training data')
|
54 |
| - parser.add_argument( |
55 |
| - '--train_layers_epochs', |
56 |
| - type=int, |
57 |
| - default=-1, |
58 |
| - help='number of epochs to train selected layers before switching to training all layers') |
59 |
| - parser.add_argument( |
60 |
| - '--sampler_type', |
61 |
| - type=str, |
62 |
| - default='rnd', |
63 |
| - choices=[ |
64 |
| - 'rnd', |
65 |
| - 'wc_rnd', |
66 |
| - 'seq'], |
67 |
| - help='type of the training data sampler') |
68 |
| - parser.add_argument('--sampler_persistent', action='store_true', help='use a persistent sampler') |
69 |
| - parser.add_argument('--start_epoch', type=int, default=1, help='index of the start epoch') |
70 |
| - parser.add_argument('--input_aug', action='store_true', help='enable input augmentation') |
71 |
| - parser.add_argument('--color_aug', action='store_true', help='enable color augmentation') |
72 |
| - parser.add_argument('--batch_size', type=int, default=64, help='batch size for training') |
73 |
| - parser.add_argument( |
74 |
| - '--weighted_pos', |
75 |
| - action='store_true', |
76 |
| - help='weight positive samples for each class in balance with negatives') |
| 58 | + parser.add_argument('--train_layers_epochs', type=int, default=-1, |
| 59 | + help='number of epochs to train selected layers before switching to training all layers') |
| 60 | + parser.add_argument('--sampler_type', type=str, default='rnd', |
| 61 | + choices=['rnd', 'wc_rnd', 'seq'], |
| 62 | + help='type of the training data sampler') |
| 63 | + parser.add_argument('--sampler_persistent', action='store_true', |
| 64 | + help='use a persistent sampler') |
| 65 | + parser.add_argument('--start_epoch', type=int, default=1, |
| 66 | + help='index of the start epoch') |
| 67 | + parser.add_argument('--input_aug', action='store_true', |
| 68 | + help='enable input augmentation') |
| 69 | + parser.add_argument('--color_aug', action='store_true', |
| 70 | + help='enable color augmentation') |
| 71 | + parser.add_argument('--batch_size', type=int, default=64, |
| 72 | + help='batch size for training') |
| 73 | + parser.add_argument('--weighted_pos', action='store_true', |
| 74 | + help='weight positive samples for each class in balance with negatives') |
77 | 75 | parser.add_argument('--weighted_pos_max', type=float, default=None,
|
78 | 76 | help='maximum weight of positive samples for all class')
|
79 |
| - parser.add_argument('--eval_steps', type=int, default=- |
80 |
| - 1, help='number of iterations on evaluation data for one epoch') |
81 |
| - parser.add_argument('--eval_batch_size', type=int, default=128, help='batch size for evaluation') |
| 77 | + parser.add_argument('--eval_steps', type=int, default=-1, |
| 78 | + help='number of iterations on evaluation data for one epoch') |
| 79 | + parser.add_argument('--eval_batch_size', type=int, default=128, |
| 80 | + help='batch size for evaluation') |
82 | 81 | parser.add_argument('--best_metric', type=str, default='AUC',
|
83 | 82 | help='the evaluation metric used to select best model')
|
84 | 83 | parser.add_argument('--num_workers', type=int, default=0,
|
85 | 84 | help='number of workers for data loader')
|
86 |
| - parser.add_argument('--seed', type=int, default=-1, help='set random seed') |
87 |
| - parser.add_argument('--no_gpu', action='store_true', help='do not use GPUs') |
| 85 | + parser.add_argument('--seed', type=int, default=-1, |
| 86 | + help='set random seed') |
| 87 | + parser.add_argument('--no_gpu', action='store_true', |
| 88 | + help='do not use GPUs') |
88 | 89 | # logging
|
89 |
| - parser.add_argument( |
90 |
| - '--log_level', |
91 |
| - type=str, |
92 |
| - default=logging.INFO) |
| 90 | + parser.add_argument('--log_level', type=str, default=logging.INFO) |
93 | 91 | parser.add_argument('--log_interval', type=int, default=100,
|
94 | 92 | help='logging interval in terms of iterations')
|
95 | 93 |
|
|
138 | 136 |
|
139 | 137 | # create a dataset
|
140 | 138 | logger.info('Create a training EmojiDataset')
|
141 |
| - train_ds = datasets.EmojiDataset(categories_list=categories_list, samples_csv_file=opt.train_csv, |
142 |
| - input_transform=image_transform_train, target_transform=label_transform, suppress_exceptions=True) |
| 139 | + train_ds = datasets.EmojiDataset(categories_list=categories_list, |
| 140 | + samples_csv_file=opt.train_csv, |
| 141 | + input_transform=image_transform_train, |
| 142 | + target_transform=label_transform, |
| 143 | + suppress_exceptions=True) |
143 | 144 | logger.info('Number of samples in training file: {}'.format(train_ds.n_samples))
|
144 | 145 | logger.info('Create a validation EmojiDataset')
|
145 |
| - valid_ds = datasets.EmojiDataset(categories_list=categories_list, samples_csv_file=opt.val_csv, |
146 |
| - input_transform=image_transform_eval, target_transform=label_transform, suppress_exceptions=True) |
| 146 | + valid_ds = datasets.EmojiDataset(categories_list=categories_list, |
| 147 | + samples_csv_file=opt.val_csv, |
| 148 | + input_transform=image_transform_eval, |
| 149 | + target_transform=label_transform, |
| 150 | + suppress_exceptions=True) |
147 | 151 | logger.info('Number of samples in validation file: {}'.format(valid_ds.n_samples))
|
148 | 152 |
|
149 | 153 | # create data samplers
|
|
174 | 178 | torch.zeros(n_categories))])
|
175 | 179 | # create loaders
|
176 | 180 | logger.info('Create data loaders')
|
177 |
| - train_dataloader = torch.utils.data.DataLoader(train_ds, sampler=train_sampler, |
178 |
| - batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, collate_fn=collate_fn) |
179 |
| - eval_dataloader = torch.utils.data.DataLoader(valid_ds, sampler=valid_sampler, |
180 |
| - batch_size=opt.eval_batch_size, shuffle=False, num_workers=opt.num_workers, collate_fn=collate_fn, drop_last=False) |
| 181 | + train_dataloader = torch.utils.data.DataLoader(train_ds, |
| 182 | + sampler=train_sampler, |
| 183 | + batch_size=opt.batch_size, |
| 184 | + shuffle=False, |
| 185 | + num_workers=opt.num_workers, |
| 186 | + collate_fn=collate_fn) |
| 187 | + eval_dataloader = torch.utils.data.DataLoader(valid_ds, |
| 188 | + sampler=valid_sampler, |
| 189 | + batch_size=opt.eval_batch_size, |
| 190 | + shuffle=False, |
| 191 | + num_workers=opt.num_workers, |
| 192 | + collate_fn=collate_fn, |
| 193 | + drop_last=False) |
181 | 194 |
|
182 | 195 | # model
|
183 | 196 | logger.info('=' * 25)
|
|
198 | 211 | if opt.scheduler_step_size > 0:
|
199 | 212 | logger.info('setup learning rate scheduler')
|
200 | 213 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
|
201 |
| - step_size=opt.scheduler_step_size, gamma=opt.scheduler_gamma, last_epoch=-1) |
| 214 | + step_size=opt.scheduler_step_size, |
| 215 | + gamma=opt.scheduler_gamma, |
| 216 | + last_epoch=-1) |
202 | 217 |
|
203 | 218 | # Loss
|
204 | 219 | logger.info('setup loss')
|
|
279 | 294 | if trainer.lr_scheduler is not None:
|
280 | 295 | logger.info('reset scheduler')
|
281 | 296 | trainer.lr_scheduler = torch.optim.lr_scheduler.StepLR(trainer.optimizer,
|
282 |
| - step_size=opt.scheduler_step_size, gamma=opt.scheduler_gamma, last_epoch=-1) |
| 297 | + step_size=opt.scheduler_step_size, |
| 298 | + gamma=opt.scheduler_gamma, |
| 299 | + last_epoch=-1) |
283 | 300 |
|
284 | 301 | logger.info('Run time: {}'.format(datetime.now() - tm_start))
|
285 | 302 | if log_file is not None and os.path.exists(log_file):
|
|
0 commit comments