-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlosses.py
42 lines (38 loc) · 1.57 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch.nn as nn
import torch
class GANLoss(nn.Module):
def __init__(self, gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def get_device(self):
if torch.cuda.is_available():
self.device = 'cuda:0'
else:
self.device = 'cpu:0'
def get_target_tensor(self, prediction, target_is_real):
self.get_device()
if target_is_real:
target_tensor = self.real_label.to(self.device)
else:
target_tensor = self.fake_label.to(self.device)
return target_tensor.expand_as(prediction)
def __call__(self, prediction, target_is_real):
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss