-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_pretrain.py
267 lines (222 loc) · 10.7 KB
/
main_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import argparse
import json
import os
import random
import shutil
import time
import datetime
import numpy as np
import torch
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from data.datasets import ImageFolder
from data.transforms import CustomDataAugmentation
from models import resnet
from models import vision_transformer as vit
from models.slotcon import SlotCon
from utils.lars import LARS
from utils.logger import setup_logger
from utils.lr_scheduler import get_scheduler
from utils.util import AverageMeter
model_names = sorted(name for name in resnet.__all__ if name.islower() and callable(resnet.__dict__[name]))
vit_names = sorted(name for name in vit.__all__ if name.islower() and callable(vit.__dict__[name]))
# Concatenate two lists
model_names.extend(vit_names)
def get_parser():
parser = argparse.ArgumentParser('SlotCon')
# dataset
parser.add_argument('--dataset', type=str, default='COCO', choices=['COCO', 'COCOplus', 'ImageNet'], help='dataset type')
parser.add_argument('--data-dir', type=str, default='./data', help='dataset director')
parser.add_argument('--image-size', type=int, default=224, help='image crop size')
parser.add_argument('--min-scale', type=float, default=0.08, help='minimum crop scale')
# model
parser.add_argument('--arch', type=str, default='resnet50', choices=model_names, help='backbone architecture')
parser.add_argument('--dim-hidden', type=int, default=4096, help='hidden dimension')
parser.add_argument('--dim-out', type=int, default=256, help='output feature dimension')
parser.add_argument('--num-prototypes', type=int, default=256, help='number of prototypes')
parser.add_argument('--teacher-momentum', default=0.99, type=float, help='momentum value for the teacher model')
parser.add_argument('--teacher-temp', default=0.07, type=float, help='teacher temperature')
parser.add_argument('--student-temp', default=0.1, type=float, help='student temperature')
parser.add_argument('--center-momentum', default=0.9, type=float, help='momentum for the center')
parser.add_argument('--group-loss-weight', default=0.5, type=float, help='balancing weight of the grouping loss')
# optim.
parser.add_argument('--batch-size', type=int, default=512, help='total batch size')
parser.add_argument('--base-lr', type=float, default=1.0,
help='base learning when batch size = 256. final lr is determined by linear scale')
parser.add_argument('--optimizer', type=str, choices=['sgd', 'lars', 'adamw'], default='sgd', help='optimizer choice')
parser.add_argument('--warmup-epoch', type=int, default=5, help='warmup epoch')
parser.add_argument('--warmup-multiplier', type=int, default=100, help='warmup multiplier')
parser.add_argument('--weight-decay', type=float, default=1e-5, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD')
parser.add_argument('--fp16', action='store_true', default=True, help='whether or not to turn on automatic mixed precision')
parser.add_argument('--start-epoch', type=int, default=1, help='used for resume')
parser.add_argument('--epochs', type=int, default=800, help='number of training epochs')
# misc
parser.add_argument('--output-dir', type=str, default='./output', help='output director')
parser.add_argument('--auto-resume', action='store_true', help='auto resume from current.pth')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to the latest checkpoint')
parser.add_argument('--print-freq', type=int, default=10, help='print frequency')
parser.add_argument('--save-freq', type=int, default=50, help='save frequency')
parser.add_argument('--seed', type=int, help='Random seed.')
parser.add_argument('--num-workers', type=int, default=8, help='num of workers per GPU to use')
args = parser.parse_args()
if os.environ["LOCAL_RANK"] is not None:
args.local_rank = int(os.environ["LOCAL_RANK"])
return args
def build_model(args):
if 'vit' in args.arch:
encoder = vit.__dict__[args.arch]
else:
encoder = resnet.__dict__[args.arch]
model = SlotCon(encoder, args).cuda()
if args.optimizer == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.batch_size * args.world_size / 256 * args.base_lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
elif args.optimizer == 'lars':
optimizer = LARS(
model.parameters(),
lr=args.batch_size * args.world_size / 256 * args.base_lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
elif args.optimizer == 'adamw':
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.batch_size * args.world_size / 256 * args.base_lr,
weight_decay=args.weight_decay)
else:
raise NotImplementedError
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
return model, optimizer
def save_checkpoint(args, epoch, model, optimizer, scheduler, scaler=None):
logger.info('==> Saving...')
state = {
'args': args,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch,
}
if args.fp16:
state['scaler'] = scaler.state_dict()
file_name = os.path.join(args.output_dir, f'ckpt_epoch_{epoch}.pth')
torch.save(state, file_name)
shutil.copyfile(file_name, os.path.join(args.output_dir, 'current.pth'))
def load_checkpoint(args, model, optimizer, scheduler, scaler=None):
if os.path.isfile(args.resume):
logger.info(f"=> loading checkpoint '{args.resume}'")
checkpoint = torch.load(args.resume, map_location='cpu')
args.start_epoch = checkpoint['epoch'] + 1
model.module.re_init(args)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
if args.fp16 and 'scaler' in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])
logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
else:
logger.info("=> no checkpoint found at '{}'".format(args.resume))
def main(args):
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# prepare data
transform = CustomDataAugmentation(args.image_size, args.min_scale)
train_dataset = ImageFolder(args.dataset, args.data_dir, transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.num_workers, pin_memory=True, sampler=train_sampler, drop_last=True)
args.num_instances = len(train_loader.dataset)
logger.info(f"length of training dataset: {args.num_instances}")
# create model
logger.info("=> creating model '{}'".format(args.arch))
model, optimizer = build_model(args)
logger.info(model)
# define scheduler
scheduler = get_scheduler(optimizer, len(train_loader), args)
# define scaler
if args.fp16:
scaler = torch.cuda.amp.GradScaler()
else:
scaler = None
# optionally resume from a checkpoint
if args.auto_resume:
resume_file = os.path.join(args.output_dir, "current.pth")
if os.path.exists(resume_file):
logger.info(f'auto resume from {resume_file}')
args.resume = resume_file
else:
logger.info(f'no checkpoint found in {args.output_dir}, ignoring auto resume')
if args.resume:
load_checkpoint(args, model, optimizer, scheduler, scaler)
for epoch in range(args.start_epoch, args.epochs + 1):
train_sampler.set_epoch(epoch)
# train for one epoch
train(train_loader, model, optimizer, scaler, scheduler, epoch, args)
if dist.get_rank() == 0 and (epoch % args.save_freq == 0 or epoch == args.epochs):
save_checkpoint(args, epoch, model, optimizer, scheduler, scaler)
def train(train_loader, model, optimizer, scaler, scheduler, epoch, args):
batch_time = AverageMeter()
loss_meter = AverageMeter()
# switch to train mode
model.train()
end = time.time()
train_len = len(train_loader)
for i, batch in enumerate(train_loader):
crops, coords, flags = batch
crops = [crop.cuda(non_blocking=True) for crop in crops]
coords = [coord.cuda(non_blocking=True) for coord in coords]
flags = [flag.cuda(non_blocking=True) for flag in flags]
# compute output and loss
with torch.cuda.amp.autocast(scaler is not None):
loss = model((crops, coords, flags))
optimizer.zero_grad()
if args.fp16:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
scheduler.step()
# avg loss from batch size
loss_meter.update(loss.item(), crops[0].size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
lr = optimizer.param_groups[0]['lr']
etas = batch_time.avg * (train_len - i)
logger.info(
f'Train: [{epoch}/{args.epochs}][{i}/{train_len}] '
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.4f} '
f'time {batch_time.val:.4f} ({batch_time.avg:.4f}) '
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) ')
if __name__ == '__main__':
args = get_parser()
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
cudnn.benchmark = True
args.world_size = dist.get_world_size()
args.batch_size = int(args.batch_size / args.world_size)
# setup logger
os.makedirs(args.output_dir, exist_ok=True)
logger = setup_logger(output=args.output_dir,
distributed_rank=dist.get_rank(), name="slotcon")
if dist.get_rank() == 0:
path = os.path.join(args.output_dir, "config.json")
with open(path, 'w') as f:
json.dump(vars(args), f, indent=2)
logger.info("Full config saved to {}".format(path))
# print args
logger.info(
"\n".join("%s: %s" % (k, str(v))
for k, v in sorted(dict(vars(args)).items()))
)
main(args)