diff --git a/.gitignore b/.gitignore index 82fc706..cc67c38 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ fg_test/ mask_test/ merged_test/ temp/ +.cache/ VOC2008test.tar VOCdevkit/ VOCtest_06-Nov-2007.tar diff --git a/config.py b/config.py index 45f205f..e4b3981 100644 --- a/config.py +++ b/config.py @@ -1,7 +1,7 @@ img_rows, img_cols = 320, 320 img_rows_half, img_cols_half = 160, 160 channel = 4 -batch_size = 24 +batch_size = 20 epochs = 1000 patience = 50 num_samples = 43100 diff --git a/data_generator.py b/data_generator.py index ae20681..e0c9a03 100644 --- a/data_generator.py +++ b/data_generator.py @@ -1,5 +1,6 @@ import math import os +import random from random import shuffle import cv2 as cv @@ -7,7 +8,7 @@ from keras.utils import Sequence from config import batch_size -from config import fg_path, bg_path, a_path, out_path +from config import fg_path, bg_path, a_path from config import img_cols, img_rows from config import unknown_code from utils import safe_crop @@ -42,14 +43,18 @@ def get_alpha_test(name): def composite4(fg, bg, a, w, h): fg = np.array(fg, np.float32) bg_h, bg_w = bg.shape[:2] - x = np.random.randint(0, bg_w - w) - y = np.random.randint(0, bg_h - h) + x = 0 + if bg_w > w: + x = np.random.randint(0, bg_w - w) + y = 0 + if bg_h > h: + y = np.random.randint(0, bg_h - h) bg = np.array(bg[y:y + h, x:x + w], np.float32) alpha = np.zeros((h, w, 1), np.float32) alpha[:, :, 0] = a / 255. im = alpha * fg + (1 - alpha) * bg im = im.astype(np.uint8) - return im, alpha, fg, bg + return im, a, fg, bg def process(im_name, bg_name): @@ -68,9 +73,9 @@ def process(im_name, bg_name): def generate_trimap(alpha): - fg = np.equal(alpha, 255).astype(np.float32) + fg = np.array(np.equal(alpha, 255).astype(np.float32)) fg = cv.erode(fg, kernel, iterations=np.random.randint(1, 3)) - unknown = np.not_equal(alpha, 0).astype(np.float32) + unknown = np.array(np.not_equal(alpha, 0).astype(np.float32)) unknown = cv.dilate(unknown, kernel, iterations=np.random.randint(1, 20)) trimap = fg * 255 + (unknown - fg) * 128 return trimap.astype(np.uint8) @@ -109,7 +114,7 @@ def __getitem__(self, idx): length = min(batch_size, (len(self.names) - i)) batch_x = np.empty((length, img_rows, img_cols, 4), dtype=np.float32) - batch_y = np.empty((length, img_rows, img_cols, 2), dtype=np.float32) + batch_y = np.empty((length, img_rows, img_cols, 11), dtype=np.float32) for i_batch in range(length): name = self.names[i] @@ -121,7 +126,7 @@ def __getitem__(self, idx): # crop size 320:640:480 = 1:1:1 different_sizes = [(320, 320), (480, 480), (640, 640)] - crop_size = np.random.choice(different_sizes) + crop_size = random.choice(different_sizes) trimap = generate_trimap(alpha) x, y = random_choice(trimap, crop_size) diff --git a/segnet.py b/segnet.py index 1738075..a55b706 100644 --- a/segnet.py +++ b/segnet.py @@ -10,12 +10,11 @@ def build_encoder_decoder(): - num_labels = 8 kernel = 3 # Encoder # - input_tensor = Input(shape=(320, 320, 3)) + input_tensor = Input(shape=(320, 320, 4)) x = ZeroPadding2D((1, 1))(input_tensor) x = Conv2D(64, (kernel, kernel), activation='relu', name='conv1_1')(x) x = ZeroPadding2D((1, 1))(x) @@ -66,15 +65,15 @@ def build_encoder_decoder(): xReshaped = Reshape(shape)(x) together = Concatenate(axis=1)([origReshaped, xReshaped]) x = Unpooling()(together) - x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_1', + x = Conv2D(512, (5, 5), activation='relu', padding='same', name='deconv5_1', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) - x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_2', + x = Conv2D(512, (5, 5), activation='relu', padding='same', name='deconv5_2', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) - x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_3', + x = Conv2D(512, (5, 5), activation='relu', padding='same', name='deconv5_3', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) @@ -86,15 +85,15 @@ def build_encoder_decoder(): xReshaped = Reshape(shape)(x) together = Concatenate(axis=1)([origReshaped, xReshaped]) x = Unpooling()(together) - x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_1', + x = Conv2D(256, (5, 5), activation='relu', padding='same', name='deconv4_1', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) - x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_2', + x = Conv2D(256, (5, 5), activation='relu', padding='same', name='deconv4_2', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) - x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_3', + x = Conv2D(256, (5, 5), activation='relu', padding='same', name='deconv4_3', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) @@ -106,15 +105,15 @@ def build_encoder_decoder(): xReshaped = Reshape(shape)(x) together = Concatenate(axis=1)([origReshaped, xReshaped]) x = Unpooling()(together) - x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_1', + x = Conv2D(128, (5, 5), activation='relu', padding='same', name='deconv3_1', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) - x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_2', + x = Conv2D(128, (5, 5), activation='relu', padding='same', name='deconv3_2', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) - x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_3', + x = Conv2D(128, (5, 5), activation='relu', padding='same', name='deconv3_3', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) @@ -126,11 +125,11 @@ def build_encoder_decoder(): xReshaped = Reshape(shape)(x) together = Concatenate(axis=1)([origReshaped, xReshaped]) x = Unpooling()(together) - x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv2_1', + x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv2_1', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) - x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv2_2', + x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv2_2', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) @@ -142,16 +141,16 @@ def build_encoder_decoder(): xReshaped = Reshape(shape)(x) together = Concatenate(axis=1)([origReshaped, xReshaped]) x = Unpooling()(together) - x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv1_1', + x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv1_1', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) - x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv1_2', + x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv1_2', kernel_initializer='he_normal', bias_initializer='zeros')(x) x = BatchNormalization()(x) - x = Conv2D(num_labels, (1, 1), activation='softmax', padding='valid', name='pred', kernel_initializer='he_normal', + x = Conv2D(1, (5, 5), activation='sigmoid', padding='same', name='pred', kernel_initializer='he_normal', bias_initializer='zeros')(x) model = Model(inputs=input_tensor, outputs=x) diff --git a/unit_tests.py b/unit_tests.py index d01a056..0a7a154 100644 --- a/unit_tests.py +++ b/unit_tests.py @@ -6,7 +6,7 @@ import os from config import unknown_code from data_generator import generate_trimap -from data_generator import get_alpha_test +from data_generator import get_alpha from data_generator import random_choice from utils import safe_crop @@ -14,9 +14,10 @@ class TestStringMethods(unittest.TestCase): def test_generate_trimap(self): - image = cv.imread('fg_test/cat-1288531_1920.png') - alpha = cv.imread('mask_test/cat-1288531_1920.png', 0) + image = cv.imread('fg/1-1252426161dfXY.jpg') + alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0) trimap = generate_trimap(alpha) + self.assertEqual(trimap.shape, (615, 410)) # ensure np.where works as expected. count = 0 @@ -37,7 +38,7 @@ def test_generate_trimap(self): self.assertEqual(trimap[center_x, center_y], unknown_code) x, y = random_choice(trimap) - print(x, y) + # print(x, y) image = safe_crop(image, x, y) trimap = safe_crop(trimap, x, y) alpha = safe_crop(alpha, x, y) @@ -46,8 +47,9 @@ def test_generate_trimap(self): cv.imwrite('temp/test_generate_trimap_alpha.png', alpha) def test_flip(self): - image = cv.imread('fg_test/cat-1288531_1920.png') - alpha = cv.imread('mask_test/cat-1288531_1920.png', 0) + image = cv.imread('fg/1-1252426161dfXY.jpg') + # print(image.shape) + alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0) trimap = generate_trimap(alpha) x, y = random_choice(trimap) image = safe_crop(image, x, y) @@ -63,16 +65,14 @@ def test_flip(self): def test_different_sizes(self): different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)] crop_size = random.choice(different_sizes) - print('crop_size=' + str(crop_size)) + # print('crop_size=' + str(crop_size)) def test_resize(self): - with open('Combined_Dataset/Test_set/test_bg_names.txt') as f: - bg_test_files = f.read().splitlines() - name = '35_716.png' - filename = os.path.join('merged_test', name) + name = '0_0.png' + filename = os.path.join('merged', name) image = cv.imread(filename) bg_h, bg_w = image.shape[:2] - a = get_alpha_test(name) + a = get_alpha(name) a_h, a_w = a.shape[:2] alpha = np.zeros((bg_h, bg_w), np.float32) alpha[0:a_h, 0:a_w] = a diff --git a/utils.py b/utils.py index a588943..53e4b60 100644 --- a/utils.py +++ b/utils.py @@ -23,7 +23,7 @@ def overall_loss(y_true, y_pred): # absolute values, we use the following loss function to approximate it. def alpha_prediction_loss(y_true, y_pred): mask = y_true[:, :, :, 1] - diff = y_pred - y_true[:, :, :, 0] + diff = y_pred[:, :, :, 0] - y_true[:, :, :, 0] diff = diff * mask num_pixels = K.sum(mask) return K.sum(K.sqrt(K.square(diff) + epsilon_sqr)) / (num_pixels + epsilon) @@ -34,6 +34,7 @@ def alpha_prediction_loss(y_true, y_pred): # alpha mattes. def compositional_loss(y_true, y_pred): mask = y_true[:, :, :, 1] + mask = K.reshape(mask, (-1, img_rows, img_cols, 1)) image = y_true[:, :, :, 2:5] fg = y_true[:, :, :, 5:8] bg = y_true[:, :, :, 8:11] @@ -89,12 +90,14 @@ def get_final_output(out, trimap): def safe_crop(mat, x, y, crop_size=(img_rows, img_cols)): - h, w = crop_size + crop_height, crop_width = crop_size if len(mat.shape) == 2: - ret = np.zeros((h, w), np.float32) + ret = np.zeros((crop_height, crop_width), np.float32) else: - ret = np.zeros((h, w, 3), np.float32) - ret[0:h, 0:w] = mat[y:y + h, x:x + w] + ret = np.zeros((crop_height, crop_width, 3), np.float32) + crop = mat[y:y + crop_height, x:x + crop_width] + h, w = crop.shape[:2] + ret[0:h, 0:w] = crop if crop_size != (img_rows, img_cols): ret = cv.resize(ret, dsize=(img_rows, img_cols), interpolation=cv.INTER_CUBIC) return ret