Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cl0ver012 committed Jan 7, 2019
1 parent c25b9e3 commit 90956f2
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import cv2 as cv
import numpy as np

from model import build_encoder_decoder
from model import build_encoder_decoder, build_refinement
from utils import get_final_output

# python test.py -i "images/image.png" -t "images/trimap.png"
if __name__ == '__main__':
img_rows, img_cols = 320, 320
channel = 4

model_weights_path = 'models/model.35-0.03.hdf5'
model = build_encoder_decoder()
model.load_weights(model_weights_path)
print(model.summary())
model_weights_path = 'models/final.42-0.0398.hdf5'
encoder_decoder = build_encoder_decoder()
final = build_refinement(encoder_decoder)
final.load_weights(model_weights_path)
print(final.summary())

ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", help="path to the image file")
Expand All @@ -22,20 +24,25 @@
image_path = args["image"]
trimap_path = args["trimap"]

if image_path is None:
image_path = 'images/image.jpg'
if trimap_path is None:
trimap_path = 'images/trimap.jpg'

print('Start processing image: {}'.format(image_path))

x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32)
bgr_img = cv.imread(image_path)
trimap = cv.imread(trimap_path, 0)

x_test = np.empty((1, 320, 320, 4), dtype=np.float32)
x_test[0, :, :, 0:3] = bgr_img / 255.
x_test[0, :, :, 3] = trimap / 255.

out = model.predict(x_test)
out = final.predict(x_test)
out = np.reshape(out, (img_rows, img_cols))
print(out.shape)
out = out * 255.0
out = get_final_output(out, trimap)
out = out.astype(np.uint8)
cv.imshow('out', out)
cv.imwrite('images/out.png', out)
Expand Down

0 comments on commit 90956f2

Please sign in to comment.