-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
104 lines (84 loc) · 4.33 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
class VGGLoss(nn.Module):
"""Computes the VGG perceptual loss between two batches of images.
The input and target must be 4D tensors with three channels
``(B, 3, H, W)`` and must have equivalent shapes. Pixel values should be
normalized to the range 0–1.
The VGG perceptual loss is the mean squared difference between the features
computed for the input and target at layer :attr:`layer` (default 8, or
``relu2_2``) of the pretrained model specified by :attr:`model` (either
``'vgg16'`` (default) or ``'vgg19'``).
If :attr:`shift` is nonzero, a random shift of at most :attr:`shift`
pixels in both height and width will be applied to all images in the input
and target. The shift will only be applied when the loss function is in
training mode, and will not be applied if a precomputed feature map is
supplied as the target.
:attr:`reduction` can be set to ``'mean'``, ``'sum'``, or ``'none'``
similarly to the loss functions in :mod:`torch.nn`. The default is
``'mean'``.
:meth:`get_features()` may be used to precompute the features for the
target, to speed up the case where inputs are compared against the same
target over and over. To use the precomputed features, pass them in as
:attr:`target` and set :attr:`target_is_features` to :code:`True`.
Instances of :class:`VGGLoss` must be manually converted to the same
device and dtype as their inputs.
"""
models = {'vgg16': models.vgg16, 'vgg19': models.vgg19}
def __init__(self, model='vgg16', layer=8, shift=0, reduction='mean'):
super().__init__()
self.shift = shift
self.reduction = reduction
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.model = self.models[model](pretrained=True).features[:layer+1]
self.model.eval()
self.model.requires_grad_(False)
def get_features(self, input):
return self.model(self.normalize(input))
def train(self, mode=True):
self.training = mode
def forward(self, input, target, target_is_features=False):
if target_is_features:
input_feats = self.get_features(input)
target_feats = target
else:
sep = input.shape[0]
batch = torch.cat([input, target])
if self.shift and self.training:
padded = F.pad(batch, [self.shift] * 4, mode='replicate')
batch = transforms.RandomCrop(batch.shape[2:])(padded)
feats = self.get_features(batch)
input_feats, target_feats = feats[:sep], feats[sep:]
return F.mse_loss(input_feats, target_feats, reduction=self.reduction)
def get_gan_losses_fn():
def d_loss_fn(real_logits, fake_logits):
real_loss = F.binary_cross_entropy_with_logits(
real_logits, torch.ones_like(real_logits), reduction="sum")
fake_loss = F.binary_cross_entropy_with_logits(
fake_logits, torch.zeros_like(fake_logits), reduction="sum")
return real_loss + fake_loss
def g_loss_fn(fake_logits):
fake_loss = F.binary_cross_entropy_with_logits(
fake_logits, torch.ones_like(fake_logits), reduction="sum")
return fake_loss
return d_loss_fn, g_loss_fn
def crop(x, x_left_pos, x_right_pos):
# x: [batch_size, n_channels=3, h=64, w=64]
# x_left_pos: [batch_size, 2]
# x_right_pos: [batch_size, 2]
crop_h, crop_w = 30, 50
x_left_eye = torch.zeros((len(x_left_pos), 3, crop_h, crop_w))
x_right_eye = torch.zeros((len(x_right_pos), 3, crop_h, crop_w))
for i in range(len(x_left_pos)):
x_left_eye[i] = transforms.functional.crop(x[i], x_left_pos[i][1] - crop_h // 2, x_left_pos[i][0] - crop_w // 2, crop_h, crop_w)
x_right_eye[i] = transforms.functional.crop(x[i], x_right_pos[i][1] - crop_h // 2, x_right_pos[i][0] - crop_w // 2, crop_h, crop_w)
resize = transforms.Resize((256 // 4, 256 // 4))
x_left_eye, x_right_eye = resize(x_left_eye), resize(x_right_eye)
device = x.device
return x_left_eye.to(device), x_right_eye.to(device)
def add_sp(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.utils.parametrizations.spectral_norm(m)