Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
cl0ver012 committed Jul 5, 2018
1 parent 06524f2 commit 1d7067d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
4 changes: 1 addition & 3 deletions data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ def __getitem__(self, idx):
x, y = random_choice(trimap, crop_size)
image = safe_crop(image, x, y, crop_size)
alpha = safe_crop(alpha, x, y, crop_size)
fg = safe_crop(fg, x, y, crop_size)
bg = safe_crop(bg, x, y, crop_size)

trimap = generate_trimap(alpha)

Expand Down Expand Up @@ -179,7 +177,7 @@ def shuffle_data():
bcount += 1

from config import num_valid_samples
valid_names = np.random.sample(names, num_valid_samples)
valid_names = random.sample(names, num_valid_samples)
train_names = [n for n in names if n not in valid_names]
shuffle(valid_names)
shuffle(train_names)
Expand Down
2 changes: 1 addition & 1 deletion train_final.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from config import patience, batch_size, epochs, num_train_samples, num_valid_samples
from data_generator import train_gen, valid_gen
from model import build_encoder_decoder, build_refinement
from utils import overall_loss, get_available_cpus, get_available_gpus
from utils import alpha_prediction_loss, get_available_cpus, get_available_gpus

if __name__ == '__main__':
checkpoint_models_path = 'models/'
Expand Down

0 comments on commit 1d7067d

Please sign in to comment.