Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ Latest version of code has following changes:
3. Another thing need to mentioned here is that when we training on single complex sample like bike, even with deconvolution (not unpooling), the network can overfitting. But deconvolution can't converge on whole dataset. (Maybe I didn't training enough time : lr = 1e-5 with 5 days training, can't converge). Discussion about 'whether deconvolution can replace unpooling' is welcomed!
4. Add hard mode to allow training on tough samples

**2017-11-01:**
1. Fresh version of unpooling added. Now it can deal with variable batch size.
2. 0-dimension of placeholders were set to None. This allows setting a different batch size for train and test runs.
3. Trimap argument added to the test script, so now you can test pretrained model on your own images, even if you don't have alpha-mask. Just paint rough trimap manually in any graphics editor.

*My Chinese blog about the implementation of this paper*
http://blog.leanote.com/post/calebge/Deep-Image-Matting%E5%A4%8D%E7%8E%B0%E8%BF%87%E7%A8%8B%E6%80%BB%E7%BB%93 <br />

Expand All @@ -25,6 +30,8 @@ simply run:<br />
python test.py --alpha --rgb<br />
sample:<br />
python test.py --alpha=./test_data/alpha/1.png --rgb=./test_data/RGB/1.png<br />
or you can use any other trimap you already have, instead of generating it from alpha:<br />
python test.py --trimap --rgb<br />

<h2>Pretrained Model</h2>
Can be found here:<br />
Expand Down
40 changes: 22 additions & 18 deletions matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,28 @@
]

def unpool(pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):

with tf.variable_scope(scope):
input_shape = pool.get_shape().as_list()
output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])

flat_input_size = np.prod(input_shape)
flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]

pool_ = tf.reshape(pool, [flat_input_size])
batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b = tf.reshape(b, [flat_input_size, 1])
ind_ = tf.reshape(ind, [flat_input_size, 1])
ind_ = tf.concat([b, ind_], 1)

ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape)
ret = tf.reshape(ret, output_shape)
return ret
with tf.variable_scope(scope):
input_shape = tf.shape(pool)
output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

flat_input_size = tf.reduce_prod(input_shape)
flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]

pool_ = tf.reshape(pool, [flat_input_size])
batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b1 = tf.reshape(b, [flat_input_size, 1])
ind_ = tf.reshape(ind, [flat_input_size, 1])
ind_ = tf.concat([b1, ind_], 1)

ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
ret = tf.reshape(ret, output_shape)

set_input_shape = pool.get_shape()
set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
ret.set_shape(set_output_shape)
return ret


def UR_center(trimap):
Expand Down
12 changes: 6 additions & 6 deletions matting_unpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
index_queue = tf.train.range_input_producer(range_size, num_epochs=None,shuffle=True, seed=None, capacity=32)
index_dequeue_op = index_queue.dequeue_many(train_batch_size, 'index_dequeue')

image_batch = tf.placeholder(tf.float32, shape=(train_batch_size,image_size,image_size,3))
raw_RGBs = tf.placeholder(tf.float32, shape=(train_batch_size,image_size,image_size,3))
GT_matte_batch = tf.placeholder(tf.float32, shape = (train_batch_size,image_size,image_size,1))
GT_trimap = tf.placeholder(tf.float32, shape = (train_batch_size,image_size,image_size,1))
GTBG_batch = tf.placeholder(tf.float32, shape = (train_batch_size,image_size,image_size,3))
GTFG_batch = tf.placeholder(tf.float32, shape = (train_batch_size,image_size,image_size,3))
image_batch = tf.placeholder(tf.float32, shape=(None,image_size,image_size,3))
raw_RGBs = tf.placeholder(tf.float32, shape=(None,image_size,image_size,3))
GT_matte_batch = tf.placeholder(tf.float32, shape = (None,image_size,image_size,1))
GT_trimap = tf.placeholder(tf.float32, shape = (None,image_size,image_size,1))
GTBG_batch = tf.placeholder(tf.float32, shape = (None,image_size,image_size,3))
GTFG_batch = tf.placeholder(tf.float32, shape = (None,image_size,image_size,3))


tf.add_to_collection('image_batch',image_batch)
Expand Down
13 changes: 9 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ def main(args):
pred_mattes = tf.get_collection('pred_mattes')[0]

rgb = misc.imread(args.rgb)
alpha = misc.imread(args.alpha,'L')
trimap = generate_trimap(np.expand_dims(np.copy(alpha),2),np.expand_dims(alpha,2))[:,:,0]
origin_shape = alpha.shape
if args.trimap is not None:
trimap = misc.imread(args.trimap,'L')
else:
alpha = misc.imread(args.alpha,'L')
trimap = generate_trimap(np.expand_dims(np.copy(alpha),2),np.expand_dims(alpha,2))[:,:,0]
origin_shape = trimap.shape
rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3]).astype(np.float32)-g_mean,0)
trimap = np.expand_dims(np.expand_dims(misc.imresize(trimap.astype(np.uint8),[320,320],interp = 'nearest').astype(np.float32),2),0)

Expand All @@ -38,11 +41,13 @@ def parse_arguments(argv):
help='input alpha')
parser.add_argument('--rgb', type=str,
help='input rgb')
parser.add_argument('--trimap', type=str,
help='input trimap')
parser.add_argument('--gpu_fraction', type=float,
help='how much gpu is needed, usually 4G is enough',default = 0.4)
return parser.parse_args(argv)


if __name__ == '__main__':
main(parse_arguments(sys.argv[1:]))
main(parse_arguments(sys.argv[1:]))