-
Notifications
You must be signed in to change notification settings - Fork 0
/
duration_model.py
98 lines (93 loc) · 5.76 KB
/
duration_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import numpy as np
import tensorflow as tf
from modules import get_next_batch, duration_model
from hyperparams import hyperparams
hp = hyperparams()
class Duration_Graph:
def __init__(self, mode='train'):
self.mode = mode.lower()
if self.mode not in ['train', 'test', 'infer']:
raise Exception(f'#-------------------------No supported mode {mode}. Please check.-------------------------#')
self.out_dim = hp.DUR_OUT_DIM
self.scope_name = 'duration_net'
self.reuse = tf.AUTO_REUSE
self.dir = hp.DUR_TF_DIR
self.gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
if self.mode == 'train':
self.is_training = True
self.build_model()
self.show_info()
self.saver = tf.train.Saver()
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=self.gpu_options, allow_soft_placement=True))
self.writer = tf.summary.FileWriter(hp.DUR_LOG_DIR, self.sess.graph)
tf.summary.scalar('{}/loss'.format(self.mode), self.loss)
tf.summary.scalar('{}/lr'.format(self.mode), self.lr)
self.merged = tf.summary.merge_all()
print(f'#-------------------------Try to load trainded model in {hp.DUR_MODEL_DIR} ...-------------------------#')
if tf.train.latest_checkpoint(hp.DUR_MODEL_DIR) != None:
self.saver.restore(self.sess, tf.train.latest_checkpoint(hp.DUR_MODEL_DIR))
print('#-------------------------Successfully loaded.-------------------------#')
else:
self.sess.run(tf.global_variables_initializer())
print(f'#-------------------------Loading trained model failed or No trainded model in {hp.DUR_MODEL_DIR}. Start training with initializer ...-------------------------#')
elif self.mode == 'test':
self.is_training = False
self.build_model()
self.saver = tf.train.Saver()
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=self.gpu_options, allow_soft_placement=True))
print(f'#-------------------------Try to load trainded model in {hp.DUR_MODEL_DIR} ...-------------------------#')
if tf.train.latest_checkpoint(hp.DUR_MODEL_DIR) != None:
self.saver.restore(self.sess, tf.train.latest_checkpoint(hp.DUR_MODEL_DIR))
print('#-------------------------Successfully loaded.-------------------------#')
else:
raise Exception(f'#-------------------------Loading trained model failed or No trainded model in {hp.DUR_MODEL_DIR}. Please check.-------------------------#')
else:
self.is_training = False
self.graph = tf.Graph()
with self.graph.as_default():
self.init = tf.global_variables_initializer()
self.build_model()
self.saver = tf.train.Saver()
self.sess = tf.Session(graph=self.graph, config=tf.ConfigProto(gpu_options=self.gpu_options, allow_soft_placement=True))
print(f'#-------------------------Try to load trainded model in {hp.DUR_MODEL_DIR} ...-------------------------#')
if tf.train.latest_checkpoint(hp.DUR_MODEL_DIR) != None:
self.saver.restore(self.sess, tf.train.latest_checkpoint(hp.DUR_MODEL_DIR))
print('#-------------------------Successfully loaded.-------------------------#')
else:
raise Exception(f'#-------------------------Loading trained model failed or No trainded model in {hp.DUR_MODEL_DIR}. Please check.-------------------------#')
def build_model(self):
if self.mode in ['train', 'test']:
self.x, self.y = get_next_batch(self.dir, mode=self.mode, type='duration')
else:
self.x = tf.placeholder(shape=[None, None, hp.DUR_IN_DIM], dtype=tf.float32, name='dur_lab')
self.y_hat = duration_model(self.x, size=self.out_dim, scope=self.scope_name, reuse=self.reuse)
if self.mode in ['train', 'test']:
self.global_steps = tf.get_variable('global_steps', initializer=0, dtype=tf.int32, trainable=False)
self.lr = tf.train.exponential_decay(hp.DUR_LR,
decay_steps=hp.DUR_LR_DECAY_STEPS,
decay_rate=hp.DUR_LR_DECAY_RATE,
global_step=self.global_steps)
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
self.loss = tf.reduce_mean(tf.square(self.y_hat - self.y))
self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_steps)
def show_info(self):
self.t_vars = tf.trainable_variables()
self.num_paras = 0
for var in self.t_vars:
var_shape = var.get_shape().as_list()
self.num_paras += np.prod(var_shape)
print("#-------------------------Duration model total number of trainable parameters : %r-------------------------#" % (self.num_paras))
def train(self):
_, y_hat, loss, summary, steps = self.sess.run((self.train_op, self.y_hat, self.loss, self.merged, self.global_steps))
self.writer.add_summary(summary, steps)
if steps % (hp.DUR_PER_STEPS + 1) == 0:
self.saver.save(self.sess, os.path.join(hp.DUR_MODEL_DIR, f'dur_model_{steps}steps_%.2flos' % float(loss)))
return y_hat, loss, steps
def test(self):
duration_time, loss, steps = self.sess.run((self.y_hat, self.loss, self.global_steps))
return duration_time, loss, steps
def infer(self, dur_lab):
self.sess.run(self.init)
duration_time = self.sess.run(self.y_hat, feed_dict={self.x: dur_lab})
return duration_time