Skip to content

Commit 05670fb

Browse files
author
Ziad Al-Halah
committed
fix minor formatting issues
1 parent d49cc0d commit 05670fb

File tree

3 files changed

+128
-104
lines changed

3 files changed

+128
-104
lines changed

scripts/model_predict.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,29 @@
3232

3333
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
3434
# data
35-
parser.add_argument('--image_dir', type=str, default=None, help='directory containing images')
36-
parser.add_argument('--category_csv', type=str, default=None, help='csv file with category names')
37-
parser.add_argument('--save_dir', type=str, default=None, help='save directory')
35+
parser.add_argument('--image_dir', type=str, default=None,
36+
help='directory containing images')
37+
parser.add_argument('--category_csv', type=str, default=None,
38+
help='csv file with category names')
39+
parser.add_argument('--save_dir', type=str, default=None,
40+
help='save directory')
3841
# model
39-
parser.add_argument('--model_file', type=str, default=None, help='model file path')
40-
parser.add_argument('--batch_size', type=int, default=64, help='batch size for training')
42+
parser.add_argument('--model_file', type=str, default=None,
43+
help='model file path')
44+
parser.add_argument('--batch_size', type=int, default=64,
45+
help='batch size for training')
4146
parser.add_argument('--num_workers', type=int, default=0,
4247
help='number of workers for data loader')
43-
parser.add_argument('--seed', type=int, default=-1, help='set random seed')
44-
parser.add_argument('--no_gpu', action='store_true', help='do not use GPUs')
48+
parser.add_argument('--seed', type=int, default=-1,
49+
help='set random seed')
50+
parser.add_argument('--no_gpu', action='store_true',
51+
help='do not use GPUs')
4552
parser.add_argument('--image_size', type=int, default=256,
4653
help='image size for qualitative results')
4754
parser.add_argument('--predict_top_k', type=int, default=5,
4855
help='number of top predictions to save to file')
4956
# logging
50-
parser.add_argument(
51-
'--log_level',
52-
type=str,
53-
default=logging.INFO)
57+
parser.add_argument('--log_level', type=str, default=logging.INFO)
5458
parser.add_argument('--log_interval', type=int, default=100,
5559
help='logging interval in terms of iterations')
5660

@@ -108,8 +112,10 @@
108112

109113
# create a dataset
110114
logger.info('Create a testing EmojiDataset')
111-
test_ds = datasets.EmojiDataset(categories_list=categories_list, samples_csv_file=image_paths_file,
112-
input_transform=image_transform, suppress_exceptions=True)
115+
test_ds = datasets.EmojiDataset(categories_list=categories_list,
116+
samples_csv_file=image_paths_file,
117+
input_transform=image_transform,
118+
suppress_exceptions=True)
113119
logger.info('Number of samples in testing file: {}'.format(test_ds.n_samples))
114120

115121
# set batch collate
@@ -123,12 +129,11 @@
123129
torch.zeros(n_categories))])
124130
# create loaders
125131
logger.info('Create data loaders')
126-
test_dataloader = torch.utils.data.DataLoader(
127-
test_ds,
128-
batch_size=opt.batch_size,
129-
shuffle=False,
130-
num_workers=opt.num_workers,
131-
collate_fn=collate_fn)
132+
test_dataloader = torch.utils.data.DataLoader(test_ds,
133+
batch_size=opt.batch_size,
134+
shuffle=False,
135+
num_workers=opt.num_workers,
136+
collate_fn=collate_fn)
132137

133138
# model
134139
logger.info('=' * 25)

scripts/model_test.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,28 @@
3232

