Skip to content

Commit 1612dc7

Browse files
mingxu1067ashors1
authored andcommitted
Set deterministic=True for inference.
1 parent 9d6b6db commit 1612dc7

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

paxml/contrib/gpu/scripts_gpu/te_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def __call__(self,
141141
inputs=x_in,
142142
attention_mask=attention_mask,
143143
encoded=cross_inputs,
144-
encoder_decoder_mask=cross_attention_mask)
144+
encoder_decoder_mask=cross_attention_mask,
145+
deterministic=self.do_eval)
145146
x_out = checkpoint_name(x_out, 'transformer_layer_out')
146147
return x_out
147148

0 commit comments

Comments
 (0)