Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
cl0ver012 committed May 25, 2018
1 parent 41e7d5c commit b7ef7a9
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 42 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ fg_test/
mask_test/
merged_test/
temp/
.cache/
VOC2008test.tar
VOCdevkit/
VOCtest_06-Nov-2007.tar
Expand Down
2 changes: 1 addition & 1 deletion config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
img_rows, img_cols = 320, 320
img_rows_half, img_cols_half = 160, 160
channel = 4
batch_size = 24
batch_size = 20
epochs = 1000
patience = 50
num_samples = 43100
Expand Down
21 changes: 13 additions & 8 deletions data_generator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import math
import os
import random
from random import shuffle

import cv2 as cv
import numpy as np
from keras.utils import Sequence

from config import batch_size
from config import fg_path, bg_path, a_path, out_path
from config import fg_path, bg_path, a_path
from config import img_cols, img_rows
from config import unknown_code
from utils import safe_crop
Expand Down Expand Up @@ -42,14 +43,18 @@ def get_alpha_test(name):
def composite4(fg, bg, a, w, h):
fg = np.array(fg, np.float32)
bg_h, bg_w = bg.shape[:2]
x = np.random.randint(0, bg_w - w)
y = np.random.randint(0, bg_h - h)
x = 0
if bg_w > w:
x = np.random.randint(0, bg_w - w)
y = 0
if bg_h > h:
y = np.random.randint(0, bg_h - h)
bg = np.array(bg[y:y + h, x:x + w], np.float32)
alpha = np.zeros((h, w, 1), np.float32)
alpha[:, :, 0] = a / 255.
im = alpha * fg + (1 - alpha) * bg
im = im.astype(np.uint8)
return im, alpha, fg, bg
return im, a, fg, bg


def process(im_name, bg_name):
Expand All @@ -68,9 +73,9 @@ def process(im_name, bg_name):


def generate_trimap(alpha):
fg = np.equal(alpha, 255).astype(np.float32)
fg = np.array(np.equal(alpha, 255).astype(np.float32))
fg = cv.erode(fg, kernel, iterations=np.random.randint(1, 3))
unknown = np.not_equal(alpha, 0).astype(np.float32)
unknown = np.array(np.not_equal(alpha, 0).astype(np.float32))
unknown = cv.dilate(unknown, kernel, iterations=np.random.randint(1, 20))
trimap = fg * 255 + (unknown - fg) * 128
return trimap.astype(np.uint8)
Expand Down Expand Up @@ -109,7 +114,7 @@ def __getitem__(self, idx):

length = min(batch_size, (len(self.names) - i))
batch_x = np.empty((length, img_rows, img_cols, 4), dtype=np.float32)
batch_y = np.empty((length, img_rows, img_cols, 2), dtype=np.float32)
batch_y = np.empty((length, img_rows, img_cols, 11), dtype=np.float32)

for i_batch in range(length):
name = self.names[i]
Expand All @@ -121,7 +126,7 @@ def __getitem__(self, idx):

# crop size 320:640:480 = 1:1:1
different_sizes = [(320, 320), (480, 480), (640, 640)]
crop_size = np.random.choice(different_sizes)
crop_size = random.choice(different_sizes)

trimap = generate_trimap(alpha)
x, y = random_choice(trimap, crop_size)
Expand Down
31 changes: 15 additions & 16 deletions segnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@


def build_encoder_decoder():
num_labels = 8
kernel = 3