3333
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
3434
# data
35-
parser.add_argument('--test_csv', type=str, default=None, help='csv file with testing samples')
36-
parser.add_argument('--category_csv', type=str, default=None, help='csv file with category names')
37-
parser.add_argument('--save_dir', type=str, default=None, help='save directory')
35+
parser.add_argument('--test_csv', type=str, default=None,
36+
help='csv file with testing samples')
37+
parser.add_argument('--category_csv', type=str, default=None,
38+
help='csv file with category names')
39+
parser.add_argument('--save_dir', type=str, default=None,
40+
help='save directory')
3841
# model
39-
parser.add_argument('--model_file', type=str, default=None, help='model file path')
40-
parser.add_argument('--batch_size', type=int, default=64, help='batch size for training')
41-
parser.add_argument('--test_steps', type=int, default=-
42-
1, help='number of iterations on evaluation data for one epoch')
42+
parser.add_argument('--model_file', type=str, default=None,
43+
help='model file path')
44+
parser.add_argument('--batch_size', type=int, default=64,
45+
help='batch size for training')
46+
parser.add_argument('--test_steps', type=int, default=-1,
47+
help='number of iterations on evaluation data for one epoch')
4348
parser.add_argument('--num_workers', type=int, default=0,
4449
help='number of workers for data loader')
45-
parser.add_argument('--seed', type=int, default=-1, help='set random seed')
46-
parser.add_argument('--no_gpu', action='store_true', help='do not use GPUs')
50+
parser.add_argument('--seed', type=int, default=-1,
51+
help='set random seed')
52+
parser.add_argument('--no_gpu', action='store_true',
53+
help='do not use GPUs')
4754
# qualitative results
48-
parser.add_argument(
49-
'--no_qualitative',
50-
action='store_true',
51-
help='disable qualitative evaluation')
55+
parser.add_argument('--no_qualitative', action='store_true',
56+
help='disable qualitative evaluation')
5257
parser.add_argument('--per_class_samples', type=int, default=20,
5358
help='the number of most confident samples per class to visualize')
5459
parser.add_argument('--multilabel_samples', type=int, default=30,
@@ -58,10 +63,7 @@
5863
parser.add_argument('--image_size', type=int, default=256,
5964
help='image size for qualitative results')
6065
# logging
61-
parser.add_argument(
62-
'--log_level',
63-
type=str,
64-
default=logging.INFO)
66+
parser.add_argument('--log_level', type=str, default=logging.INFO)
6567
parser.add_argument('--log_interval', type=int, default=100,
6668
help='logging interval in terms of iterations')
6769

@@ -117,8 +119,11 @@
117119

118120
# create a dataset
119121
logger.info('Create a testing EmojiDataset')
120-
test_ds = datasets.EmojiDataset(categories_list=categories_list, samples_csv_file=opt.test_csv,
121-
input_transform=image_transform, target_transform=label_transform, suppress_exceptions=True)
122+
test_ds = datasets.EmojiDataset(categories_list=categories_list,
123+
samples_csv_file=opt.test_csv,
124+
input_transform=image_transform,
125+
target_transform=label_transform,
126+
suppress_exceptions=True)
122127
logger.info('Number of samples in testing file: {}'.format(test_ds.n_samples))
123128

124129
# set batch collate
@@ -132,12 +137,11 @@
132137
torch.zeros(n_categories))])
133138
# create loaders
134139
logger.info('Create data loaders')
135-
test_dataloader = torch.utils.data.DataLoader(
136-
test_ds,
137-
batch_size=opt.batch_size,
138-
shuffle=False,
139-
num_workers=opt.num_workers,
140-
collate_fn=collate_fn)
140+
test_dataloader = torch.utils.data.DataLoader(test_ds,
141+
batch_size=opt.batch_size,
142+
shuffle=False,
143+
num_workers=opt.num_workers,
144+
collate_fn=collate_fn)
141145

142146
# model
143147
logger.info('=' * 25)
@@ -146,7 +150,9 @@
146150
logger.info('in checkpoint: {}'.format(checkpoint.keys()))
147151
model_name = checkpoint['opt'].net_name
148152
logger.info('model type: {}'.format(model_name))
149-
model = nf.create_and_init_model(model_name, checkpoint['model_state'], output_size=n_categories)
153+
model = nf.create_and_init_model(model_name,
154+
checkpoint['model_state'],
155+
output_size=n_categories)
150156
# check if there is a gpu
151157
device = torch.device('cuda' if torch.cuda.is_available() and not opt.no_gpu else 'cpu')
152158
logger.info('using device: {}'.format(device))
@@ -169,17 +175,13 @@
169175
metrics.add_default_eval_metrics(tester, max_k=n_categories - 1)
170176
if not opt.no_qualitative:
171177
logger.info('add qualitative metrics')
172-
tester.add_metric(
173-
'TopBPred',
174-
metrics.TopBinaryPredictions(
175-
n_samples=opt.per_class_samples),
176-
eval=True)
177-
tester.add_metric(
178-
'TopMPred',
179-
metrics.TopMultiLabelPredictions(
180-
n_samples=opt.multilabel_samples,
181-
k=opt.multilabel_k),
182-
eval=True)
178+
tester.add_metric('TopBPred',
179+
metrics.TopBinaryPredictions(n_samples=opt.per_class_samples),
180+
eval=True)
181+
tester.add_metric('TopMPred',
182+
metrics.TopMultiLabelPredictions(n_samples=opt.multilabel_samples,
183+
k=opt.multilabel_k),
184+
eval=True)
183185

