-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
60 lines (49 loc) · 2.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import numpy as np
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def log10_cuda(t):
"""
Calculates the base-10 tensorboard_log of each element in t.
@param t: The tensor from which to calculate the base-10 tensorboard_log.
@return: A tensor with the base-10 tensorboard_log of each element in t.
"""
numerator = torch.log(t)
denominator = torch.log(torch.FloatTensor([10.])).cuda()
return numerator / denominator
def psnr_error_cuda(gen_frames, gt_frames):
"""
Computes the Peak Signal to Noise Ratio error between the generated images and the ground
truth images.
@param gen_frames: A tensor of shape [batch_size, 3, height, width]. The frames generated by the
generator model.
@param gt_frames: A tensor of shape [batch_size, 3, height, width]. The ground-truth frames for
each frame in gen_frames.
@return: A scalar tensor. The mean Peak Signal to Noise Ratio error over each frame in the
batch.
"""
shape = list(gen_frames.shape)
num_pixels = (shape[1] * shape[2] * shape[3])
gt_frames = (gt_frames + 1.0) / 2.0 # if the generate ouuput is sigmoid output, modify here.
gen_frames = (gen_frames + 1.0) / 2.0
square_diff = (gt_frames - gen_frames) ** 2
batch_errors = 10 * log10_cuda(1. / ((1. / num_pixels) * torch.sum(square_diff, [1, 2, 3])))
return torch.mean(batch_errors)
def Visualization(fake_pic):
tensor = fake_pic
tensor = tensor.cpu()
pic = tensor.detach().numpy()
pic = np.transpose(pic , [1 , 2 , 0]) # from (C , W , H)
pic = 127.5 * (pic + 1)
return pic
def stack_Visualization(real_frame , fake_frame):
real_out = Visualization(real_frame)
fake_out = Visualization(fake_frame)
contrast = abs(real_out - fake_out)
out = np.concatenate((real_out , fake_out , contrast) , axis=1)
return out