# Encoder
#
input_tensor = Input(shape=(320, 320, 3))
input_tensor = Input(shape=(320, 320, 4))
x = ZeroPadding2D((1, 1))(input_tensor)
x = Conv2D(64, (kernel, kernel), activation='relu', name='conv1_1')(x)
x = ZeroPadding2D((1, 1))(x)
Expand Down Expand Up @@ -66,15 +65,15 @@ def build_encoder_decoder():
xReshaped = Reshape(shape)(x)
together = Concatenate(axis=1)([origReshaped, xReshaped])
x = Unpooling()(together)
x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_1',
x = Conv2D(512, (5, 5), activation='relu', padding='same', name='deconv5_1',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_2',
x = Conv2D(512, (5, 5), activation='relu', padding='same', name='deconv5_2',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
x = Conv2D(512, (kernel, kernel), activation='relu', padding='same', name='deconv5_3',
x = Conv2D(512, (5, 5), activation='relu', padding='same', name='deconv5_3',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
Expand All @@ -86,15 +85,15 @@ def build_encoder_decoder():
xReshaped = Reshape(shape)(x)
together = Concatenate(axis=1)([origReshaped, xReshaped])
x = Unpooling()(together)
x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_1',
x = Conv2D(256, (5, 5), activation='relu', padding='same', name='deconv4_1',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_2',
x = Conv2D(256, (5, 5), activation='relu', padding='same', name='deconv4_2',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
x = Conv2D(256, (kernel, kernel), activation='relu', padding='same', name='deconv4_3',
x = Conv2D(256, (5, 5), activation='relu', padding='same', name='deconv4_3',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
Expand All @@ -106,15 +105,15 @@ def build_encoder_decoder():
xReshaped = Reshape(shape)(x)
together = Concatenate(axis=1)([origReshaped, xReshaped])
x = Unpooling()(together)
x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_1',
x = Conv2D(128, (5, 5), activation='relu', padding='same', name='deconv3_1',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_2',
x = Conv2D(128, (5, 5), activation='relu', padding='same', name='deconv3_2',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
x = Conv2D(128, (kernel, kernel), activation='relu', padding='same', name='deconv3_3',
x = Conv2D(128, (5, 5), activation='relu', padding='same', name='deconv3_3',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
Expand All @@ -126,11 +125,11 @@ def build_encoder_decoder():
xReshaped = Reshape(shape)(x)
together = Concatenate(axis=1)([origReshaped, xReshaped])
x = Unpooling()(together)
x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv2_1',
x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv2_1',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv2_2',
x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv2_2',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
Expand All @@ -142,16 +141,16 @@ def build_encoder_decoder():
xReshaped = Reshape(shape)(x)
together = Concatenate(axis=1)([origReshaped, xReshaped])
x = Unpooling()(together)
x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv1_1',
x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv1_1',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)
x = Conv2D(64, (kernel, kernel), activation='relu', padding='same', name='deconv1_2',
x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv1_2',
kernel_initializer='he_normal',
bias_initializer='zeros')(x)
x = BatchNormalization()(x)

x = Conv2D(num_labels, (1, 1), activation='softmax', padding='valid', name='pred', kernel_initializer='he_normal',
x = Conv2D(1, (5, 5), activation='sigmoid', padding='same', name='pred', kernel_initializer='he_normal',
bias_initializer='zeros')(x)

model = Model(inputs=input_tensor, outputs=x)
Expand Down
24 changes: 12 additions & 12 deletions unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
import os
from config import unknown_code
from data_generator import generate_trimap
from data_generator import get_alpha_test
from data_generator import get_alpha
from data_generator import random_choice
from utils import safe_crop


class TestStringMethods(unittest.TestCase):

def test_generate_trimap(self):
image = cv.imread('fg_test/cat-1288531_1920.png')
alpha = cv.imread('mask_test/cat-1288531_1920.png', 0)
image = cv.imread('fg/1-1252426161dfXY.jpg')
alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0)
trimap = generate_trimap(alpha)
self.assertEqual(trimap.shape, (615, 410))

# ensure np.where works as expected.
count = 0
Expand All @@ -37,7 +38,7 @@ def test_generate_trimap(self):
self.assertEqual(trimap[center_x, center_y], unknown_code)

x, y = random_choice(trimap)
print(x, y)
# print(x, y)
image = safe_crop(image, x, y)
trimap = safe_crop(trimap, x, y)
alpha = safe_crop(alpha, x, y)
Expand All @@ -46,8 +47,9 @@ def test_generate_trimap(self):
cv.imwrite('temp/test_generate_trimap_alpha.png', alpha)

def test_flip(self):
image = cv.imread('fg_test/cat-1288531_1920.png')
alpha = cv.imread('mask_test/cat-1288531_1920.png', 0)
image = cv.imread('fg/1-1252426161dfXY.jpg')
# print(image.shape)
alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0)
trimap = generate_trimap(alpha)
x, y = random_choice(trimap)
image = safe_crop(image, x, y)
Expand All @@ -63,16 +65,14 @@ def test_flip(self):
def test_different_sizes(self):
different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)]
crop_size = random.choice(different_sizes)
print('crop_size=' + str(crop_size))
# print('crop_size=' + str(crop_size))

def test_resize(self):
with open('Combined_Dataset/Test_set/test_bg_names.txt') as f:
bg_test_files = f.read().splitlines()
name = '35_716.png'
filename = os.path.join('merged_test', name)
name = '0_0.png'
filename = os.path.join('merged', name)
image = cv.imread(filename)
bg_h, bg_w = image.shape[:2]
a = get_alpha_test(name)
a = get_alpha(name)
a_h, a_w = a.shape[:2]
alpha = np.zeros((bg_h, bg_w), np.float32)
alpha[0:a_h, 0:a_w] = a
Expand Down
13 changes: 8 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def overall_loss(y_true, y_pred):
# absolute values, we use the following loss function to approximate it.
def alpha_prediction_loss(y_true, y_pred):
mask = y_true[:, :, :, 1]
diff = y_pred - y_true[:, :, :, 0]
diff = y_pred[:, :, :, 0] - y_true[:, :, :, 0]
diff = diff * mask
num_pixels = K.sum(mask)
return K.sum(K.sqrt(K.square(diff) + epsilon_sqr)) / (num_pixels + epsilon)
Expand All @@ -34,6 +34,7 @@ def alpha_prediction_loss(y_true, y_pred):
# alpha mattes.
def compositional_loss(y_true, y_pred):
mask = y_true[:, :, :, 1]
mask = K.reshape(mask, (-1, img_rows, img_cols, 1))
image = y_true[:, :, :, 2:5]
fg = y_true[:, :, :, 5:8]
bg = y_true[:, :, :, 8:11]
Expand Down Expand Up @@ -89,12 +90,14 @@ def get_final_output(out, trimap):


def safe_crop(mat, x, y, crop_size=(img_rows, img_cols)):
h, w = crop_size
crop_height, crop_width = crop_size
if len(mat.shape) == 2:
ret = np.zeros((h, w), np.float32)
ret = np.zeros((crop_height, crop_width), np.float32)
else:
ret = np.zeros((h, w, 3), np.float32)
ret[0:h, 0:w] = mat[y:y + h, x:x + w]
ret = np.zeros((crop_height, crop_width, 3), np.float32)
crop = mat[y:y + crop_height, x:x + crop_width]
h, w = crop.shape[:2]
ret[0:h, 0:w] = crop
if crop_size != (img_rows, img_cols):
ret = cv.resize(ret, dsize=(img_rows, img_cols), interpolation=cv.INTER_CUBIC)
return ret
Expand Down

0 comments on commit b7ef7a9

Please sign in to comment.