184186
# testing
185187
epoch = checkpoint['epoch']

scripts/model_train.py

Lines changed: 65 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,14 @@
3333

3434
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
3535
# data
36-
parser.add_argument('--train_csv', type=str, default=None, help='csv file with training samples')
37-
parser.add_argument('--val_csv', type=str, default=None, help='csv file with validation samples')
38-
parser.add_argument('--category_csv', default=None, type=str, help='csv file with category names')
39-
parser.add_argument('--save_dir', default=None, type=str, help='save directory')
36+
parser.add_argument('--train_csv', type=str, default=None,
37+
help='csv file with training samples')
38+
parser.add_argument('--val_csv', type=str, default=None,
39+
help='csv file with validation samples')
40+
parser.add_argument('--category_csv', default=None, type=str,
41+
help='csv file with category names')
42+
parser.add_argument('--save_dir', default=None, type=str,
43+
help='save directory')
4044
# model
4145
parser = nf.add_model_parser_arguments(parser)
4246
# optim
@@ -47,49 +51,43 @@
4751
parser.add_argument('--scheduler_gamma', type=float, default=0.1,
4852
help='multiplicative factor of learning rate decay')
4953

50-
parser.add_argument('--train_steps', type=int, default=-
51-
1, help='number of iterations on training data for one epoch')
54+
parser.add_argument('--train_steps', type=int, default=-1,
55+
help='number of iterations on training data for one epoch')
5256
parser.add_argument('--train_epochs', type=int, default=1,
5357
help='number of epochs on training data')
54-
parser.add_argument(
55-
'--train_layers_epochs',
56-
type=int,
57-
default=-1,
58-
help='number of epochs to train selected layers before switching to training all layers')
59-
parser.add_argument(
60-
'--sampler_type',
61-
type=str,
62-
default='rnd',
63-
choices=[
64-
'rnd',
65-
'wc_rnd',
66-
'seq'],
67-
help='type of the training data sampler')
68-
parser.add_argument('--sampler_persistent', action='store_true', help='use a persistent sampler')
69-
parser.add_argument('--start_epoch', type=int, default=1, help='index of the start epoch')
70-
parser.add_argument('--input_aug', action='store_true', help='enable input augmentation')
71-
parser.add_argument('--color_aug', action='store_true', help='enable color augmentation')
72-
parser.add_argument('--batch_size', type=int, default=64, help='batch size for training')
73-
parser.add_argument(
74-
'--weighted_pos',
75-
action='store_true',
76-
help='weight positive samples for each class in balance with negatives')
58+
parser.add_argument('--train_layers_epochs', type=int, default=-1,
59+
help='number of epochs to train selected layers before switching to training all layers')
60+
parser.add_argument('--sampler_type', type=str, default='rnd',
61+
choices=['rnd', 'wc_rnd', 'seq'],
62+
help='type of the training data sampler')
63+
parser.add_argument('--sampler_persistent', action='store_true',
64+
help='use a persistent sampler')
65+
parser.add_argument('--start_epoch', type=int, default=1,
66+
help='index of the start epoch')
67+
parser.add_argument('--input_aug', action='store_true',
68+
help='enable input augmentation')
69+
parser.add_argument('--color_aug', action='store_true',
70+
help='enable color augmentation')
71+
parser.add_argument('--batch_size', type=int, default=64,
72+
help='batch size for training')
73+
parser.add_argument('--weighted_pos', action='store_true',
74+
help='weight positive samples for each class in balance with negatives')
7775
parser.add_argument('--weighted_pos_max', type=float, default=None,
7876
help='maximum weight of positive samples for all class')
79-
parser.add_argument('--eval_steps', type=int, default=-
80-
1, help='number of iterations on evaluation data for one epoch')
81-
parser.add_argument('--eval_batch_size', type=int, default=128, help='batch size for evaluation')
77+
parser.add_argument('--eval_steps', type=int, default=-1,
78+
help='number of iterations on evaluation data for one epoch')
79+
parser.add_argument('--eval_batch_size', type=int, default=128,
80+
help='batch size for evaluation')
8281
parser.add_argument('--best_metric', type=str, default='AUC',
8382
help='the evaluation metric used to select best model')
8483
parser.add_argument('--num_workers', type=int, default=0,
8584
help='number of workers for data loader')
86-
parser.add_argument('--seed', type=int, default=-1, help='set random seed')
87-
parser.add_argument('--no_gpu', action='store_true', help='do not use GPUs')
85+
parser.add_argument('--seed', type=int, default=-1,
86+
help='set random seed')
87+
parser.add_argument('--no_gpu', action='store_true',
88+
help='do not use GPUs')
8889
# logging
89-
parser.add_argument(
90-
'--log_level',
91-
type=str,
92-
default=logging.INFO)
90+
parser.add_argument('--log_level', type=str, default=logging.INFO)
9391
parser.add_argument('--log_interval', type=int, default=100,
9492
help='logging interval in terms of iterations')
9593

