From 2dda7e3918b07914e26485806f50c229e81d1e44 Mon Sep 17 00:00:00 2001 From: Jessie Pathfinder <55774978+jessiepathfinder@users.noreply.github.com> Date: Thu, 5 Oct 2023 07:43:55 +0200 Subject: [PATCH] Deduplicated duplicate subtraction --- src/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/model.py b/src/model.py index 230b83cc2..4b9266279 100644 --- a/src/model.py +++ b/src/model.py @@ -32,8 +32,9 @@ def norm(x, scope, *, axis=-1, epsilon=1e-5): g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) u = tf.reduce_mean(x, axis=axis, keepdims=True) - s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) - x = (x - u) * tf.rsqrt(s + epsilon) + n = x-u + s = tf.reduce_mean(tf.square(n), axis=axis, keepdims=True) + x = n * tf.rsqrt(s + epsilon) x = x*g + b return x