diff --git a/.gitignore b/.gitignore index 8eb1899..9fe5204 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ custom ffhq afhq logs +abandon # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/inference.py b/inference.py index 0f890c9..87e1e4c 100644 --- a/inference.py +++ b/inference.py @@ -25,25 +25,35 @@ def __init__(self, self.ckpt = tf.train.Checkpoint(generator_clone=self.Gs) print(f'Loading network from {checkpoint_path}...') self.ckpt.restore(tf.train.latest_checkpoint(checkpoint_path)).expect_partial() - + tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) # Enable mixed precision + + + def _get_labels(self, label, length): + if self.num_labels > 0: # Contain labels + labels_indice = [label]*length if label is not None else tf.random.uniform([length], 0, self.num_labels, dtype=tf.int32) + labels = tf.one_hot(labels_indice, self.num_labels) + else: + labels_indice = [0]*length + labels = tf.zeros([length, 0], tf.float32) + + return labels, labels_indice + def genetate_example(self, num_examples, batch_size=1, label=None): create_dir(f'{self.result_path}/example') print('Generating images...') for begin in range(0, num_examples, batch_size): latents = tf.random.normal([batch_size, 512]) - labels_indice = [label]*batch_size if label is not None else tf.random.uniform([batch_size], 0, self.num_labels, dtype=tf.int32) - labels = tf.one_hot(labels_indice, self.num_labels) if self.num_labels > 0 else tf.zeros([batch_size, 0], tf.float32) + labels, labels_indice = self._get_labels(label, batch_size) images = self.Gs([latents, labels], self.truncation_psi, training=False) - for idx, (image, label) in enumerate(zip(images, labels_indice)): - imsave(image, f'{self.result_path}/example/{begin+idx}_label-{label}.jpg') + for idx, (image, indice) in enumerate(zip(images, labels_indice)): + imsave(image, f'{self.result_path}/example/{begin+idx}_label-{indice}.jpg') def style_mixing_example(self, row_seeds, col_seeds, label=None, col_styles='0-6'): create_dir(f'{self.result_path}/mixing') all_seeds = list(set(row_seeds + col_seeds)) - all_labels_indice = [label]*len(all_seeds) if label is not None else tf.random.uniform([len(all_seeds)], 0, self.num_labels, dtype=tf.int32) - all_labels = tf.one_hot(all_labels_indice, self.num_labels) if self.num_labels > 0 else tf.zeros([len(all_seeds), 0], tf.float32) + all_labels, all_labels_indices = self._get_labels(label, len(all_seeds)) all_z = tf.stack([tf.random.normal([512], seed=seed) for seed in all_seeds]) # [minibatch, component] print('Generating images...') @@ -88,8 +98,8 @@ def generate_gif(self, output='test', label=None, num_rows=2, num_cols=3, resolu output_seq = [] batch_size = num_rows * num_cols latents = [tf.random.normal([batch_size, 512]) for _ in range(num_phases)] - labels_indice = [label]*batch_size if label is not None else tf.random.uniform([batch_size], 0, self.num_labels, dtype=tf.int32) - labels = [tf.one_hot(labels_indice, self.num_labels) if self.num_labels > 0 else tf.zeros([batch_size, 0], tf.float32) for _ in range(num_phases)] + labels, labels_indice = self._get_labels(label, batch_size) + labels = tf.repeat(tf.expand_dims(labels, axis=0), repeats=num_phases, axis=0) def to_image_grid(outputs): outputs = (outputs + 1) * 127.5 diff --git a/main.py b/main.py index 70d8857..1afc45c 100644 --- a/main.py +++ b/main.py @@ -26,7 +26,7 @@ def main(): parser.add_argument('--total_img', help='Training length of images', type=int, default=25000000) parser.add_argument('--impl', help="(Faster)Custom op use:'cuda'; (Slower)Tensorflow op use:'ref'", type=str, default='ref', choices=['ref','cuda']) # Inference - parser.add_argument('--label', help='Inference label', type=int, default=0) + parser.add_argument('--label', help='Inference label', type=int, default=None) parser.add_argument('--num_labels', help='Number of labels', type=int, default=0) parser.add_argument('--truncation_psi', help='Inference truncation psi', type=float, default=0.5) parser.add_argument('--mode', help="Inference mode be one of: 'example', 'gif', 'mixing'", type=str, default='example', choices=['example','gif','mixing']) @@ -71,7 +71,6 @@ def main(): assert 0.0 <= args.truncation_psi <= 1.0, 'Error: Inference truncation_psi needs to be between 0 and 1.' assert args.res >= 4 assert args.num_labels >= 0 - assert 0 <= args.label <= max(0, args.num_labels-1) parameters = { 'resolution' : args.res, diff --git a/modules/generator.py b/modules/generator.py index 334f177..87406fe 100644 --- a/modules/generator.py +++ b/modules/generator.py @@ -205,7 +205,7 @@ def truncation_trick(self, w_latents, truncation_psi): truncated_w = self.w_avg + (w_latents - self.w_avg) * layer_psi return truncated_w - + @tf.function def call(self, inputs, truncation_psi=0.5, return_latents=False, training=None): latents_in, labels_in = inputs