Skip to content

Commit

Permalink
fix bugs onsite
Browse files Browse the repository at this point in the history
  • Loading branch information
cl0ver012 committed May 23, 2018
1 parent b53da7d commit 1fbeff7
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
img_rows, img_cols = 320, 320
img_rows_half, img_cols_half = 160, 160
channel = 4
batch_size = 32
batch_size = 24
epochs = 1000
patience = 50
num_samples = 43100
Expand Down
10 changes: 6 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import keras.backend as K
from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, ZeroPadding2D, MaxPooling2D, Concatenate, Lambda
from keras.models import Model
from keras.utils import plot_model
import tensorflow as tf
from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, ZeroPadding2D, MaxPooling2D, Concatenate, \
Reshape, Lambda
from keras.models import Model
from keras.utils import multi_gpu_model
from keras.utils import plot_model

from custom_layers.unpooling_layer import Unpooling


Expand Down Expand Up @@ -60,7 +62,7 @@ def build_encoder_decoder():
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
x = UpSampling2D(size=(2, 2))(x)
the_shape = K.int_shape(orig_5)
the_shape = K.int_shape(orig_5)
shape = (1, the_shape[1], the_shape[2], the_shape[3])
origReshaped = Reshape(shape)(orig_5)
# print('origReshaped.shape: ' + str(K.int_shape(origReshaped)))
Expand Down
25 changes: 15 additions & 10 deletions train_final.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import argparse

import keras
import tensorflow as tf
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
Expand All @@ -10,6 +12,11 @@

if __name__ == '__main__':
checkpoint_models_path = 'models/'
# Parse arguments
ap = argparse.ArgumentParser()
ap.add_argument("-p", "--pretrained", help="path to save pretrained model files")
args = vars(ap.parse_args())
pretrained_path = args["pretrained"]

# Callbacks
tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)
Expand All @@ -28,23 +35,21 @@ def on_epoch_end(self, epoch, logs=None):
fmt = checkpoint_models_path + 'model.%02d-%.4f.hdf5'
self.model_to_save.save(fmt % (epoch, logs['val_loss']))


pretrained_path = 'models/final.61-0.0459.hdf5'
num_gpu = len(get_available_gpus())
if num_gpu >= 2:
with tf.device("/cpu:0"):
# Load our model, added support for Multi-GPUs
encoder_decoder = build_encoder_decoder()
final = build_refinement(encoder_decoder)
final.load_weights(pretrained_path)
model = build_encoder_decoder()
model = build_refinement(model)
model.load_weights(pretrained_path)

final = multi_gpu_model(final, gpus=num_gpu)
final = multi_gpu_model(model, gpus=num_gpu)
# rewrite the callback: saving through the original model and not the multi-gpu model.
model_checkpoint = MyCbk(final)
model_checkpoint = MyCbk(model)
else:
encoder_decoder = build_encoder_decoder()
final = build_refinement(encoder_decoder)
final.load_weights(pretrained_path)
model = build_encoder_decoder()
final = build_refinement(model)
final.load_weights(pretrained_path)

# finetune the whole network together.
for layer in final.layers:
Expand Down

0 comments on commit 1fbeff7

Please sign in to comment.