@@ -42,7 +42,7 @@ def call(self, inputs, mask=None, **kwargs):
42
42
s = q @ tf .transpose (k , [0 , 1 , 3 , 2 ]) / (tf .math .sqrt (self .k_f ) + 1e-8 )
43
43
if mask is not None :
44
44
s += mask * - 1e9
45
- a = tf .nn .softmax (s ) # [b,h,attention,s]
45
+ a = tf .nn .softmax (s )
46
46
self .attention = a
47
47
b = a @ v
48
48
o = tf .concat (tf .unstack (b , axis = 1 ), 2 ) @ self .wo
@@ -188,7 +188,7 @@ def call(self, inputs, training=None, **kwargs):
188
188
pad_mask = self ._pad_mask (x )
189
189
encoded_z = self .encoder (x_embed , mask = pad_mask )
190
190
decoded_z = self .decoder (
191
- (encoded_z , y_embed ), look_ahead_mask = self ._look_ahead_mask (x ), pad_mask = pad_mask )
191
+ (encoded_z , y_embed ), look_ahead_mask = self ._look_ahead_mask (y ), pad_mask = pad_mask )
192
192
o = self .o (decoded_z )
193
193
return o
194
194
@@ -213,7 +213,7 @@ def translate(self, src, i2v, v2i):
213
213
y = tgt [:, :- 1 ]
214
214
y_embed = self .embed (y )
215
215
decoded_z = self .decoder (
216
- (encoded_z , y_embed ), look_ahead_mask = self ._look_ahead_mask (src_pad ), pad_mask = self ._pad_mask (src_pad ))
216
+ (encoded_z , y_embed ), look_ahead_mask = self ._look_ahead_mask (y ), pad_mask = self ._pad_mask (src_pad ))
217
217
logit = self .o (decoded_z )[:, tgti , :].numpy ()
218
218
idx = np .argmax (logit , 1 )
219
219
tgti += 1
0 commit comments