Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cl0ver012 committed May 28, 2018
1 parent ed98626 commit d3034bf
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions train_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from config import patience, batch_size, epochs, num_train_samples, num_valid_samples
from data_generator import train_gen, valid_gen
from model import build_encoder_decoder
from utils import custom_loss_wrapper, get_available_cpus, get_available_gpus
from utils import overall_loss, get_available_cpus, get_available_gpus

if __name__ == '__main__':
# Parse arguments
Expand Down Expand Up @@ -43,28 +43,28 @@ def on_epoch_end(self, epoch, logs=None):

# Load our model, added support for Multi-GPUs
num_gpu = len(get_available_gpus())
# if num_gpu >= 2:
# with tf.device("/cpu:0"):
# if pretrained_path is not None:
# model = create_model()
# model.load_weights(pretrained_path)
# else:
# model = create_model()
# migrate.migrate_model(model)
#
# new_model = multi_gpu_model(model, gpus=num_gpu)
# # rewrite the callback: saving through the original model and not the multi-gpu model.
# model_checkpoint = MyCbk(model)
# else:
if pretrained_path is not None:
new_model = build_encoder_decoder()
new_model.load_weights(pretrained_path)
if num_gpu >= 2:
with tf.device("/cpu:0"):
if pretrained_path is not None:
model = build_encoder_decoder()
model.load_weights(pretrained_path)
else:
model = build_encoder_decoder()
migrate.migrate_model(model)

new_model = multi_gpu_model(model, gpus=num_gpu)
# rewrite the callback: saving through the original model and not the multi-gpu model.
model_checkpoint = MyCbk(model)
else:
new_model = build_encoder_decoder()
migrate.migrate_model(new_model)
if pretrained_path is not None:
new_model = build_encoder_decoder()
new_model.load_weights(pretrained_path)
else:
new_model = build_encoder_decoder()
migrate.migrate_model(new_model)

# sgd = SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True)
new_model.compile(optimizer='nadam', loss=custom_loss_wrapper(new_model.input))
new_model.compile(optimizer='nadam', loss=overall_loss)

print(new_model.summary())

Expand Down

0 comments on commit d3034bf

Please sign in to comment.