Skip to content

Commit 539d264

Browse files
committed
fixed image tensorboard viewer
1 parent 53af006 commit 539d264

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def __getitem__(self, i):
3939

4040
inputImage = rgb2lab(np.array(self.transf(img))).astype(float)
4141
inputImage[:,:,0] -= 50.0
42-
image_ab = self.toTensor(cv2.resize(inputImage, (56, 56), interpolation = cv2.INTER_AREA)[:,:,1:].astype(float))
43-
image_L = torch.from_numpy(inputImage[:,:,0]).unsqueeze_(0)
42+
image_ab = self.toTensor(cv2.resize(inputImage, (56, 56), interpolation = cv2.INTER_AREA)[:,:,1:].astype(float)).float()
43+
image_L = torch.from_numpy(inputImage[:,:,0]).unsqueeze_(0).float()
4444
if self.output_full:
45-
img_all = torch.from_numpy(inputImage)
45+
img_all = self.toTensor(inputImage).float()
4646
return img_all, image_L, image_ab
4747
else:
4848
return image_L, image_ab

train_coGAN.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def train(train_loader, model_G, model_D, criterion_G, criterion_GAN, optimizer_
143143
data_time.update(time.time() - end)
144144
var = Variable(img_L.float(), requires_grad=True).cuda()
145145
real = Variable(real.float(), requires_grad=True).cuda()
146-
target = Variable(utils.soft_encode_ab(target).float(), requires_grad=False).cuda()
146+
target_class = Variable(utils.soft_encode_ab(target).float(), requires_grad=False).cuda()
147147
# compute output G(L)
148148
output = model_G(var)
149149

@@ -174,7 +174,7 @@ def train(train_loader, model_G, model_D, criterion_G, criterion_GAN, optimizer_
174174
# Fool the discriminator
175175
loss_G_GAN = criterion_GAN(fake_prob, True)
176176
# Regressor loss term
177-
loss_G_L2 = criterion_G(output, target)
177+
loss_G_L2 = criterion_G(output, target_class)
178178
loss_G = loss_G_GAN + loss_G_L2*10
179179
if torch.isnan(loss_G):
180180
print('NaN value encountered in loss_G.')
@@ -208,7 +208,7 @@ def train(train_loader, model_G, model_D, criterion_G, criterion_GAN, optimizer_
208208
start = time.time()
209209
batch_num = np.maximum(args.batch_size//4,2)
210210
idx = i + epoch*len(train_loader)
211-
imgs = utils.getImages(img_L.float(), target.cpu(), output.detach().cpu(), batch_num, decode=True)
211+
imgs = utils.getImages(var.detach(), target.cuda(), output.detach(), batch_num, do_decode=True)
212212
writer.add_image('data/imgs_gen', imgs, idx)
213213
print("Img conversion time: ", time.time() - start)
214214
writer.add_scalar('data/L2_loss_train', losses_L2.avg, i + epoch*len(train_loader))

utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,18 @@ def generateImg(Z,light):
6262
Image = cv2.cvtColor(newImg.astype(np.uint8), cv2.COLOR_LAB2RGB)
6363
return Image
6464

65-
def getImages(L_channel, ab_target, ab_gen, batch_num, decode=True):
65+
def getImages(L_channel, ab_target, ab_gen, batch_num, do_decode=True):
6666
L_channel = interpolate(L_channel[:batch_num,:,:,:] + 50.0, scale_factor=0.25, mode='bilinear',
6767
recompute_scale_factor=True, align_corners=True)
68-
if decode:
68+
if do_decode:
6969
ab_gen = decode(ab_gen[:batch_num,:,:,:], T=0.38)
7070
else:
7171
ab_gen = ab_gen[:batch_num,:,:,:]
7272

7373
ab_target = ab_target[:batch_num,:,:,:]
7474
img_target = torch.cat([L_channel, ab_target], dim=1)
7575
img_gen = torch.cat([L_channel, ab_gen], dim=1)
76-
img_all = torch.cat([img_target, img_gen], dim=0).numpy().transpose((0,2,3,1))
76+
img_all = torch.cat([img_target, img_gen], dim=0).cpu().numpy().transpose((0,2,3,1))
7777

7878
imgs_all_l = []
7979
for i in range(batch_num):

0 commit comments

Comments
 (0)