Skip to content

Commit

Permalink
Merge pull request #1 from cryu854/Metrics-feature
Browse files Browse the repository at this point in the history
Metrics feature
  • Loading branch information
cryu854 authored Nov 23, 2020
2 parents 209141f + 1fe79d5 commit 066a03f
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ weights-lpips
custom
ffhq
afhq
log
logs

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
15 changes: 8 additions & 7 deletions cal_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ def get_generator(checkpoint_path, resolution, num_labels, config, randomize_noi
def calculate_metric(generator, num_labels, mode, dataset_path):
fid50k_full_parameters = {'num_images':50000, 'num_labels':num_labels , 'batch_size':8}
ppl_wend_parameters = {'num_images':50000, 'num_labels':num_labels, 'epsilon':1e-4, 'space':'w', 'sampling':'end', 'crop':False, 'batch_size':2}
# ppl_wfull_parameters = {'num_images':50000, 'num_labels':num_labels, 'epsilon':1e-4, 'space':'w', 'sampling':'full','crop':True, 'batch_size':2}

start = time.perf_counter()
if mode == 'fid':
assert os.path.exists(dataset_path), 'Error: Dataset does not exist.'
fid = FID(**fid50k_full_parameters)
dist = fid.evaluate(generator, real_dir=dataset_path)
FID_metric = FID(**fid50k_full_parameters)
dist = FID_metric(generator, real_dir=dataset_path)
else: # mode == 'ppl'
ppl = PPL(**ppl_wend_parameters)
dist = ppl.evaluate(generator)
print(f'Time taken for {mode} evaluation: {round(time.perf_counter()-start)}s')
PPL_metric = PPL(**ppl_wend_parameters)
dist = PPL_metric(generator)
return dist


Expand Down Expand Up @@ -60,9 +59,11 @@ def main():
'randomize_noise' : False,
}
Gs = get_generator(**Gs_parameters)

start = time.perf_counter()
dist = calculate_metric(Gs, args.num_labels, args.mode.lower(), args.dataset)
print(f'{args.mode} : {dist:.3f}')

print(f'Time taken for evaluation: {round(time.perf_counter()-start)}s')

if __name__ == '__main__':
main()
26 changes: 20 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,38 @@ def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description='StyleGAN2')
parser.add_argument('command',help="'train' or 'inference'", type=str, choices=['train', 'inference'])
parser.add_argument('--impl', help="(Faster)Custom op use:'cuda'; (Slower)Tensorflow op use:'ref'", type=str, default='ref', choices=['ref','cuda'])
# Common
parser.add_argument('--res', help='Resolution of image', type=int, default=1024)
parser.add_argument('--config', help="Model's config be one of: 'e', 'f'", type=str, default='f')
parser.add_argument('--ckpt', help='Checkpoints/Weights directory', type=str, default='./checkpoint')
# Training
parser.add_argument('--dataset_name', help="Specific dataset be one of: 'ffhq', 'afhq', 'custom'", type=str, default='afhq', choices=['ffhq','afhq','custom'])
parser.add_argument('--dataset_path', help='Dataset directory', type=str, default='./../../datasets/afhq/train_labels')
parser.add_argument('--batch_size', help='Training batch size', type=int, default=4)
parser.add_argument('--res', help='Resolution of image', type=int, default=1024)
parser.add_argument('--total_img', help='Training length of images', type=int, default=25000000)
parser.add_argument('--ckpt', help='Checkpoints/Weights directory', type=str, default='./checkpoint')
parser.add_argument('--num_labels', help='Number of labels', type=int, default=0)
parser.add_argument('--impl', help="(Faster)Custom op use:'cuda'; (Slower)Tensorflow op use:'ref'", type=str, default='ref', choices=['ref','cuda'])
# Inference
parser.add_argument('--label', help='Inference label', type=int, default=0)
parser.add_argument('--num_labels', help='Number of labels', type=int, default=0)
parser.add_argument('--truncation_psi', help='Inference truncation psi', type=float, default=0.5)
parser.add_argument('--mode', help="Inference mode be one of: 'example', 'gif', 'mixing'", type=str, default='example', choices=['example','gif','mixing'])
# Misc
parser.add_argument('--save_step', help='Steps to save checkpoint and write summary', type=int, default=5000)
parser.add_argument('--print_step', help="Steps to print training losses", type=int, default=100)
parser.add_argument('--metric_step', help="Steps to calculate FID/PPL metrics(A multiple of save_step)", type=int, default=0)

args = parser.parse_args()


# Validate arguments
if args.command == 'train':
assert os.path.exists(args.dataset_path), 'Error: Dataset does not exist.'
assert args.batch_size > 0
assert args.res >= 4
assert args.batch_size > 0
assert args.total_img > 0
assert args.save_step > 0
assert args.print_step > 0
assert args.metric_step % args.save_step == 0 if args.metric_step > 0 else args.metric_step == 0

parameters = {
'resolution' : args.res,
Expand All @@ -45,7 +56,10 @@ def main():
'dataset_name' : args.dataset_name,
'dataset_path' : args.dataset_path,
'checkpoint_path' : args.ckpt,
'impl' : args.impl
'impl' : args.impl,
'save_step': args.save_step,
'print_step': args.print_step,
'metric_step': args.metric_step,
}

