forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
adversarial_losses.py
236 lines (195 loc) · 8.97 KB
/
adversarial_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Adversarial losses for text models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
from six.moves import xrange
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
# Adversarial and virtual adversarial training parameters.
flags.DEFINE_float('perturb_norm_length', 5.0,
'Norm length of adversarial perturbation to be '
'optimized with validation. '
'5.0 is optimal on IMDB with virtual adversarial training. ')
# Virtual adversarial training parameters
flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration')
flags.DEFINE_float('small_constant_for_finite_diff', 1e-1,
'Small constant for finite difference method')
# Parameters for building the graph
flags.DEFINE_string('adv_training_method', None,
'The flag which specifies training method. '
'"" : non-adversarial training (e.g. for running the '
' semi-supervised sequence learning model) '
'"rp" : random perturbation training '
'"at" : adversarial training '
'"vat" : virtual adversarial training '
'"atvat" : at + vat ')
flags.DEFINE_float('adv_reg_coeff', 1.0,
'Regularization coefficient of adversarial loss.')
def random_perturbation_loss(embedded, length, loss_fn):
"""Adds noise to embeddings and recomputes classification loss."""
noise = tf.random_normal(shape=tf.shape(embedded))
perturb = _scale_l2(_mask_by_length(noise, length), FLAGS.perturb_norm_length)
return loss_fn(embedded + perturb)
def adversarial_loss(embedded, loss, loss_fn):
"""Adds gradient to embedding and recomputes classification loss."""
grad, = tf.gradients(
loss,
embedded,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
grad = tf.stop_gradient(grad)
perturb = _scale_l2(grad, FLAGS.perturb_norm_length)
return loss_fn(embedded + perturb)
def virtual_adversarial_loss(logits, embedded, inputs,
logits_from_embedding_fn):
"""Virtual adversarial loss.
Computes virtual adversarial perturbation by finite difference method and
power iteration, adds it to the embedding, and computes the KL divergence
between the new logits and the original logits.
Args:
logits: 3-D float Tensor, [batch_size, num_timesteps, m], where m=1 if
num_classes=2, otherwise m=num_classes.
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
inputs: VatxtInput.
logits_from_embedding_fn: callable that takes embeddings and returns
classifier logits.
Returns:
kl: float scalar.
"""
# Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
logits = tf.stop_gradient(logits)
# Only care about the KL divergence on the final timestep.
weights = inputs.eos_weights
assert weights is not None
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
weights = tf.expand_dims(tf.gather_nd(inputs.eos_weights, indices), 1)
# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
d = tf.random_normal(shape=tf.shape(embedded))
# Perform finite difference method and power iteration.
# See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf,
# Adding small noise to input and taking gradient with respect to the noise
# corresponds to 1 power iteration.
for _ in xrange(FLAGS.num_power_iteration):
d = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
d_logits = logits_from_embedding_fn(embedded + d)
kl = _kl_divergence_with_logits(logits, d_logits, weights)
d, = tf.gradients(
kl,
d,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
d = tf.stop_gradient(d)
perturb = _scale_l2(d, FLAGS.perturb_norm_length)
vadv_logits = logits_from_embedding_fn(embedded + perturb)
return _kl_divergence_with_logits(logits, vadv_logits, weights)
def random_perturbation_loss_bidir(embedded, length, loss_fn):
"""Adds noise to embeddings and recomputes classification loss."""
noise = [tf.random_normal(shape=tf.shape(emb)) for emb in embedded]
masked = [_mask_by_length(n, length) for n in noise]
scaled = [_scale_l2(m, FLAGS.perturb_norm_length) for m in masked]
return loss_fn([e + s for (e, s) in zip(embedded, scaled)])
def adversarial_loss_bidir(embedded, loss, loss_fn):
"""Adds gradient to embeddings and recomputes classification loss."""
grads = tf.gradients(
loss,
embedded,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
adv_exs = [
emb + _scale_l2(tf.stop_gradient(g), FLAGS.perturb_norm_length)
for emb, g in zip(embedded, grads)
]
return loss_fn(adv_exs)
def virtual_adversarial_loss_bidir(logits, embedded, inputs,
logits_from_embedding_fn):
"""Virtual adversarial loss for bidirectional models."""
logits = tf.stop_gradient(logits)
f_inputs, _ = inputs
weights = f_inputs.eos_weights
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), f_inputs.length - 1], 1)
weights = tf.expand_dims(tf.gather_nd(f_inputs.eos_weights, indices), 1)
assert weights is not None
perturbs = [
_mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length)
for emb in embedded
]
for _ in xrange(FLAGS.num_power_iteration):
perturbs = [
_scale_l2(d, FLAGS.small_constant_for_finite_diff) for d in perturbs
]
d_logits = logits_from_embedding_fn(
[emb + d for (emb, d) in zip(embedded, perturbs)])
kl = _kl_divergence_with_logits(logits, d_logits, weights)
perturbs = tf.gradients(
kl,
perturbs,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
perturbs = [tf.stop_gradient(d) for d in perturbs]
perturbs = [_scale_l2(d, FLAGS.perturb_norm_length) for d in perturbs]
vadv_logits = logits_from_embedding_fn(
[emb + d for (emb, d) in zip(embedded, perturbs)])
return _kl_divergence_with_logits(logits, vadv_logits, weights)
def _mask_by_length(t, length):
"""Mask t, 3-D [batch, time, dim], by length, 1-D [batch,]."""
maxlen = t.get_shape().as_list()[1]
# Subtract 1 from length to prevent the perturbation from going on 'eos'
mask = tf.sequence_mask(length - 1, maxlen=maxlen)
mask = tf.expand_dims(tf.cast(mask, tf.float32), -1)
# shape(mask) = (batch, num_timesteps, 1)
return t * mask
def _scale_l2(x, norm_length):
# shape(x) = (batch, num_timesteps, d)
# Divide x by max(abs(x)) for a numerically stable L2 norm.
# 2norm(x) = a * 2norm(x/a)
# Scale over the full sequence, dims (1, 2)
alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12
l2_norm = alpha * tf.sqrt(
tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), keep_dims=True) + 1e-6)
x_unit = x / l2_norm
return norm_length * x_unit
def _kl_divergence_with_logits(q_logits, p_logits, weights):
"""Returns weighted KL divergence between distributions q and p.
Args:
q_logits: logits for 1st argument of KL divergence shape
[batch_size, num_timesteps, num_classes] if num_classes > 2, and
[batch_size, num_timesteps] if num_classes == 2.
p_logits: logits for 2nd argument of KL divergence with same shape q_logits.
weights: 1-D float tensor with shape [batch_size, num_timesteps].
Elements should be 1.0 only on end of sequences
Returns:
KL: float scalar.
"""
# For logistic regression
if FLAGS.num_classes == 2:
q = tf.nn.sigmoid(q_logits)
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) +
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q))
kl = tf.squeeze(kl, 2)
# For softmax regression
else:
q = tf.nn.softmax(q_logits)
kl = tf.reduce_sum(
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), -1)
num_labels = tf.reduce_sum(weights)
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels)
kl.get_shape().assert_has_rank(2)
weights.get_shape().assert_has_rank(2)
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl')
return loss