@@ -143,7 +143,7 @@ def train(train_loader, model_G, model_D, criterion_G, criterion_GAN, optimizer_
143
143
data_time .update (time .time () - end )
144
144
var = Variable (img_L .float (), requires_grad = True ).cuda ()
145
145
real = Variable (real .float (), requires_grad = True ).cuda ()
146
- target = Variable (utils .soft_encode_ab (target ).float (), requires_grad = False ).cuda ()
146
+ target_class = Variable (utils .soft_encode_ab (target ).float (), requires_grad = False ).cuda ()
147
147
# compute output G(L)
148
148
output = model_G (var )
149
149
@@ -174,7 +174,7 @@ def train(train_loader, model_G, model_D, criterion_G, criterion_GAN, optimizer_
174
174
# Fool the discriminator
175
175
loss_G_GAN = criterion_GAN (fake_prob , True )
176
176
# Regressor loss term
177
- loss_G_L2 = criterion_G (output , target )
177
+ loss_G_L2 = criterion_G (output , target_class )
178
178
loss_G = loss_G_GAN + loss_G_L2 * 10
179
179
if torch .isnan (loss_G ):
180
180
print ('NaN value encountered in loss_G.' )
@@ -208,7 +208,7 @@ def train(train_loader, model_G, model_D, criterion_G, criterion_GAN, optimizer_
208
208
start = time .time ()
209
209
batch_num = np .maximum (args .batch_size // 4 ,2 )
210
210
idx = i + epoch * len (train_loader )
211
- imgs = utils .getImages (img_L . float (), target .cpu (), output .detach (). cpu () , batch_num , decode = True )
211
+ imgs = utils .getImages (var . detach (), target .cuda (), output .detach (), batch_num , do_decode = True )
212
212
writer .add_image ('data/imgs_gen' , imgs , idx )
213
213
print ("Img conversion time: " , time .time () - start )
214
214
writer .add_scalar ('data/L2_loss_train' , losses_L2 .avg , i + epoch * len (train_loader ))
0 commit comments