trainer = Trainer(**parameters)
Expand Down
15 changes: 7 additions & 8 deletions modules/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,17 @@ def __init__(self, randomize_noise, **kwargs):

def build(self, input_shape):
self.noise_shape = input_shape
self.noise = self.add_weight(name='noise',
shape=[1, self.noise_shape[1], self.noise_shape[2], 1],
dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 1.0),
trainable=False)
self.noise_strength = self.add_weight(name='noise_strength',
shape=[],
dtype=tf.float32,
initializer=tf.zeros_initializer(),
trainable=True)
if self.randomize_noise is not True:
self.noise = self.add_weight(name='noise',
shape=[1, self.noise_shape[1], self.noise_shape[2], 1],
dtype=tf.float32,
initializer=tf.random_normal_initializer(0, 1.0),
trainable=False)


def call(self, inputs, training=None):
if self.randomize_noise:
noise = tf.random.normal([tf.shape(inputs)[0], self.noise_shape[1], self.noise_shape[2], 1])
Expand Down Expand Up @@ -101,7 +100,7 @@ def __init__(self,
impl='ref',
randomize_noise=True,
w_avg_beta=0.995,
style_mixing_prob=0.9, # 0.9 for f, 0.5 for e
style_mixing_prob=0.9,
**kwargs):
super(generator, self).__init__(**kwargs)
self.resolution = resolution
Expand Down
18 changes: 8 additions & 10 deletions modules/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _evaluate_fakes_step(self, Gs):
feats = self.inception_v3(images)
return feats

def evaluate(self, Gs, real_dir=None):
def __call__(self, Gs, real_dir=None):
cache_file = f'{real_dir}/FID-{self.num_images}-cache.npy'

# Calculate mean and covariance statistics for reals.
Expand All @@ -55,7 +55,7 @@ def evaluate(self, Gs, real_dir=None):
real_dataset = self._create_dataset(real_dir)
feats = []
for reals in real_dataset.take(self.num_images//self.batch_size):
feats.append(self.inception_v3(reals).numpy())
feats.append(self.inception_v3(reals))
feats = np.concatenate(feats, axis=0)
real_mu = np.mean(feats, axis=0)
real_sigma = np.cov(feats, rowvar=False)
Expand All @@ -67,7 +67,7 @@ def evaluate(self, Gs, real_dir=None):
print('Start evaluating fake statistics...')
feats = []
for _ in range(0, self.num_images, self.batch_size):
feats.append(self._evaluate_fakes_step(Gs).numpy())
feats.append(self._evaluate_fakes_step(Gs))
feats = np.concatenate(feats, axis=0)
fake_mu = np.mean(feats, axis=0)
fake_sigma = np.cov(feats, rowvar=False)
Expand Down Expand Up @@ -156,12 +156,12 @@ def _evaluate_step(self, Gs):
dist = self.lpips([img_e0, img_e1]) * (1 / self.epsilon**2)
return dist

def evaluate(self, Gs):
def __call__(self, Gs):
# Sampling loop.
all_distances = []
print('Start evaluating PPL...')
for _ in range(0, self.num_images, self.batch_size):
dist = self._evaluate_step(Gs).numpy()
dist = self._evaluate_step(Gs)
all_distances.append(dist)
all_distances = np.concatenate(all_distances, axis=0)

Expand All @@ -172,6 +172,7 @@ def evaluate(self, Gs):
dist = np.mean(filtered_distances)
return dist


# Learned perceptual metric
class LPIPS(Model):
def __init__(self, lpips=True, spatial=False):
Expand Down Expand Up @@ -202,9 +203,6 @@ def _normalize_tensor(self, inputs, epsilon=1e-10):
def _spatial_average(self, inputs, keepdims=True):
return tf.reduce_mean(inputs, axis=[1,2], keepdims=keepdims)

def _upsample(self, inputs, out_HW=(64,64)): # assumes scale factor is same for H and W
return UpSampling2D(size=out_HW, mode='bilinear')(inputs)

def call(self, inputs, training=None):
""" Expected inputs range between [1, -1] """
imgs1, imgs2 = inputs
Expand All @@ -221,12 +219,12 @@ def call(self, inputs, training=None):

if(self.lpips):
if(self.spatial):
res = [self._upsample(self.lin_layers[kk](diffs[kk]), out_HW=imgs1.shape[1:-1]) for kk in range(self.L)]
res = [tf.image.resize(self.lin_layers[kk](diffs[kk]), size=imgs1.shape[1:-1], method='bilinear') for kk in range(self.L)]
else:
res = [self._spatial_average(self.lin_layers[kk](diffs[kk]), keepdims=True) for kk in range(self.L)]
else:
if(self.spatial):
res = [self._upsample(diffs[kk].sum(dim=1,keepdims=True), out_HW=imgs1.shape[1:-1]) for kk in range(self.L)]
res = [tf.image.resize(diffs[kk].sum(dim=1,keepdims=True), size=imgs1.shape[1:-1], method='bilinear') for kk in range(self.L)]
else:
res = [self._spatial_average(diffs[kk].sum(dim=1,keepdims=True), keepdims=True) for kk in range(self.L)]

Expand Down
Loading

0 comments on commit 066a03f

Please sign in to comment.