diff --git a/data_generator.py b/data_generator.py index 888a45e..72174aa 100644 --- a/data_generator.py +++ b/data_generator.py @@ -132,8 +132,6 @@ def __getitem__(self, idx): x, y = random_choice(trimap, crop_size) image = safe_crop(image, x, y, crop_size) alpha = safe_crop(alpha, x, y, crop_size) - fg = safe_crop(fg, x, y, crop_size) - bg = safe_crop(bg, x, y, crop_size) trimap = generate_trimap(alpha) @@ -179,7 +177,7 @@ def shuffle_data(): bcount += 1 from config import num_valid_samples - valid_names = np.random.sample(names, num_valid_samples) + valid_names = random.sample(names, num_valid_samples) train_names = [n for n in names if n not in valid_names] shuffle(valid_names) shuffle(train_names) diff --git a/train_final.py b/train_final.py index 9389a9f..c6e5edf 100644 --- a/train_final.py +++ b/train_final.py @@ -8,7 +8,7 @@ from config import patience, batch_size, epochs, num_train_samples, num_valid_samples from data_generator import train_gen, valid_gen from model import build_encoder_decoder, build_refinement -from utils import overall_loss, get_available_cpus, get_available_gpus +from utils import alpha_prediction_loss, get_available_cpus, get_available_gpus if __name__ == '__main__': checkpoint_models_path = 'models/'