Skip to content

Commit

Permalink
Merge pull request #37 from SonyPony/master
Browse files Browse the repository at this point in the history
Add patching of larger input
  • Loading branch information
cl0ver012 committed Aug 20, 2019
2 parents 383a835 + 3fc7a40 commit 71a7d3a
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
63 changes: 63 additions & 0 deletions predit_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
import cv2 as cv
import keras.backend as K
import numpy as np

from model import build_encoder_decoder, build_refinement
from utils import get_final_output, create_patches, patch_dims, assemble_patches
import tensorflow as tf
import time

config = tf.ConfigProto(device_count = {"GPU": 1, "CPU": 1})
sess = tf.Session(config=config)
K.set_session(sess)

if __name__ == '__main__':
# load network
PATCH_SIZE = 320
PRETRAINED_PATH = 'models/final.42-0.0398.hdf5'
TRIMAP_PATH = "images/trimap2.png"
IMG_PATH = "images/frame2.png"

encoder_decoder = build_encoder_decoder()
final = build_refinement(encoder_decoder)
final.load_weights(PRETRAINED_PATH)
print(final.summary())

# loading input files
trimap = cv.imread(TRIMAP_PATH, cv.IMREAD_GRAYSCALE)
img = cv.imread(IMG_PATH)
result = np.zeros(trimap.shape, dtype=np.uint8)

img_size = np.array(trimap.shape)

# create patches
x = np.dstack((img, np.expand_dims(trimap, axis=2))) / 255.
patches = create_patches(x, PATCH_SIZE)

# create mat for patches predictions
patches_count = np.product(
patch_dims(mat_size=trimap.shape, patch_size=PATCH_SIZE)
)
patches_predictions = np.zeros(shape=(patches_count, PATCH_SIZE, PATCH_SIZE))

# predicting
for i in range(patches.shape[0]):
print("Predicting patches {}/{}".format(i + 1, patches_count))

patch_prediction = final.predict(np.expand_dims(patches[i, :, :, :], axis=0))
patches_predictions[i] = np.reshape(patch_prediction, (PATCH_SIZE, PATCH_SIZE)) * 255.

# assemble
result = assemble_patches(patches_predictions, trimap.shape, PATCH_SIZE)
result = result[:img_size[0], :img_size[1]]

prediction = get_final_output(result, trimap).astype(np.uint8)

# save into files
cv.imshow("result", prediction)
cv.imshow("image", img)
cv.waitKey(0)

K.clear_session()

41 changes: 40 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import multiprocessing

import math
import cv2 as cv
import keras.backend as K
import numpy as np
Expand Down Expand Up @@ -88,6 +88,45 @@ def get_final_output(out, trimap):
mask = np.equal(trimap, unknown_code).astype(np.float32)
return (1 - mask) * trimap + mask * out

def patch_dims(mat_size, patch_size):
return np.ceil(np.array(mat_size) / patch_size).astype(int)

def create_patches(mat, patch_size):
mat_size = mat.shape
assert len(mat_size) == 3, "Input mat need to have 4 channels (R, G, B, trimap)"
assert mat_size[-1] == 4 , "Input mat need to have 4 channels (R, G, B, trimap)"

patches_dim = patch_dims(mat_size=mat_size[:2], patch_size=patch_size)
patches_count = np.product(patches_dim)

patches = np.zeros(shape=(patches_count, patch_size, patch_size, 4), dtype=np.float32)
for y in range(patches_dim[0]):
y_start = y * patch_size
for x in range(patches_dim[1]):
x_start = x * patch_size

# extract patch from input mat
single_patch = mat[y_start: y_start + patch_size, x_start: x_start + patch_size, :]

# zero pad patch in bottom and right side if real patch size is smaller than patch size
real_patch_h, real_patch_w = single_patch.shape[:2]
patch_id = y + x * patches_dim[0]
patches[patch_id, :real_patch_h, :real_patch_w, :] = single_patch

return patches

def assemble_patches(pred_patches, mat_size, patch_size):
patch_dim_h, patch_dim_w = patch_dims(mat_size=mat_size, patch_size=patch_size)
result = np.zeros(shape=(patch_size * patch_dim_h, patch_size * patch_dim_w), dtype=np.uint8)
patches_count = pred_patches.shape[0]

for i in range(patches_count):
y = (i % patch_dim_h) * patch_size
x = int(math.floor(i / patch_dim_h)) * patch_size

result[y:y+patch_size, x:x+patch_size] = pred_patches[i]

return result

def safe_crop(mat, x, y, crop_size=(img_rows, img_cols)):
crop_height, crop_width = crop_size
Expand Down

0 comments on commit 71a7d3a

Please sign in to comment.