diff --git a/predit_single.py b/predit_single.py new file mode 100644 index 0000000..ee875e4 --- /dev/null +++ b/predit_single.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +import cv2 as cv +import keras.backend as K +import numpy as np + +from model import build_encoder_decoder, build_refinement +from utils import get_final_output, create_patches, patch_dims, assemble_patches +import tensorflow as tf +import time + +config = tf.ConfigProto(device_count = {"GPU": 1, "CPU": 1}) +sess = tf.Session(config=config) +K.set_session(sess) + +if __name__ == '__main__': + # load network + PATCH_SIZE = 320 + PRETRAINED_PATH = 'models/final.42-0.0398.hdf5' + TRIMAP_PATH = "images/trimap2.png" + IMG_PATH = "images/frame2.png" + + encoder_decoder = build_encoder_decoder() + final = build_refinement(encoder_decoder) + final.load_weights(PRETRAINED_PATH) + print(final.summary()) + + # loading input files + trimap = cv.imread(TRIMAP_PATH, cv.IMREAD_GRAYSCALE) + img = cv.imread(IMG_PATH) + result = np.zeros(trimap.shape, dtype=np.uint8) + + img_size = np.array(trimap.shape) + + # create patches + x = np.dstack((img, np.expand_dims(trimap, axis=2))) / 255. + patches = create_patches(x, PATCH_SIZE) + + # create mat for patches predictions + patches_count = np.product( + patch_dims(mat_size=trimap.shape, patch_size=PATCH_SIZE) + ) + patches_predictions = np.zeros(shape=(patches_count, PATCH_SIZE, PATCH_SIZE)) + + # predicting + for i in range(patches.shape[0]): + print("Predicting patches {}/{}".format(i + 1, patches_count)) + + patch_prediction = final.predict(np.expand_dims(patches[i, :, :, :], axis=0)) + patches_predictions[i] = np.reshape(patch_prediction, (PATCH_SIZE, PATCH_SIZE)) * 255. + + # assemble + result = assemble_patches(patches_predictions, trimap.shape, PATCH_SIZE) + result = result[:img_size[0], :img_size[1]] + + prediction = get_final_output(result, trimap).astype(np.uint8) + + # save into files + cv.imshow("result", prediction) + cv.imshow("image", img) + cv.waitKey(0) + + K.clear_session() + diff --git a/utils.py b/utils.py index 3006e37..549a536 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,5 @@ import multiprocessing - +import math import cv2 as cv import keras.backend as K import numpy as np @@ -88,6 +88,45 @@ def get_final_output(out, trimap): mask = np.equal(trimap, unknown_code).astype(np.float32) return (1 - mask) * trimap + mask * out +def patch_dims(mat_size, patch_size): + return np.ceil(np.array(mat_size) / patch_size).astype(int) + +def create_patches(mat, patch_size): + mat_size = mat.shape + assert len(mat_size) == 3, "Input mat need to have 4 channels (R, G, B, trimap)" + assert mat_size[-1] == 4 , "Input mat need to have 4 channels (R, G, B, trimap)" + + patches_dim = patch_dims(mat_size=mat_size[:2], patch_size=patch_size) + patches_count = np.product(patches_dim) + + patches = np.zeros(shape=(patches_count, patch_size, patch_size, 4), dtype=np.float32) + for y in range(patches_dim[0]): + y_start = y * patch_size + for x in range(patches_dim[1]): + x_start = x * patch_size + + # extract patch from input mat + single_patch = mat[y_start: y_start + patch_size, x_start: x_start + patch_size, :] + + # zero pad patch in bottom and right side if real patch size is smaller than patch size + real_patch_h, real_patch_w = single_patch.shape[:2] + patch_id = y + x * patches_dim[0] + patches[patch_id, :real_patch_h, :real_patch_w, :] = single_patch + + return patches + +def assemble_patches(pred_patches, mat_size, patch_size): + patch_dim_h, patch_dim_w = patch_dims(mat_size=mat_size, patch_size=patch_size) + result = np.zeros(shape=(patch_size * patch_dim_h, patch_size * patch_dim_w), dtype=np.uint8) + patches_count = pred_patches.shape[0] + + for i in range(patches_count): + y = (i % patch_dim_h) * patch_size + x = int(math.floor(i / patch_dim_h)) * patch_size + + result[y:y+patch_size, x:x+patch_size] = pred_patches[i] + + return result def safe_crop(mat, x, y, crop_size=(img_rows, img_cols)): crop_height, crop_width = crop_size