Skip to content

Commit a53fd87

Browse files
committed
minor modification
1 parent b800d23 commit a53fd87

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

simple_realize/transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def call(self, inputs, mask=None, **kwargs):
4242
s = q @ tf.transpose(k, [0, 1, 3, 2]) / (tf.math.sqrt(self.k_f) + 1e-8)
4343
if mask is not None:
4444
s += mask * -1e9
45-
a = tf.nn.softmax(s) # [b,h,attention,s]
45+
a = tf.nn.softmax(s)
4646
self.attention = a
4747
b = a @ v
4848
o = tf.concat(tf.unstack(b, axis=1), 2) @ self.wo
@@ -188,7 +188,7 @@ def call(self, inputs, training=None, **kwargs):
188188
pad_mask = self._pad_mask(x)
189189
encoded_z = self.encoder(x_embed, mask=pad_mask)
190190
decoded_z = self.decoder(
191-
(encoded_z, y_embed), look_ahead_mask=self._look_ahead_mask(x), pad_mask=pad_mask)
191+
(encoded_z, y_embed), look_ahead_mask=self._look_ahead_mask(y), pad_mask=pad_mask)
192192
o = self.o(decoded_z)
193193
return o
194194

@@ -213,7 +213,7 @@ def translate(self, src, i2v, v2i):
213213
y = tgt[:, :-1]
214214
y_embed = self.embed(y)
215215
decoded_z = self.decoder(
216-
(encoded_z, y_embed), look_ahead_mask=self._look_ahead_mask(src_pad), pad_mask=self._pad_mask(src_pad))
216+
(encoded_z, y_embed), look_ahead_mask=self._look_ahead_mask(y), pad_mask=self._pad_mask(src_pad))
217217
logit = self.o(decoded_z)[:, tgti, :].numpy()
218218
idx = np.argmax(logit, 1)
219219
tgti += 1

0 commit comments

Comments
 (0)