-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
64 lines (52 loc) · 1.96 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import yaml
import argparse
import numpy as np
from models import *
from utils.dataset_CelebA import genDatasetCelebA
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument('--config', '-c',
dest="filename",
metavar='FILE',
default=None)
parser.add_argument('--save_model', '-s',
dest="Is_save_model",
default='true')
args = parser.parse_args()
with open(args.filename, 'r') as file:
try :
config = yaml.load(file, Loader=yaml.FullLoader)
except yaml.YAMLError as e:
print(e)
# --- dataset
train_gen, test_gen = genDatasetCelebA(**config['dataset_params'])
# --- model
model = gan_models[config['model_params']['model_name']](**config['model_params'])
# --- Optimizer
if config['opt_params']['gen_opt']['name'] == 'Adam':
gen_opt = tfk.optimizers.Adam(config['opt_params']['gen_opt']['learning_rate'])
if config['opt_params']['disc_opt']['name'] == 'Adam':
disc_opt = tfk.optimizers.Adam(config['opt_params']['disc_opt']['learning_rate'])
if 'en_opt' in config['opt_params']:
if config['opt_params']['en_opt']['name'] == 'Adam':
en_opt = tfk.optimizers.Adam(config['opt_params']['en_opt']['learning_rate'])
else:
en_opt = None
# --- train
trainer(model,
train_gen,
test_gen,
gen_opt= gen_opt,
disc_opt= disc_opt,
en_opt= en_opt,
epochs=config['train_params']['epochs'],
iter_disc=config['train_params']['iter_disc'],
iter_gen=config['train_params']['iter_gen'],
save_path=config['train_params']['save_path'],
save_model_path = config['train_params']['save_model_path'],
scale=config['dataset_params']['scale'],
batch_size=config['dataset_params']['batch_size'])
# --- save model
if args.Is_save_model == 'true':
path = config['train_params']['save_model_path'] + model.model_name +'.h5'
model.save_weights(path)