@@ -138,12 +136,18 @@
138136

139137
# create a dataset
140138
logger.info('Create a training EmojiDataset')
141-
train_ds = datasets.EmojiDataset(categories_list=categories_list, samples_csv_file=opt.train_csv,
142-
input_transform=image_transform_train, target_transform=label_transform, suppress_exceptions=True)
139+
train_ds = datasets.EmojiDataset(categories_list=categories_list,
140+
samples_csv_file=opt.train_csv,
141+
input_transform=image_transform_train,
142+
target_transform=label_transform,
143+
suppress_exceptions=True)
143144
logger.info('Number of samples in training file: {}'.format(train_ds.n_samples))
144145
logger.info('Create a validation EmojiDataset')
145-
valid_ds = datasets.EmojiDataset(categories_list=categories_list, samples_csv_file=opt.val_csv,
146-
input_transform=image_transform_eval, target_transform=label_transform, suppress_exceptions=True)
146+
valid_ds = datasets.EmojiDataset(categories_list=categories_list,
147+
samples_csv_file=opt.val_csv,
148+
input_transform=image_transform_eval,
149+
target_transform=label_transform,
150+
suppress_exceptions=True)
147151
logger.info('Number of samples in validation file: {}'.format(valid_ds.n_samples))
148152

149153
# create data samplers
@@ -174,10 +178,19 @@
174178
torch.zeros(n_categories))])
175179
# create loaders
176180
logger.info('Create data loaders')
177-
train_dataloader = torch.utils.data.DataLoader(train_ds, sampler=train_sampler,
178-
batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers, collate_fn=collate_fn)
179-
eval_dataloader = torch.utils.data.DataLoader(valid_ds, sampler=valid_sampler,
180-
batch_size=opt.eval_batch_size, shuffle=False, num_workers=opt.num_workers, collate_fn=collate_fn, drop_last=False)
181+
train_dataloader = torch.utils.data.DataLoader(train_ds,
182+
sampler=train_sampler,
183+
batch_size=opt.batch_size,
184+
shuffle=False,
185+
num_workers=opt.num_workers,
186+
collate_fn=collate_fn)
187+
eval_dataloader = torch.utils.data.DataLoader(valid_ds,
188+
sampler=valid_sampler,
189+
batch_size=opt.eval_batch_size,
190+
shuffle=False,
191+
num_workers=opt.num_workers,
192+
collate_fn=collate_fn,
193+
drop_last=False)
181194

182195
# model
183196
logger.info('=' * 25)
@@ -198,7 +211,9 @@
198211
if opt.scheduler_step_size > 0:
199212
logger.info('setup learning rate scheduler')
200213
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
201-
step_size=opt.scheduler_step_size, gamma=opt.scheduler_gamma, last_epoch=-1)
214+
step_size=opt.scheduler_step_size,
215+
gamma=opt.scheduler_gamma,
216+
last_epoch=-1)
202217

203218
# Loss
204219
logger.info('setup loss')
@@ -279,7 +294,9 @@
279294
if trainer.lr_scheduler is not None:
280295
logger.info('reset scheduler')
281296
trainer.lr_scheduler = torch.optim.lr_scheduler.StepLR(trainer.optimizer,
282-
step_size=opt.scheduler_step_size, gamma=opt.scheduler_gamma, last_epoch=-1)
297+
step_size=opt.scheduler_step_size,
298+
gamma=opt.scheduler_gamma,
299+
last_epoch=-1)
283300

284301
logger.info('Run time: {}'.format(datetime.now() - tm_start))
285302
if log_file is not None and os.path.exists(log_file):

0 commit comments

Comments
 (0)