We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9d6b6db commit 1612dc7Copy full SHA for 1612dc7
paxml/contrib/gpu/scripts_gpu/te_helper.py
@@ -141,7 +141,8 @@ def __call__(self,
141
inputs=x_in,
142
attention_mask=attention_mask,
143
encoded=cross_inputs,
144
- encoder_decoder_mask=cross_attention_mask)
+ encoder_decoder_mask=cross_attention_mask,
145
+ deterministic=self.do_eval)
146
x_out = checkpoint_name(x_out, 'transformer_layer_out')
147
return x_out
148
0 commit comments