-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlosses.py
71 lines (53 loc) · 2.54 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import tensorflow as tf
def discriminator_loss(loss_func: str, real, fake, use_ra: bool = False):
real_loss: float = .0
fake_loss: float = .0
if use_ra:
if not loss_func.__contains__("wgan"):
real = real - tf.reduce_mean(fake)
fake = fake - tf.reduce_mean(real)
if loss_func.__contains__("wgan"):
real_loss = -tf.reduce_mean(real)
fake_loss = tf.reduce_mean(fake)
if loss_func == "lsgan":
real_loss = tf.reduce_mean(tf.squared_difference(real, 1.))
fake_loss = tf.reduce_mean(tf.square(fake))
if loss_func == "gan" or loss_func == "dragan":
real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(real), logits=real))
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(fake), logits=fake))
if loss_func == "hinge":
real_loss = tf.reduce_mean(tf.nn.relu(1. - real))
fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake))
loss = real_loss + fake_loss
return loss
def generator_loss(loss_func: str, real, fake, use_ra: bool = False):
fake_loss: float = .0
real_loss: float = .0
if use_ra:
fake_logit = (fake - tf.reduce_mean(real))
real_logit = (real - tf.reduce_mean(fake))
if loss_func == 'lsgan':
fake_loss = tf.reduce_mean(tf.square(fake_logit - 1.))
real_loss = tf.reduce_mean(tf.square(real_logit + 1.))
if loss_func == 'gan' or loss_func == 'gan-gp' or loss_func == 'dragan':
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(fake), logits=fake_logit))
real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(real), logits=real_logit))
if loss_func == 'hinge':
fake_loss = tf.reduce_mean(tf.nn.relu(1. - fake_logit))
real_loss = tf.reduce_mean(tf.nn.relu(1. + real_logit))
else:
if loss_func == 'wgan-gp' or loss_func == 'wgan-lp':
fake_loss = -tf.reduce_mean(fake)
if loss_func == 'lsgan':
fake_loss = tf.reduce_mean(tf.square(fake - 1.0))
if loss_func == 'gan' or loss_func == 'gan-gp' or loss_func == 'dragan':
fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(fake), logits=fake))
if loss_func == 'hinge':
fake_loss = -tf.reduce_mean(fake)
loss = fake_loss + real_loss
return loss