diff --git a/demo.py b/demo.py index 6fa32ea..9396a74 100644 --- a/demo.py +++ b/demo.py @@ -6,7 +6,7 @@ import numpy as np from data_generator import generate_trimap, random_choice, get_alpha_test -from model import build_encoder_decoder_net +from model import build_encoder_decoder from utils import get_final_output, safe_crop if __name__ == '__main__': @@ -14,7 +14,7 @@ channel = 4 model_weights_path = 'models/model.62-0.0524.hdf5' - model = build_encoder_decoder_net() + model = build_encoder_decoder() model.load_weights(model_weights_path) print(model.summary()) diff --git a/migrate.py b/migrate.py index 40faf4a..71eeb3c 100644 --- a/migrate.py +++ b/migrate.py @@ -2,7 +2,7 @@ import numpy as np from config import channel -from model import build_encoder_decoder_net +from model import build_encoder_decoder from vgg16 import vgg16_model @@ -43,7 +43,7 @@ def migrate_model(new_model): if __name__ == '__main__': - model = build_encoder_decoder_net() + model = build_encoder_decoder() migrate_model(model) print(model.summary()) model.save_weights('models/model_weights.h5') diff --git a/model.py b/model.py index da47e06..a396263 100644 --- a/model.py +++ b/model.py @@ -6,7 +6,7 @@ from custom_layers.unpooling_layer import Unpooling -def build_encoder_decoder_net(): +def build_encoder_decoder(): # Encoder input_tensor = Input(shape=(320, 320, 4)) x = ZeroPadding2D((1, 1))(input_tensor) @@ -96,7 +96,7 @@ def build_encoder_decoder_net(): return model -def build_refinement_net(encoder_decoder): +def build_refinement(encoder_decoder): input_tensor = encoder_decoder.input x = encoder_decoder.output x = Concatenate(axis=1)([input_tensor, x]) @@ -111,12 +111,12 @@ def build_refinement_net(encoder_decoder): if __name__ == '__main__': - encoder_decoder = build_encoder_decoder_net(320, 320, 4) + encoder_decoder = build_encoder_decoder() # input_layer = model.get_layer('input') print(encoder_decoder.summary()) plot_model(encoder_decoder, to_file='encoder_decoder.svg', show_layer_names=True, show_shapes=True) - refinement = build_refinement_net(encoder_decoder) + refinement = build_refinement(encoder_decoder) print(refinement.summary()) plot_model(refinement, to_file='refinement.svg', show_layer_names=True, show_shapes=True) diff --git a/plot_model.py b/plot_model.py index 414c679..8f5fd52 100644 --- a/plot_model.py +++ b/plot_model.py @@ -1,9 +1,9 @@ # dependency: pip install pydot & brew install graphviz -from model import build_encoder_decoder_net +from model import build_encoder_decoder from keras.utils import plot_model if __name__ == '__main__': img_rows, img_cols = 320, 320 channel = 3 - model = build_encoder_decoder_net(img_rows, img_cols, channel) + model = build_encoder_decoder(img_rows, img_cols, channel) plot_model(model, to_file='model.svg', show_layer_names=True, show_shapes=True) diff --git a/test.py b/test.py index 2bbc17c..558d949 100644 --- a/test.py +++ b/test.py @@ -3,7 +3,7 @@ import cv2 as cv import numpy as np -from model import build_encoder_decoder_net +from model import build_encoder_decoder # python test.py -i "images/image.png" -t "images/trimap.png" if __name__ == '__main__': @@ -11,7 +11,7 @@ channel = 4 model_weights_path = 'models/model.35-0.03.hdf5' - model = build_encoder_decoder_net() + model = build_encoder_decoder() model.load_weights(model_weights_path) print(model.summary()) diff --git a/train.py b/train.py index a7836f2..b33395e 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,7 @@ import migrate from config import patience, batch_size, epochs, num_train_samples, num_valid_samples from data_generator import train_gen, valid_gen -from model import build_encoder_decoder_net +from model import build_encoder_decoder from utils import custom_loss_wrapper, get_available_cpus, get_available_gpus if __name__ == '__main__': @@ -57,10 +57,10 @@ def on_epoch_end(self, epoch, logs=None): # model_checkpoint = MyCbk(model) # else: if pretrained_path is not None: - new_model = build_encoder_decoder_net() + new_model = build_encoder_decoder() new_model.load_weights(pretrained_path) else: - new_model = build_encoder_decoder_net() + new_model = build_encoder_decoder() migrate.migrate_model(new_model) # sgd = SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True)