-
Notifications
You must be signed in to change notification settings - Fork 665
/
train_adobe.py
182 lines (118 loc) · 5.66 KB
/
train_adobe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
import os
import time
import argparse
from data_loader import AdobeDataAffineHR
from functions import *
from networks import ResnetConditionHR, conv_init
from loss_functions import alpha_loss, compose_loss, alpha_gradient_loss
#CUDA
#os.environ["CUDA_VISIBLE_DEVICES"]="4"
print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])
"""Parses arguments."""
parser = argparse.ArgumentParser(description='Training Background Matting on Adobe Dataset.')
parser.add_argument('-n', '--name', type=str, help='Name of tensorboard and model saving folders.')
parser.add_argument('-bs', '--batch_size', type=int, help='Batch Size.')
parser.add_argument('-res', '--reso', type=int, help='Input image resolution')
parser.add_argument('-epoch', '--epoch', type=int, default=60,help='Maximum Epoch')
parser.add_argument('-n_blocks1', '--n_blocks1', type=int, default=7,help='Number of residual blocks after Context Switching.')
parser.add_argument('-n_blocks2', '--n_blocks2', type=int, default=3,help='Number of residual blocks for Fg and alpha each.')
args=parser.parse_args()
##Directories
tb_dir='TB_Summary/' + args.name
model_dir='Models/' + args.name
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if not os.path.exists(tb_dir):
os.makedirs(tb_dir)
## Input list
data_config_train = {'reso': [args.reso,args.reso], 'trimapK': [5,5], 'noise': True} # choice for data loading parameters
# DATA LOADING
print('\n[Phase 1] : Data Preparation')
def collate_filter_none(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
#Original Data
traindata = AdobeDataAffineHR(csv_file='Data_adobe/Adobe_train_data.csv',data_config=data_config_train,transform=None) #Write a dataloader function that can read the database provided by .csv file
train_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.batch_size, collate_fn=collate_filter_none)
print('\n[Phase 2] : Initialization')
net=ResnetConditionHR(input_nc=(3,3,1,4), output_nc=4, n_blocks1=7, n_blocks2=3, norm_layer=nn.BatchNorm2d)
net.apply(conv_init)
net=nn.DataParallel(net)
#net.load_state_dict(torch.load(model_dir + 'net_epoch_X')) #uncomment this if you are initializing your model
net.cuda()
torch.backends.cudnn.benchmark=True
#Loss
l1_loss=alpha_loss()
c_loss=compose_loss()
g_loss=alpha_gradient_loss()
optimizer = optim.Adam(net.parameters(), lr=1e-4)
#optimizer.load_state_dict(torch.load(model_dir + 'optim_epoch_X')) #uncomment this if you are initializing your model
log_writer=SummaryWriter(tb_dir)
print('Starting Training')
step=50 #steps to visualize training images in tensorboard
KK=len(train_loader)
for epoch in range(0,args.epoch):
net.train();
netL, alL, fgL, fg_cL, al_fg_cL, elapse_run, elapse=0,0,0,0,0,0,0
t0=time.time();
testL=0; ct_tst=0;
for i,data in enumerate(train_loader):
#Initiating
fg, bg, alpha, image, seg, bg_tr, multi_fr = data['fg'], data['bg'], data['alpha'], data['image'], data['seg'], data['bg_tr'], data['multi_fr']
fg, bg, alpha, image, seg, bg_tr, multi_fr = Variable(fg.cuda()), Variable(bg.cuda()), Variable(alpha.cuda()), Variable(image.cuda()), Variable(seg.cuda()), Variable(bg_tr.cuda()), Variable(multi_fr.cuda())
mask=(alpha>-0.99).type(torch.cuda.FloatTensor)
mask0=Variable(torch.ones(alpha.shape).cuda())
tr0=time.time()
alpha_pred,fg_pred=net(image,bg_tr,seg,multi_fr)
## Put needed loss here
al_loss=l1_loss(alpha,alpha_pred,mask0)
fg_loss=l1_loss(fg,fg_pred,mask)
al_mask=(alpha_pred>0.95).type(torch.cuda.FloatTensor)
fg_pred_c=image*al_mask + fg_pred*(1-al_mask)
fg_c_loss= c_loss(image,alpha_pred,fg_pred_c,bg,mask0)
al_fg_c_loss=g_loss(alpha,alpha_pred,mask0)
loss=al_loss + 2*fg_loss + fg_c_loss + al_fg_c_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
netL += loss.data
alL += al_loss.data
fgL += fg_loss.data
fg_cL += fg_c_loss.data
al_fg_cL += al_fg_c_loss.data
log_writer.add_scalar('training_loss', loss.data, epoch*KK + i + 1)
log_writer.add_scalar('alpha_loss', al_loss.data, epoch*KK + i + 1)
log_writer.add_scalar('fg_loss', fg_loss.data, epoch*KK + i + 1)
log_writer.add_scalar('comp_loss', fg_c_loss.data, epoch*KK + i + 1)
log_writer.add_scalar('alpha_gradient_loss', al_fg_c_loss.data, epoch*KK + i + 1)
t1=time.time()
elapse +=t1 -t0
elapse_run += t1-tr0
t0=t1
testL+=loss.data
ct_tst+=1
if i % step == (step-1):
print('[%d, %5d] Total-loss: %.4f Alpha-loss: %.4f Fg-loss: %.4f Comp-loss: %.4f Alpha-gradient-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step, alL/step, fgL/step, fg_cL/step, al_fg_cL/step, elapse/step, elapse_run/step))
netL, alL, fgL, fg_cL, al_fg_cL, elapse_run, elapse=0,0,0,0,0,0,0
write_tb_log(image,'image',log_writer,i)
write_tb_log(seg,'seg',log_writer,i)
write_tb_log(alpha,'alpha',log_writer,i)
write_tb_log(alpha_pred,'alpha_pred',log_writer,i)
write_tb_log(fg*mask,'fg',log_writer,i)
write_tb_log(fg_pred*mask,'fg_pred',log_writer,i)
write_tb_log(multi_fr[0:4,0,...].unsqueeze(1),'multi_fr',log_writer,i)
#composition
alpha_pred=(alpha_pred+1)/2
comp=fg_pred*alpha_pred + (1-alpha_pred)*bg
write_tb_log(comp,'composite',log_writer,i)
del comp
del fg, bg, alpha, image, alpha_pred, fg_pred, seg, multi_fr
#Saving
torch.save(net.state_dict(), model_dir + 'net_epoch_%d_%.4f.pth' %(epoch,testL/ct_tst))
torch.save(optimizer.state_dict(), model_dir + 'optim_epoch_%d_%.4f.pth' %(epoch,testL/ct_tst))