Skip to content

Commit bf3e171

Browse files
authored
Update reference scale calculation in TensorFlow test (#463)
Signed-off-by: Tim Moon <[email protected]>
1 parent 38b85c3 commit bf3e171

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

tests/tensorflow/test_layers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,9 @@ def get_fp8_recipe(override_wgrad=False):
3434

3535
def compute_scale(amax, scale, fp8_max, margin):
3636
"""Default function to convert amax to scaling factor."""
37-
exp = tf.math.floor(tf.experimental.numpy.log2(fp8_max / amax)) - margin
38-
sf = tf.math.round(tf.math.pow(2., tf.math.abs(exp)))
37+
sf = (fp8_max / amax) / (2 ** margin)
3938
sf = tf.where(amax > 0.0, sf, scale)
4039
sf = tf.where(tf.math.is_finite(amax), sf, scale)
41-
sf = tf.where(exp < 0, 1.0 / sf, sf)
4240
return sf
4341

4442

0 commit comments

Comments
 (0)