-
Notifications
You must be signed in to change notification settings - Fork 1
/
sample.py
113 lines (89 loc) · 3.68 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import torch
import torchvision
import BigGAN
import utils
import shutil
def trunc_trick(bs, z_dim, bound=2.0):
z = torch.randn(bs, z_dim)
while z.abs().max() > bound:
z = torch.where(z.abs() <= bound, z, torch.randn_like(z))
return z
def collect_bn_stats(G, n_samples, config, device):
im_batch_size = config['n_classes']
G.train()
for i_batch in range(0, n_samples, im_batch_size):
with torch.no_grad():
z = torch.randn(im_batch_size, G.dim_z, device=device)
y = torch.arange(im_batch_size).to(device)
_images = G(z, G.shared(y)).float().cpu()
def generate_images(out_dir, G, n_images, config, device):
im_batch_size = config['n_classes']
z_bound = config['trunc_z']
if z_bound > 0.0:
print(f'Truncating z to (-{z_bound}, {z_bound})')
for i_batch in range(0, n_images, im_batch_size):
with torch.no_grad():
if z_bound > 0.0:
z = trunc_trick(im_batch_size, G.dim_z, bound=z_bound).to(device)
else:
z = torch.randn(im_batch_size, G.dim_z, device=device)
y = torch.arange(im_batch_size).to(device)
images = G(z, G.shared(y)).float().cpu()
if i_batch + im_batch_size > n_images:
n_last_images = n_images - i_batch
print(f'Taking only {n_last_images} images from the last batch...')
images = images[:n_last_images]
for i_image, image in enumerate(images):
fname = os.path.join(out_dir, f'image_{i_batch+i_image:05d}.png')
image = utils.denorm(image)
torchvision.utils.save_image(image, fname)
def run(config):
# Prepare state dict, which holds things like epoch # and itr #
state_dict = {'itr': 0, 'epoch': 0, 'save_num': 0, 'config': config}
# update config (see train.py for explanation)
config['resolution'] = 64
config['n_classes'] = 120
config['G_activation'] = utils.activation_dict[config['G_nl']]
config['D_activation'] = utils.activation_dict[config['D_nl']]
config = utils.update_config_roots(config)
config['skip_init'] = True
config['no_optim'] = True
device = 'cuda'
# Seed RNG
utils.seed_rng(config['seed'])
# Setup cudnn.benchmark for free speed
torch.backends.cudnn.benchmark = True
experiment_name = (config['experiment_name'] if config['experiment_name']
else 'generative_dog_images')
print('Experiment name is %s' % experiment_name)
G = BigGAN.Generator(**config).cuda()
# Load weights
print('Loading weights...')
# Here is where we deal with the ema--load ema weights or load normal weights
utils.load_weights(G if not (config['use_ema']) else None, None, state_dict,
config['weights_root'], experiment_name, config['load_weights'],
G if config['ema'] and config['use_ema'] else None,
strict=False, load_optim=False)
if config['use_ema']:
collect_bn_stats(G, 10_000, config, device)
if config['G_eval_mode']:
print('Putting G in eval mode..')
G.eval()
else:
print('G is in %s mode...' % ('training' if G.training else 'eval'))
out_dir = config['samples_root']
if not os.path.exists(out_dir):
os.mkdir(out_dir)
print('Generating images..')
generate_images(out_dir, G, config['sample_num'], config, device)
shutil.make_archive('images', 'zip', out_dir)
def main():
# parse command line and run
parser = utils.prepare_parser()
parser = utils.add_sample_parser(parser)
config = vars(parser.parse_args())
print(config)
run(config)
if __name__ == '__main__':
main()