Skip to content

Commit

Permalink
Merge branch 'master' of git+ssh://github.com/foamliu/Deep-Image-Matting
Browse files Browse the repository at this point in the history
  • Loading branch information
cl0ver012 committed May 18, 2018
2 parents 45be108 + 79f694d commit 661d29e
Show file tree
Hide file tree
Showing 45 changed files with 46 additions and 22 deletions.
2 changes: 1 addition & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
img_rows, img_cols = 320, 320
channel = 4

model_weights_path = 'models/model.62-0.0524.hdf5'
model_weights_path = 'models/model.98-0.0459.hdf5'
model = build_encoder_decoder()
model.load_weights(model_weights_path)
print(model.summary())
Expand Down
Binary file modified images/0_alpha.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/0_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/0_out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/0_trimap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/1_alpha.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/1_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/1_out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/1_trimap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/2_alpha.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/2_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/2_out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/2_trimap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/3_alpha.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/3_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/3_out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/3_trimap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/4_alpha.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/4_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/4_out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/4_trimap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/5_alpha.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/5_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/5_out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/5_trimap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/6_alpha.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified images/6_image.png
Binary file modified images/6_out.png
Binary file modified images/6_trimap.png
Binary file modified images/7_alpha.png
Binary file modified images/7_image.png
Binary file modified images/7_out.png
Binary file modified images/7_trimap.png
Binary file modified images/8_alpha.png
Binary file modified images/8_image.png
Binary file modified images/8_out.png
Binary file modified images/8_trimap.png
Binary file modified images/9_alpha.png
Binary file modified images/9_image.png
Binary file modified images/9_out.png
Binary file modified images/9_trimap.png
7 changes: 4 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import keras.backend as K
from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, ZeroPadding2D, MaxPooling2D, Concatenate
from keras.layers import Input, Conv2D, UpSampling2D, BatchNormalization, ZeroPadding2D, MaxPooling2D, Concatenate, Lambda
from keras.models import Model
from keras.utils import plot_model

Expand Down Expand Up @@ -99,8 +99,9 @@ def build_encoder_decoder():
def build_refinement(encoder_decoder):
input_tensor = encoder_decoder.input

x = encoder_decoder.output
x = Concatenate(axis=3)([input_tensor[:, :, :, 0:3], x])
input = Lambda(lambda i: i[:, :, :, 0:3])(input_tensor)

x = Concatenate(axis=3)([input, encoder_decoder.output])
x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
Expand Down
32 changes: 16 additions & 16 deletions train_final.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@

# Callbacks
tensor_board = keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)
model_names = checkpoint_models_path + 'refinement.{epoch:02d}-{val_loss:.4f}.hdf5'
model_names = checkpoint_models_path + 'final.{epoch:02d}-{val_loss:.4f}.hdf5'
model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True)
early_stop = EarlyStopping('val_loss', patience=patience)
reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1)

pretrained_path = 'models/'
encoder_decoder = build_encoder_decoder()
refinement = build_refinement(encoder_decoder)
refinement.load_weights(pretrained_path)
final = build_refinement(encoder_decoder)
final.load_weights(pretrained_path)
# finetune the whole network together.
for layer in refinement.layers:
for layer in final.layers:
layer.trainable = True

sgd = SGD(lr=1e-5, decay=1e-6, momentum=0.9, nesterov=True)
refinement.compile(optimizer=sgd, loss=custom_loss_wrapper(refinement.input))
final.compile(optimizer=sgd, loss=custom_loss_wrapper(final.input))

print(refinement.summary())
print(final.summary())

# Summarize then go!
num_cpu = get_available_cpus()
Expand All @@ -38,13 +38,13 @@
callbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr]

# Start Fine-tuning
refinement.fit_generator(train_gen(),
steps_per_epoch=num_train_samples // batch_size,
validation_data=valid_gen(),
validation_steps=num_valid_samples // batch_size,
epochs=epochs,
verbose=1,
callbacks=callbacks,
use_multiprocessing=True,
workers=workers
)
final.fit_generator(train_gen(),
steps_per_epoch=num_train_samples // batch_size,
validation_data=valid_gen(),
validation_steps=num_valid_samples // batch_size,
epochs=epochs,
verbose=1,
callbacks=callbacks,
use_multiprocessing=True,
workers=workers
)
2 changes: 1 addition & 1 deletion train_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
early_stop = EarlyStopping('val_loss', patience=patience)
reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience / 4), verbose=1)

pretrained_path = 'models/'
pretrained_path = 'models/model.98-0.0459.hdf5'
encoder_decoder = build_encoder_decoder()
encoder_decoder.load_weights(pretrained_path)
# fix encoder-decoder part parameters and then update the refinement part.
Expand Down
25 changes: 24 additions & 1 deletion unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import cv2 as cv
import numpy as np

import os
from config import unknown
from data_generator import generate_trimap
from data_generator import get_alpha_test
from data_generator import random_choice
from utils import safe_crop

Expand Down Expand Up @@ -64,6 +65,28 @@ def test_different_sizes(self):
crop_size = random.choice(different_sizes)
print('crop_size=' + str(crop_size))

def test_resize(self):
with open('Combined_Dataset/Test_set/test_bg_names.txt') as f:
bg_test_files = f.read().splitlines()
name = '35_716.png'
filename = os.path.join('merged_test', name)
image = cv.imread(filename)
bg_h, bg_w = image.shape[:2]
a = get_alpha_test(name)
a_h, a_w = a.shape[:2]
alpha = np.zeros((bg_h, bg_w), np.float32)
alpha[0:a_h, 0:a_w] = a
trimap = generate_trimap(alpha)
# 剪切尺寸 320:640:480 = 3:1:1
crop_size = (480, 480)
x, y = random_choice(trimap, crop_size)
image = safe_crop(image, x, y, crop_size)
trimap = safe_crop(trimap, x, y, crop_size)
alpha = safe_crop(alpha, x, y, crop_size)
cv.imwrite('temp/test_resize_image.png', image)
cv.imwrite('temp/test_resize_trimap.png', trimap)
cv.imwrite('temp/test_resize_alpha.png', alpha)


if __name__ == '__main__':
unittest.main()

0 comments on commit 661d29e

Please sign in to comment.