-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
444 lines (376 loc) · 18.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import torchvision
import xml.etree.ElementTree as ET
from tqdm import tqdm_notebook as tqdm
import time
import argparse
from collections import OrderedDict
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.utils as vutils
import numpy as np
import cv2
import matplotlib.animation as animation
from IPython.display import HTML
import sys
from models import generators,discriminators
def prepare_parser():
parser = argparse.ArgumentParser()
# data settings
parser.add_argument('--data', type=str, default='channels'
,help = 'type of data')
parser.add_argument('--data_path', type=str, default='datasets/prop_channels_train/'
,help = 'data path')
parser.add_argument('--labels_path', type=str, default= None
,help = 'labels path')
parser.add_argument('--data_ext', type=str, default='txt'
,help = 'data extension txt, png')
parser.add_argument('--sampling', type=int, default=None
,help = 'randomly sample --sampling instances from the training data if not None')
# models settings
parser.add_argument('--G_model', type=str, default='residual_GAN'
,help = 'Generator Model can be residual_GAN, dcgan, ...')
parser.add_argument('--D_model', type=str, default='residual_GAN'
,help = 'Discriminator Model can be residual_GAN, dcgan, sngan,...')
parser.add_argument('--cgan',action='store_true',default=False
,help = 'Use conditional GAN if True (only implmented in residual_GAN)')
parser.add_argument('--att',action='store_true',default=False
,help = 'Use Attention if True (only implmented in residual_GAN)')
parser.add_argument('--img_ch', type=int, default=3
,help = 'the number of image channels 1 for grayscale 3 for RGB')
parser.add_argument('--G_ch', type=int, default=52
,help = 'base multiplier for the Generator (for cnn_GAN should be large 512/1024) , (for ')
parser.add_argument('--D_ch', type=int, default=32
,help = 'base multiplier for the discriminator')
parser.add_argument('--leak_G', type=float, default=0
,help = 'use leaky relu activation for generator with leak= leak_G,zero value will use RELU')
parser.add_argument('--leak_D', type=float, default=0
,help = 'use leaky relu activation for discriminator with leak= leak_G,zero value will use RELU')
parser.add_argument('--zdim', type=int, default=128
,help ='dimenstion of latent vector')
parser.add_argument('--spec_norm_D', default=False,action='store_true'
,help = 'apply spectral normalization in discriminator')
parser.add_argument('--spec_norm_G', default=False,action='store_true'
,help = 'apply spectral normalization in generator')
parser.add_argument('--n_layers_D', type=int, default=3
,help = 'number of layers used in discriminator of dcgan')
# Double discriminators parameters
# optimizers settings
parser.add_argument('--lr_G', type=float, default=2e-4
,help = 'Generator learning rate')
parser.add_argument('--lr_D', type=float, default=2e-4
,help = 'discriminator learning rate')
parser.add_argument('--beta1', type=float, default=0
,help = 'first momentum value for ADAM optimizer')
parser.add_argument('--beta2', type=float, default=0.999
,help = 'second momentum value for ADAM optimizer')
parser.add_argument('--batch_size', type=int, default=64
,help = 'discriminator batch size')
parser.add_argument('--G_batch_size', type=int, default=None
,help = 'generator batch size if None it will be set to batch_size')
#training settings
parser.add_argument('--loss', type=str, default='standard'
,help = 'Loss function can be standard,hinge or wgan')
parser.add_argument('--disc_iters', type=int, default=1
,help = ' no. discriminator updates per one generator update')
parser.add_argument('--epochs', type=int, default=1
,help ='no. of epochs')
parser.add_argument('--limit', type=float, default=None
,help = 'if not None will limit training to --limit seconds')
parser.add_argument('--save_rate', type=int, default=30
,help = 'save checkpoints every 30 epcohs')
parser.add_argument('--ema',action='store_true' , default=False
,help = 'keep EMA of G weights')
parser.add_argument('--ema_decay',type = float, default=0.999
,help = 'EMA decay rate')
parser.add_argument('--decay_lr',type=str,default=None
,help = 'if not None decay the learning rates (exp,step)')
parser.add_argument('--saved_cp', type=str, default=None
,help='if not None start training from a loaded cp with path --saved_cp')
parser.add_argument('--seed', type=int, default=None
,help = 'None to use random seed can be fixed for reporoduction')
parser.add_argument('--z_dist', type=str, default='normal'
, help = ' distribution of latent space normal/uniform')
parser.add_argument('--smooth',default=False,action='store_true'
, help = 'Use smooth labeling if True')
# GPU settings
parser.add_argument('--ngpu', type=int, default=1
,help = 'number of gpus to be used')
parser.add_argument('--dev_num', type=int, default=0
,help = 'the index of a gpu to be used if --ngpu is 1 ')
# folder name
parser.add_argument('--fname', type=str, default='models_cp',help='folder name to save cp')
# Conditional GAN settings
parser.add_argument('--D_cond_method', type=str, default='concat'
,help='conditiong method concat for concatentation/proj for projection')
parser.add_argument('--G_cond_method', type=str, default='cbn'
,help='conditiong method concat for concatentation/cbn for conditional batch normalization')
parser.add_argument('--n_cl', type=int, default=0
,help = 'number of classes, 1 for continious conditioning')
parser.add_argument('--real_cond_list', nargs='+', default=None,type=float
,help='List of conditions values for the real samples e.g. 0.25 0.30 0.35 if not provided it will be 0 1 2 ..')
parser.add_argument('--discrete',action='store_true' , default=False
,help = 'if True Sample only discrete labels from min_label to max_label')
parser.add_argument('--c_list', nargs='+', default=None,type=float
,help='list of conditions to be sampled, if provided')
parser.add_argument('--min_label', type=float, default=0
,help='minimum label for conditions')
parser.add_argument('--max_label', type=float, default=None
,help='maximum label for conditions, if None will be set num of classes')
parser.add_argument('--ohe',action='store_true' , default=False
,help = 'use one hot encoding for conditioning')
parser.add_argument('--SN_y', action='store_true' , default=False
,help='apply SN to the condition linear layer')
parser.add_argument('--y_real_GD', action='store_true' , default=False
,help='Use same real conditions for both G and D')
parser.add_argument('--x_fake_GD', action='store_true' , default=False
,help='Use same fake data for both G and D')
return parser
def prepare_device(args):
# Device
ngpu = args.ngpu
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.dev_num)
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
print("Device: ",device)
return device
def prepare_seed(args):
#Seeds
if args.seed is None:
seed = random.randint(1, 10000) # use if you want new results
else :
seed = args.seed #random.randint(1, 10000) # use if you want new results
print("Random Seed: ", seed)
return seed
def init_weight(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.orthogonal_(m.weight, gain=1)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Batch') != -1:
m.weight.data.normal_(1,0.02)
m.bias.data.zero_()
elif classname.find('Linear') != -1:
nn.init.orthogonal_(m.weight, gain=1)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Embedding') != -1:
nn.init.orthogonal_(m.weight, gain=1)
#dataset
def prepare_data(args):
print(" laoding " +args.data +" ...")
if args.data == 'dogs':
from datasets import Dogs_labels
train_data = Dogs_labels.DogsDataset()
dataloader = torch.utils.data.DataLoader(train_data,
shuffle=True, batch_size=args.batch_size,
num_workers=12,pin_memory=True)
elif args.data == 'cifar':
train_data = dset.CIFAR10(root='./data', train=True, download=True,
transform=transforms.Compose([
# transforms.Resize(image_size),
# transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
elif args.data == 'channels':
from datasets import channels_datasets
train_data = channels_datasets.Channels(path = args.data_path
,labels_path=args.labels_path
,ext = args.data_ext
,sampling = args.sampling)
else:
print('no data named :',args.data)
exit()
dataloader = torch.utils.data.DataLoader(train_data,
shuffle=True, batch_size=args.batch_size,
num_workers=4)
print("Finished data loading")
return dataloader,train_data
def prepare_models(args,n_cl = 0,device = 'cpu',only_G = False):
#model
if args.G_model == 'dcgan':
netG = generators.DC_Generator(args.zdim,img_ch=args.img_ch,base_ch= args.G_ch).to(device)
netG.apply(init_weight)
elif args.G_model == 'residual_GAN':
netG = generators.Res_Generator(args.zdim,img_ch=args.img_ch,n_classes = n_cl
,base_ch = args.G_ch,leak = args.leak_G,att = args.att
,SN = args.spec_norm_G
,cond_method = args.G_cond_method).to(device)
if only_G:
return netG
if args.D_model == 'dcgan':
netD = discriminators.DC_Discriminator(img_ch=args.img_ch
,base_ch= args.D_ch
,leak = args.leak_D
,n_layers = args.n_layers_D).to(device)
netD.apply(init_weight)
elif args.D_model == 'cnn_sngan':
netD = discriminators.SN_Discriminator(img_ch=args.img_ch
,base_ch= args.D_ch
,leak = args.leak_D
,SN=args.spec_norm_D).to(device)
netD.apply(init_weight)
elif args.D_model == 'residual_GAN':
netD = discriminators.Res_Discriminator(img_ch=args.img_ch,n_classes = n_cl,base_ch = args.D_ch
,leak = args.leak_D,att = args.att
,cond_method = args.D_cond_method
,SN = args.spec_norm_D
,SN_y = args.SN_y).to(device)
return netG,netD
def load_from_saved(args,netG,netD,optimizerG,optimizerD):
checkpoint = torch.load(args.saved_cp)
#load G
state_dict_G = checkpoint['netG_state_dict']
new_state_dict_G = OrderedDict()
for k, v in state_dict_G.items():
if 'module' in k:
k = k.replace('module.','')
new_state_dict_G[k] = v
netG.load_state_dict(new_state_dict_G)
#Load D
state_dict_D = checkpoint['netD_state_dict']
new_state_dict_D = OrderedDict()
for k, v in state_dict_D.items():
if 'module' in k:
k = k.replace('module.','')
new_state_dict_D[k] = v
netD.load_state_dict(new_state_dict_D)
#load optimizer
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
st_epoch = checkpoint['epoch']+1
G_losses = checkpoint['Gloss']
D_losses = checkpoint['Dloss']
#args = checkpoint['args']
return netG,netD,optimizerG,optimizerD,st_epoch,G_losses,D_losses
def prepare_filename(args):
filename = str(args.epochs) + "_"
if args.fname is not None:
if not os.path.exists(args.fname):
os.makedirs(args.fname)
filename = args.fname+"/" + filename
return filename
#generate random labels
def sample_pseudo_labels(args,num_classes,batch_size,device):
#returns labels used in D and G respectively.
if args.max_label is None:
max_value = num_classes
else:
max_value = args.max_label
#if num_classes > 1:
# list of conditions
if args.c_list is not None:
c = torch.tensor(args.c_list)
y_ind = torch.randint(low=0, high=len(c), size=(batch_size,1))
y = c[y_ind].to(device)
return y,y
# discrete conditions
elif args.discrete:
y = torch.randint(low=int(args.min_label), high=int(max_value), size=(batch_size,1)).to(device)
return y,y
elif args.ohe:
y = torch.randint(low=int(args.min_label), high=int(max_value), size=(batch_size,)).to(device)
y_ohe = torch.eye(num_classes)[y].to(device)
return y_ohe,y_ohe
else: # continious conditions
y = (args.max_label - args.min_label) * torch.rand(batch_size,num_classes) + args.min_label
y = y.to(device)
return y,y
def disc_2_ohe(y,num_classes,device):
y_ohe = torch.eye(num_classes)[y].to(device)
return y_ohe
def disc_2_cont(y,c_list,device): # convert discrete label to label in c_list
for i,v in enumerate(c_list):
y[y==i] = v
y = y.unsqueeze(1)
return y
#generate fake images
def sample_from_gen(args,b_size, zdim, num_classes,netG,device ='cpu',truncated = 0,real_y = None):
# latent z
if args.z_dist == 'normal':
z = torch.randn(b_size, zdim).to(device=device)
elif args.z_dist =='uniform':
z =2*torch.rand(b_size, zdim).to(device=device) -1
if truncated > 0:
z = get_trun_noise(truncated,zdim,b_size,device)
#labels
if num_classes>0:
if args.y_real_GD:
y_D = real_y
y_G = real_y
else:
y_D,y_G = sample_pseudo_labels(args,num_classes,b_size,device)
else:
y_D,y_G = None,None
fake = netG(z, y_G)
return fake, y_D
def load_netG(netG,checkpointname = None):
checkpoint = torch.load(checkpointname)
state_dict_G = checkpoint['netG_state_dict']
new_state_dict_G = OrderedDict()
for k, v in state_dict_G.items():
if 'module' in k:
k = k.replace('module.','')
new_state_dict_G[k] = v
netG.load_state_dict(new_state_dict_G)
netG.eval()
return netG
def get_trun_noise(truncated,z_dim,b_size,device):
flag = True
while flag:
z = np.random.randn(100*b_size*z_dim)
z = z[np.where(abs(z)<truncated)]
if len(z)>=64*z_dim:
flag=False
z = torch.from_numpy(z[:b_size*z_dim]).view(b_size,z_dim)
z = z.float().to(device)
return z
def save_grid(args,netG,device,nrows=3,ncol=3,out_path = 'plot'):
b_size = nrows*ncol
gen_images,_ = sample_from_gen(args,b_size,args.zdim,args.n_cl,netG,device,truncated = args.truncated)
gen_images = gen_images.cpu().detach()
if args.img_ch == 1:
gen_images = gen_images.squeeze()
plt.close('all')
fig,axes = plt.subplots(nrows,ncol,figsize=[10,10])
#fake = fake.permute((0,3, 2, 1))
for i,iax in enumerate( axes.flatten() ):
iax.imshow(gen_images[i,:,:])
iax.set_xticks([])
iax.set_yticks([])
fig.show()
else:
plt.close('all')
fig = plt.figure(figsize=(10,10))
l = vutils.make_grid(gen_images,nrow=nrows, padding=2, normalize=True).numpy()
plt.imshow(np.transpose(l,(1,2,0)))
plt.xticks([])
plt.yticks([])
fig.savefig(out_path+'.png')
def elapsed_time(start_time):
return time.time() - start_time
def calc_ralsloss_G(real, fake, margin=1.0):
loss_real = torch.mean((real - fake.mean() + margin) ** 2)
loss_fake = torch.mean((fake - real.mean() - margin) ** 2)
loss = (loss_real + loss_fake)
return loss
def replace_face(img,old_face,new_face):
new_img = img.copy()
for x in range(np.size(new_img,0)):
for y in range(np.size(new_img,1)):
if sum(new_img[x,y,:] == old_face)==3 :
new_img[x,y,:] = new_face
return new_img